org.apache.spark
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index dbfd5a514c189..b87476322573d 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -17,15 +17,10 @@
package org.apache.spark
-import java.lang.{Byte => JByte}
import java.net.{Authenticator, PasswordAuthentication}
import java.nio.charset.StandardCharsets.UTF_8
-import java.security.{KeyStore, SecureRandom}
-import java.security.cert.X509Certificate
import javax.net.ssl._
-import com.google.common.hash.HashCodes
-import com.google.common.io.Files
import org.apache.hadoop.io.Text
import org.apache.hadoop.security.{Credentials, UserGroupInformation}
@@ -365,13 +360,8 @@ private[spark] class SecurityManager(
return
}
- val rnd = new SecureRandom()
- val length = sparkConf.getInt("spark.authenticate.secretBitLength", 256) / JByte.SIZE
- val secretBytes = new Array[Byte](length)
- rnd.nextBytes(secretBytes)
-
+ secretKey = Utils.createSecret(sparkConf)
val creds = new Credentials()
- secretKey = HashCodes.fromBytes(secretBytes).toString()
creds.addSecretKey(SECRET_LOOKUP_KEY, secretKey.getBytes(UTF_8))
UserGroupInformation.getCurrentUser().addCredentials(creds)
}
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index 129956e9f9ffa..dab409572646f 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -454,8 +454,9 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria
*/
private[spark] def validateSettings() {
if (contains("spark.local.dir")) {
- val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " +
- "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)."
+ val msg = "Note that spark.local.dir will be overridden by the value set by " +
+ "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS" +
+ " in YARN)."
logWarning(msg)
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala
index 11f2432575d84..9ddc4a4910180 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala
@@ -17,26 +17,39 @@
package org.apache.spark.api.python
-import java.io.DataOutputStream
-import java.net.Socket
+import java.io.{DataOutputStream, File, FileOutputStream}
+import java.net.InetAddress
+import java.nio.charset.StandardCharsets.UTF_8
+import java.nio.file.Files
import py4j.GatewayServer
+import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
/**
- * Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port
- * back to its caller via a callback port specified by the caller.
+ * Process that starts a Py4J GatewayServer on an ephemeral port.
*
* This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py).
*/
private[spark] object PythonGatewayServer extends Logging {
initializeLogIfNecessary(true)
- def main(args: Array[String]): Unit = Utils.tryOrExit {
- // Start a GatewayServer on an ephemeral port
- val gatewayServer: GatewayServer = new GatewayServer(null, 0)
+ def main(args: Array[String]): Unit = {
+ val secret = Utils.createSecret(new SparkConf())
+
+ // Start a GatewayServer on an ephemeral port. Make sure the callback client is configured
+ // with the same secret, in case the app needs callbacks from the JVM to the underlying
+ // python processes.
+ val localhost = InetAddress.getLoopbackAddress()
+ val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder()
+ .authToken(secret)
+ .javaPort(0)
+ .javaAddress(localhost)
+ .callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
+ .build()
+
gatewayServer.start()
val boundPort: Int = gatewayServer.getListeningPort
if (boundPort == -1) {
@@ -46,15 +59,24 @@ private[spark] object PythonGatewayServer extends Logging {
logDebug(s"Started PythonGatewayServer on port $boundPort")
}
- // Communicate the bound port back to the caller via the caller-specified callback port
- val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
- val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
- logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
- val callbackSocket = new Socket(callbackHost, callbackPort)
- val dos = new DataOutputStream(callbackSocket.getOutputStream)
+ // Communicate the connection information back to the python process by writing the
+ // information in the requested file. This needs to match the read side in java_gateway.py.
+ val connectionInfoPath = new File(sys.env("_PYSPARK_DRIVER_CONN_INFO_PATH"))
+ val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(),
+ "connection", ".info").toFile()
+
+ val dos = new DataOutputStream(new FileOutputStream(tmpPath))
dos.writeInt(boundPort)
+
+ val secretBytes = secret.getBytes(UTF_8)
+ dos.writeInt(secretBytes.length)
+ dos.write(secretBytes, 0, secretBytes.length)
dos.close()
- callbackSocket.close()
+
+ if (!tmpPath.renameTo(connectionInfoPath)) {
+ logError(s"Unable to write connection information to $connectionInfoPath.")
+ System.exit(1)
+ }
// Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
while (System.in.read() != -1) {
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index f6293c0dc5091..a1ee2f7d1b119 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -38,6 +38,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
+import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util._
@@ -107,6 +108,12 @@ private[spark] object PythonRDD extends Logging {
// remember the broadcasts sent to each worker
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
+ // Authentication helper used when serving iterator data.
+ private lazy val authHelper = {
+ val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
+ new SocketAuthHelper(conf)
+ }
+
def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = {
synchronized {
workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
@@ -129,12 +136,13 @@ private[spark] object PythonRDD extends Logging {
* (effectively a collect()), but allows you to run on a certain subset of partitions,
* or to enable local execution.
*
- * @return the port number of a local socket which serves the data collected from this job.
+ * @return 2-tuple (as a Java array) with the port number of a local socket which serves the
+ * data collected from this job, and the secret for authentication.
*/
def runJob(
sc: SparkContext,
rdd: JavaRDD[Array[Byte]],
- partitions: JArrayList[Int]): Int = {
+ partitions: JArrayList[Int]): Array[Any] = {
type ByteArray = Array[Byte]
type UnrolledPartition = Array[ByteArray]
val allPartitions: Array[UnrolledPartition] =
@@ -147,13 +155,14 @@ private[spark] object PythonRDD extends Logging {
/**
* A helper function to collect an RDD as an iterator, then serve it via socket.
*
- * @return the port number of a local socket which serves the data collected from this job.
+ * @return 2-tuple (as a Java array) with the port number of a local socket which serves the
+ * data collected from this job, and the secret for authentication.
*/
- def collectAndServe[T](rdd: RDD[T]): Int = {
+ def collectAndServe[T](rdd: RDD[T]): Array[Any] = {
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}
- def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = {
+ def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
}
@@ -384,8 +393,11 @@ private[spark] object PythonRDD extends Logging {
* and send them into this connection.
*
* The thread will terminate after all the data are sent or any exceptions happen.
+ *
+ * @return 2-tuple (as a Java array) with the port number of a local socket which serves the
+ * data collected from this job, and the secret for authentication.
*/
- def serveIterator[T](items: Iterator[T], threadName: String): Int = {
+ def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
// Close the socket if no connection in 15 seconds
serverSocket.setSoTimeout(15000)
@@ -395,11 +407,14 @@ private[spark] object PythonRDD extends Logging {
override def run() {
try {
val sock = serverSocket.accept()
+ authHelper.authClient(sock)
+
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
Utils.tryWithSafeFinally {
writeIteratorToStream(items, out)
} {
out.close()
+ sock.close()
}
} catch {
case NonFatal(e) =>
@@ -410,7 +425,7 @@ private[spark] object PythonRDD extends Logging {
}
}.start()
- serverSocket.getLocalPort
+ Array(serverSocket.getLocalPort, authHelper.secret)
}
private def getMergedConf(confAsMap: java.util.HashMap[String, String],
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
index 92e228a9dd10c..27a5e19f96a14 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
@@ -32,7 +32,7 @@ private[spark] object PythonUtils {
val pythonPath = new ArrayBuffer[String]
for (sparkHome <- sys.env.get("SPARK_HOME")) {
pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator)
- pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.6-src.zip").mkString(File.separator)
+ pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.7-src.zip").mkString(File.separator)
}
pythonPath ++= SparkContext.jarOfObject(this)
pythonPath.mkString(File.pathSeparator)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 2340580b54f67..6afa37aa36fd3 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -27,6 +27,7 @@ import scala.collection.mutable
import org.apache.spark._
import org.apache.spark.internal.Logging
+import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util.{RedirectThread, Utils}
private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
@@ -67,6 +68,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
value
}.getOrElse("pyspark.worker")
+ private val authHelper = new SocketAuthHelper(SparkEnv.get.conf)
+
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
@@ -108,6 +111,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
if (pid < 0) {
throw new IllegalStateException("Python daemon failed to launch worker with code " + pid)
}
+
+ authHelper.authToServer(socket)
daemonWorkers.put(socket, pid)
socket
}
@@ -145,25 +150,24 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
workerEnv.put("PYTHONPATH", pythonPath)
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
workerEnv.put("PYTHONUNBUFFERED", "YES")
+ workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocket.getLocalPort.toString)
+ workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
val worker = pb.start()
// Redirect worker stdout and stderr
redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream)
- // Tell the worker our port
- val out = new OutputStreamWriter(worker.getOutputStream, StandardCharsets.UTF_8)
- out.write(serverSocket.getLocalPort + "\n")
- out.flush()
-
- // Wait for it to connect to our socket
+ // Wait for it to connect to our socket, and validate the auth secret.
serverSocket.setSoTimeout(10000)
+
try {
val socket = serverSocket.accept()
+ authHelper.authClient(socket)
simpleWorkers.put(socket, worker)
return socket
} catch {
case e: Exception =>
- throw new SparkException("Python worker did not connect back in time", e)
+ throw new SparkException("Python worker failed to connect back.", e)
}
} finally {
if (serverSocket != null) {
@@ -187,6 +191,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
val workerEnv = pb.environment()
workerEnv.putAll(envVars.asJava)
workerEnv.put("PYTHONPATH", pythonPath)
+ workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
workerEnv.put("PYTHONUNBUFFERED", "YES")
daemon = pb.start()
@@ -218,7 +223,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
// Redirect daemon stdout and stderr
redirectStreamsToStderr(in, daemon.getErrorStream)
-
} catch {
case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
new file mode 100644
index 0000000000000..ac6826a9ec774
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io.{DataInputStream, DataOutputStream}
+import java.net.Socket
+
+import org.apache.spark.SparkConf
+import org.apache.spark.security.SocketAuthHelper
+
+private[spark] class RAuthHelper(conf: SparkConf) extends SocketAuthHelper(conf) {
+
+ override protected def readUtf8(s: Socket): String = {
+ SerDe.readString(new DataInputStream(s.getInputStream()))
+ }
+
+ override protected def writeUtf8(str: String, s: Socket): Unit = {
+ val out = s.getOutputStream()
+ SerDe.writeString(new DataOutputStream(out), str)
+ out.flush()
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
index 2d1152a036449..3b2e809408e0f 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
@@ -17,8 +17,8 @@
package org.apache.spark.api.r
-import java.io.{DataOutputStream, File, FileOutputStream, IOException}
-import java.net.{InetAddress, InetSocketAddress, ServerSocket}
+import java.io.{DataInputStream, DataOutputStream, File, FileOutputStream, IOException}
+import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket}
import java.util.concurrent.TimeUnit
import io.netty.bootstrap.ServerBootstrap
@@ -32,6 +32,8 @@ import io.netty.handler.timeout.ReadTimeoutHandler
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.util.Utils
/**
* Netty-based backend server that is used to communicate between R and Java.
@@ -45,7 +47,7 @@ private[spark] class RBackend {
/** Tracks JVM objects returned to R for this RBackend instance. */
private[r] val jvmObjectTracker = new JVMObjectTracker
- def init(): Int = {
+ def init(): (Int, RAuthHelper) = {
val conf = new SparkConf()
val backendConnectionTimeout = conf.getInt(
"spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
@@ -53,6 +55,7 @@ private[spark] class RBackend {
conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS))
val workerGroup = bossGroup
val handler = new RBackendHandler(this)
+ val authHelper = new RAuthHelper(conf)
bootstrap = new ServerBootstrap()
.group(bossGroup, workerGroup)
@@ -71,13 +74,16 @@ private[spark] class RBackend {
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
.addLast("decoder", new ByteArrayDecoder())
.addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout))
+ .addLast(new RBackendAuthHandler(authHelper.secret))
.addLast("handler", handler)
}
})
channelFuture = bootstrap.bind(new InetSocketAddress("localhost", 0))
channelFuture.syncUninterruptibly()
- channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort()
+
+ val port = channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort()
+ (port, authHelper)
}
def run(): Unit = {
@@ -116,7 +122,7 @@ private[spark] object RBackend extends Logging {
val sparkRBackend = new RBackend()
try {
// bind to random port
- val boundPort = sparkRBackend.init()
+ val (boundPort, authHelper) = sparkRBackend.init()
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
val listenPort = serverSocket.getLocalPort()
// Connection timeout is set by socket client. To make it configurable we will pass the
@@ -133,6 +139,7 @@ private[spark] object RBackend extends Logging {
dos.writeInt(listenPort)
SerDe.writeString(dos, RUtils.rPackages.getOrElse(""))
dos.writeInt(backendConnectionTimeout)
+ SerDe.writeString(dos, authHelper.secret)
dos.close()
f.renameTo(new File(path))
@@ -144,12 +151,35 @@ private[spark] object RBackend extends Logging {
val buf = new Array[Byte](1024)
// shutdown JVM if R does not connect back in 10 seconds
serverSocket.setSoTimeout(10000)
+
+ // Wait for the R process to connect back, ignoring any failed auth attempts. Allow
+ // a max number of connection attempts to avoid looping forever.
try {
- val inSocket = serverSocket.accept()
+ var remainingAttempts = 10
+ var inSocket: Socket = null
+ while (inSocket == null) {
+ inSocket = serverSocket.accept()
+ try {
+ authHelper.authClient(inSocket)
+ } catch {
+ case e: Exception =>
+ remainingAttempts -= 1
+ if (remainingAttempts == 0) {
+ val msg = "Too many failed authentication attempts."
+ logError(msg)
+ throw new IllegalStateException(msg)
+ }
+ logInfo("Client connection failed authentication.")
+ inSocket = null
+ }
+ }
+
serverSocket.close()
+
// wait for the end of socket, closed if R process die
inSocket.getInputStream().read(buf)
} finally {
+ serverSocket.close()
sparkRBackend.close()
System.exit(0)
}
@@ -165,4 +195,5 @@ private[spark] object RBackend extends Logging {
}
System.exit(0)
}
+
}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala
new file mode 100644
index 0000000000000..4162e4a6c7476
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io.{ByteArrayOutputStream, DataOutputStream}
+import java.nio.charset.StandardCharsets.UTF_8
+
+import io.netty.channel.{Channel, ChannelHandlerContext, SimpleChannelInboundHandler}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.Utils
+
+/**
+ * Authentication handler for connections from the R process.
+ */
+private class RBackendAuthHandler(secret: String)
+ extends SimpleChannelInboundHandler[Array[Byte]] with Logging {
+
+ override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = {
+ // The R code adds a null terminator to serialized strings, so ignore it here.
+ val clientSecret = new String(msg, 0, msg.length - 1, UTF_8)
+ try {
+ require(secret == clientSecret, "Auth secret mismatch.")
+ ctx.pipeline().remove(this)
+ writeReply("ok", ctx.channel())
+ } catch {
+ case e: Exception =>
+ logInfo("Authentication failure.", e)
+ writeReply("err", ctx.channel())
+ ctx.close()
+ }
+ }
+
+ private def writeReply(reply: String, chan: Channel): Unit = {
+ val out = new ByteArrayOutputStream()
+ SerDe.writeString(new DataOutputStream(out), reply)
+ chan.writeAndFlush(out.toByteArray())
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
index 88118392003e8..e7fdc3963945a 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
@@ -74,14 +74,19 @@ private[spark] class RRunner[U](
// the socket used to send out the input of task
serverSocket.setSoTimeout(10000)
- val inSocket = serverSocket.accept()
- startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex)
-
- // the socket used to receive the output of task
- val outSocket = serverSocket.accept()
- val inputStream = new BufferedInputStream(outSocket.getInputStream)
- dataStream = new DataInputStream(inputStream)
- serverSocket.close()
+ dataStream = try {
+ val inSocket = serverSocket.accept()
+ RRunner.authHelper.authClient(inSocket)
+ startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex)
+
+ // the socket used to receive the output of task
+ val outSocket = serverSocket.accept()
+ RRunner.authHelper.authClient(outSocket)
+ val inputStream = new BufferedInputStream(outSocket.getInputStream)
+ new DataInputStream(inputStream)
+ } finally {
+ serverSocket.close()
+ }
try {
return new Iterator[U] {
@@ -315,6 +320,11 @@ private[r] object RRunner {
private[this] var errThread: BufferedStreamThread = _
private[this] var daemonChannel: DataOutputStream = _
+ private lazy val authHelper = {
+ val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
+ new RAuthHelper(conf)
+ }
+
/**
* Start a thread to print the process's stderr to ours
*/
@@ -349,6 +359,7 @@ private[r] object RRunner {
pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString)
pb.environment().put("SPARKR_SPARKFILES_ROOT_DIR", SparkFiles.getRootDirectory())
pb.environment().put("SPARKR_IS_RUNNING_ON_WORKER", "TRUE")
+ pb.environment().put("SPARKR_WORKER_SECRET", authHelper.secret)
pb.redirectErrorStream(true) // redirect stderr into stdout
val proc = pb.start()
val errThread = startStdoutThread(proc)
@@ -370,8 +381,12 @@ private[r] object RRunner {
// the socket used to send out the input of task
serverSocket.setSoTimeout(10000)
val sock = serverSocket.accept()
- daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
- serverSocket.close()
+ try {
+ authHelper.authClient(sock)
+ daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
+ } finally {
+ serverSocket.close()
+ }
}
try {
daemonChannel.writeInt(port)
diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
index 7aca305783a7f..1b7e031ee0678 100644
--- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
@@ -18,7 +18,7 @@
package org.apache.spark.deploy
import java.io.File
-import java.net.URI
+import java.net.{InetAddress, URI}
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
@@ -39,6 +39,7 @@ object PythonRunner {
val pyFiles = args(1)
val otherArgs = args.slice(2, args.length)
val sparkConf = new SparkConf()
+ val secret = Utils.createSecret(sparkConf)
val pythonExec = sparkConf.get(PYSPARK_DRIVER_PYTHON)
.orElse(sparkConf.get(PYSPARK_PYTHON))
.orElse(sys.env.get("PYSPARK_DRIVER_PYTHON"))
@@ -51,7 +52,13 @@ object PythonRunner {
// Launch a Py4J gateway server for the process to connect to; this will let it see our
// Java system properties and such
- val gatewayServer = new py4j.GatewayServer(null, 0)
+ val localhost = InetAddress.getLoopbackAddress()
+ val gatewayServer = new py4j.GatewayServer.GatewayServerBuilder()
+ .authToken(secret)
+ .javaPort(0)
+ .javaAddress(localhost)
+ .callbackClient(py4j.GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
+ .build()
val thread = new Thread(new Runnable() {
override def run(): Unit = Utils.logUncaughtExceptions {
gatewayServer.start()
@@ -82,6 +89,7 @@ object PythonRunner {
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
+ env.put("PYSPARK_GATEWAY_SECRET", secret)
// pass conf spark.pyspark.python to python process, the only way to pass info to
// python process is through environment variable.
sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _))
diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
index 6eb53a8252205..e86b362639e57 100644
--- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
@@ -68,10 +68,13 @@ object RRunner {
// Java system properties etc.
val sparkRBackend = new RBackend()
@volatile var sparkRBackendPort = 0
+ @volatile var sparkRBackendSecret: String = null
val initialized = new Semaphore(0)
val sparkRBackendThread = new Thread("SparkR backend") {
override def run() {
- sparkRBackendPort = sparkRBackend.init()
+ val (port, authHelper) = sparkRBackend.init()
+ sparkRBackendPort = port
+ sparkRBackendSecret = authHelper.secret
initialized.release()
sparkRBackend.run()
}
@@ -91,6 +94,7 @@ object RRunner {
env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(","))
env.put("R_PROFILE_USER",
Seq(rPackageDir(0), "SparkR", "profile", "general.R").mkString(File.separator))
+ env.put("SPARKR_BACKEND_AUTH_SECRET", sparkRBackendSecret)
builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
val process = builder.start()
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 427c797755b84..087e9c31a9c9a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -22,6 +22,7 @@ import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowab
import java.net.URL
import java.security.PrivilegedExceptionAction
import java.text.ParseException
+import java.util.UUID
import scala.annotation.tailrec
import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
@@ -1204,7 +1205,33 @@ private[spark] object SparkSubmitUtils {
/** A nice function to use in tests as well. Values are dummy strings. */
def getModuleDescriptor: DefaultModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance(
- ModuleRevisionId.newInstance("org.apache.spark", "spark-submit-parent", "1.0"))
+ // Include UUID in module name, so multiple clients resolving maven coordinate at the same time
+ // do not modify the same resolution file concurrently.
+ ModuleRevisionId.newInstance("org.apache.spark",
+ s"spark-submit-parent-${UUID.randomUUID.toString}",
+ "1.0"))
+
+ /**
+ * Clear ivy resolution from current launch. The resolution file is usually at
+ * ~/.ivy2/org.apache.spark-spark-submit-parent-$UUID-default.xml,
+ * ~/.ivy2/resolved-org.apache.spark-spark-submit-parent-$UUID-1.0.xml, and
+ * ~/.ivy2/resolved-org.apache.spark-spark-submit-parent-$UUID-1.0.properties.
+ * Since each launch will have its own resolution files created, delete them after
+ * each resolution to prevent accumulation of these files in the ivy cache dir.
+ */
+ private def clearIvyResolutionFiles(
+ mdId: ModuleRevisionId,
+ ivySettings: IvySettings,
+ ivyConfName: String): Unit = {
+ val currentResolutionFiles = Seq(
+ s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml",
+ s"resolved-${mdId.getOrganisation}-${mdId.getName}-${mdId.getRevision}.xml",
+ s"resolved-${mdId.getOrganisation}-${mdId.getName}-${mdId.getRevision}.properties"
+ )
+ currentResolutionFiles.foreach { filename =>
+ new File(ivySettings.getDefaultCache, filename).delete()
+ }
+ }
/**
* Resolves any dependencies that were supplied through maven coordinates
@@ -1255,14 +1282,6 @@ private[spark] object SparkSubmitUtils {
// A Module descriptor must be specified. Entries are dummy strings
val md = getModuleDescriptor
- // clear ivy resolution from previous launches. The resolution file is usually at
- // ~/.ivy2/org.apache.spark-spark-submit-parent-default.xml. In between runs, this file
- // leads to confusion with Ivy when the files can no longer be found at the repository
- // declared in that file/
- val mdId = md.getModuleRevisionId
- val previousResolution = new File(ivySettings.getDefaultCache,
- s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml")
- if (previousResolution.exists) previousResolution.delete
md.setDefaultConf(ivyConfName)
@@ -1283,7 +1302,10 @@ private[spark] object SparkSubmitUtils {
packagesDirectory.getAbsolutePath + File.separator +
"[organization]_[artifact]-[revision](-[classifier]).[ext]",
retrieveOptions.setConfs(Array(ivyConfName)))
- resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory)
+ val paths = resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory)
+ val mdId = md.getModuleRevisionId
+ clearIvyResolutionFiles(mdId, ivySettings, ivyConfName)
+ paths
} finally {
System.setOut(sysOut)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 0733fdb72cafb..fed4e0a5069c3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -36,7 +36,6 @@ import org.apache.spark.launcher.SparkSubmitArgumentsParser
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.util.Utils
-
/**
* Parses and encapsulates arguments from the spark-submit script.
* The env argument is used for testing.
@@ -76,6 +75,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
var proxyUser: String = null
var principal: String = null
var keytab: String = null
+ private var dynamicAllocationEnabled: Boolean = false
// Standalone cluster mode only
var supervise: Boolean = false
@@ -198,6 +198,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
queue = Option(queue).orElse(sparkProperties.get("spark.yarn.queue")).orNull
keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull
principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull
+ dynamicAllocationEnabled =
+ sparkProperties.get("spark.dynamicAllocation.enabled").exists("true".equalsIgnoreCase)
// Try to set main class from JAR if no --class argument is given
if (mainClass == null && !isPython && !isR && primaryResource != null) {
@@ -274,7 +276,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
if (totalExecutorCores != null && Try(totalExecutorCores.toInt).getOrElse(-1) <= 0) {
error("Total executor cores must be a positive number")
}
- if (numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) {
+ if (!dynamicAllocationEnabled &&
+ numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) {
error("Number of executors must be a positive number")
}
if (pyFiles != null && !isPython) {
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 6bb98c37b4479..82f0a04e94b1c 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -352,6 +352,11 @@ package object config {
.regexConf
.createOptional
+ private[spark] val AUTH_SECRET_BIT_LENGTH =
+ ConfigBuilder("spark.authenticate.secretBitLength")
+ .intConf
+ .createWithDefault(256)
+
private[spark] val NETWORK_AUTH_ENABLED =
ConfigBuilder("spark.authenticate")
.booleanConf
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 5627a557a12f3..d8794e8e551aa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -633,7 +633,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
doRequestTotalExecutors(requestedTotalExecutors)
} else {
- numPendingExecutors += knownExecutors.size
+ numPendingExecutors += executorsToKill.size
Future.successful(true)
}
diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
new file mode 100644
index 0000000000000..d15e7937b0523
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.security
+
+import java.io.{DataInputStream, DataOutputStream, InputStream}
+import java.net.Socket
+import java.nio.charset.StandardCharsets.UTF_8
+
+import org.apache.spark.SparkConf
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.util.Utils
+
+/**
+ * A class that can be used to add a simple authentication protocol to socket-based communication.
+ *
+ * The protocol is simple: an auth secret is written to the socket, and the other side checks the
+ * secret and writes either "ok" or "err" to the output. If authentication fails, the socket is
+ * not expected to be valid anymore.
+ *
+ * There's no secrecy, so this relies on the sockets being either local or somehow encrypted.
+ */
+private[spark] class SocketAuthHelper(conf: SparkConf) {
+
+ val secret = Utils.createSecret(conf)
+
+ /**
+ * Read the auth secret from the socket and compare to the expected value. Write the reply back
+ * to the socket.
+ *
+ * If authentication fails, this method will close the socket.
+ *
+ * @param s The client socket.
+ * @throws IllegalArgumentException If authentication fails.
+ */
+ def authClient(s: Socket): Unit = {
+ // Set the socket timeout while checking the auth secret. Reset it before returning.
+ val currentTimeout = s.getSoTimeout()
+ try {
+ s.setSoTimeout(10000)
+ val clientSecret = readUtf8(s)
+ if (secret == clientSecret) {
+ writeUtf8("ok", s)
+ } else {
+ writeUtf8("err", s)
+ JavaUtils.closeQuietly(s)
+ }
+ } finally {
+ s.setSoTimeout(currentTimeout)
+ }
+ }
+
+ /**
+ * Authenticate with a server by writing the auth secret and checking the server's reply.
+ *
+ * If authentication fails, this method will close the socket.
+ *
+ * @param s The socket connected to the server.
+ * @throws IllegalArgumentException If authentication fails.
+ */
+ def authToServer(s: Socket): Unit = {
+ writeUtf8(secret, s)
+
+ val reply = readUtf8(s)
+ if (reply != "ok") {
+ JavaUtils.closeQuietly(s)
+ throw new IllegalArgumentException("Authentication failed.")
+ }
+ }
+
+ protected def readUtf8(s: Socket): String = {
+ val din = new DataInputStream(s.getInputStream())
+ val len = din.readInt()
+ val bytes = new Array[Byte](len)
+ din.readFully(bytes)
+ new String(bytes, UTF_8)
+ }
+
+ protected def writeUtf8(str: String, s: Socket): Unit = {
+ val bytes = str.getBytes(UTF_8)
+ val dout = new DataOutputStream(s.getOutputStream())
+ dout.writeInt(bytes.length)
+ dout.write(bytes, 0, bytes.length)
+ dout.flush()
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index dcad1b914038f..13adaa921dc23 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,6 +18,7 @@
package org.apache.spark.util
import java.io._
+import java.lang.{Byte => JByte}
import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo}
import java.lang.reflect.InvocationTargetException
import java.math.{MathContext, RoundingMode}
@@ -26,11 +27,11 @@ import java.nio.ByteBuffer
import java.nio.channels.{Channels, FileChannel}
import java.nio.charset.StandardCharsets
import java.nio.file.Files
+import java.security.SecureRandom
import java.util.{Locale, Properties, Random, UUID}
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
import java.util.zip.GZIPInputStream
-import javax.net.ssl.HttpsURLConnection
import scala.annotation.tailrec
import scala.collection.JavaConverters._
@@ -44,6 +45,7 @@ import scala.util.matching.Regex
import _root_.io.netty.channel.unix.Errors.NativeIoException
import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
+import com.google.common.hash.HashCodes
import com.google.common.io.{ByteStreams, Files => GFiles}
import com.google.common.net.InetAddresses
import org.apache.commons.lang3.SystemUtils
@@ -2704,6 +2706,15 @@ private[spark] object Utils extends Logging {
def substituteAppId(opt: String, appId: String): String = {
opt.replace("{{APP_ID}}", appId)
}
+
+ def createSecret(conf: SparkConf): String = {
+ val bits = conf.get(AUTH_SECRET_BIT_LENGTH)
+ val rnd = new SecureRandom()
+ val secretBytes = new Array[Byte](bits / JByte.SIZE)
+ rnd.nextBytes(secretBytes)
+ HashCodes.fromBytes(secretBytes).toString()
+ }
+
}
private[util] object CallerContext extends Logging {
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 7451e07b25a1f..43286953e4383 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -180,6 +180,26 @@ class SparkSubmitSuite
appArgs.toString should include ("thequeue")
}
+ test("SPARK-24241: do not fail fast if executor num is 0 when dynamic allocation is enabled") {
+ val clArgs1 = Seq(
+ "--name", "myApp",
+ "--class", "Foo",
+ "--num-executors", "0",
+ "--conf", "spark.dynamicAllocation.enabled=true",
+ "thejar.jar")
+ new SparkSubmitArguments(clArgs1)
+
+ val clArgs2 = Seq(
+ "--name", "myApp",
+ "--class", "Foo",
+ "--num-executors", "0",
+ "--conf", "spark.dynamicAllocation.enabled=false",
+ "thejar.jar")
+
+ val e = intercept[SparkException](new SparkSubmitArguments(clArgs2))
+ assert(e.getMessage.contains("Number of executors must be a positive number"))
+ }
+
test("specify deploy mode through configuration") {
val clArgs = Seq(
"--master", "yarn",
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
index eb8c203ae7751..a0f09891787e0 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
@@ -256,4 +256,19 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll {
assert(jarPath.indexOf("mydep") >= 0, "should find dependency")
}
}
+
+ test("SPARK-10878: test resolution files cleaned after resolving artifact") {
+ val main = new MavenCoordinate("my.great.lib", "mylib", "0.1")
+
+ IvyTestUtils.withRepository(main, None, None) { repo =>
+ val ivySettings = SparkSubmitUtils.buildIvySettings(Some(repo), Some(tempIvyPath))
+ val jarPath = SparkSubmitUtils.resolveMavenCoordinates(
+ main.toString,
+ ivySettings,
+ isTest = true)
+ val r = """.*org.apache.spark-spark-submit-parent-.*""".r
+ assert(!ivySettings.getDefaultCache.listFiles.map(_.getName)
+ .exists(r.findFirstIn(_).isDefined), "resolution files should be cleaned")
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index da6ecb82c7e42..fa47a52bbbc47 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.scheduler
+import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.util.concurrent.Semaphore
import scala.collection.JavaConverters._
@@ -294,10 +295,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
val listener = new SaveStageAndTaskInfo
sc.addSparkListener(listener)
sc.addSparkListener(new StatsReportListener)
- // just to make sure some of the tasks take a noticeable amount of time
+ // just to make sure some of the tasks and their deserialization take a noticeable
+ // amount of time
+ val slowDeserializable = new SlowDeserializable
val w = { i: Int =>
if (i == 0) {
Thread.sleep(100)
+ slowDeserializable.use()
}
i
}
@@ -583,3 +587,12 @@ private class FirehoseListenerThatAcceptsSparkConf(conf: SparkConf) extends Spar
case _ =>
}
}
+
+private class SlowDeserializable extends Externalizable {
+
+ override def writeExternal(out: ObjectOutput): Unit = { }
+
+ override def readExternal(in: ObjectInput): Unit = Thread.sleep(1)
+
+ def use(): Unit = { }
+}
diff --git a/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala
new file mode 100644
index 0000000000000..e57cb701b6284
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.security
+
+import java.io.Closeable
+import java.net._
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.internal.config._
+import org.apache.spark.util.Utils
+
+class SocketAuthHelperSuite extends SparkFunSuite {
+
+ private val conf = new SparkConf()
+ private val authHelper = new SocketAuthHelper(conf)
+
+ test("successful auth") {
+ Utils.tryWithResource(new ServerThread()) { server =>
+ Utils.tryWithResource(server.createClient()) { client =>
+ authHelper.authToServer(client)
+ server.close()
+ server.join()
+ assert(server.error == null)
+ assert(server.authenticated)
+ }
+ }
+ }
+
+ test("failed auth") {
+ Utils.tryWithResource(new ServerThread()) { server =>
+ Utils.tryWithResource(server.createClient()) { client =>
+ val badHelper = new SocketAuthHelper(new SparkConf().set(AUTH_SECRET_BIT_LENGTH, 128))
+ intercept[IllegalArgumentException] {
+ badHelper.authToServer(client)
+ }
+ server.close()
+ server.join()
+ assert(server.error != null)
+ assert(!server.authenticated)
+ }
+ }
+ }
+
+ private class ServerThread extends Thread with Closeable {
+
+ private val ss = new ServerSocket()
+ ss.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0))
+
+ @volatile var error: Exception = _
+ @volatile var authenticated = false
+
+ setDaemon(true)
+ start()
+
+ def createClient(): Socket = {
+ new Socket(InetAddress.getLoopbackAddress(), ss.getLocalPort())
+ }
+
+ override def run(): Unit = {
+ var clientConn: Socket = null
+ try {
+ clientConn = ss.accept()
+ authHelper.authClient(clientConn)
+ authenticated = true
+ } catch {
+ case e: Exception =>
+ error = e
+ } finally {
+ Option(clientConn).foreach(_.close())
+ }
+ }
+
+ override def close(): Unit = {
+ try {
+ ss.close()
+ } finally {
+ interrupt()
+ }
+ }
+
+ }
+
+}
diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh
index c00b00b845401..5faa3d3260a56 100755
--- a/dev/create-release/release-build.sh
+++ b/dev/create-release/release-build.sh
@@ -371,11 +371,18 @@ if [[ "$1" == "publish-release" ]]; then
find . -type f |grep -v \.jar |grep -v \.pom | xargs rm
echo "Creating hash and signature files"
- # this must have .asc and .sha1 - it really doesn't like anything else there
+ # this must have .asc, .md5 and .sha1 - it really doesn't like anything else there
for file in $(find . -type f)
do
echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \
--detach-sig --armour $file;
+ if [ $(command -v md5) ]; then
+ # Available on OS X; -q to keep only hash
+ md5 -q $file > $file.md5
+ else
+ # Available on Linux; cut to keep only hash
+ md5sum $file | cut -f1 -d' ' > $file.md5
+ fi
sha1sum $file | cut -f1 -d' ' > $file.sha1
done
diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6
index c3d1dd444b506..e710e26348117 100644
--- a/dev/deps/spark-deps-hadoop-2.6
+++ b/dev/deps/spark-deps-hadoop-2.6
@@ -162,15 +162,15 @@ orc-mapreduce-1.4.3-nohive.jar
oro-2.0.8.jar
osgi-resource-locator-1.0.1.jar
paranamer-2.8.jar
-parquet-column-1.8.2.jar
-parquet-common-1.8.2.jar
-parquet-encoding-1.8.2.jar
-parquet-format-2.3.1.jar
-parquet-hadoop-1.8.2.jar
+parquet-column-1.10.0.jar
+parquet-common-1.10.0.jar
+parquet-encoding-1.10.0.jar
+parquet-format-2.4.0.jar
+parquet-hadoop-1.10.0.jar
parquet-hadoop-bundle-1.6.0.jar
-parquet-jackson-1.8.2.jar
+parquet-jackson-1.10.0.jar
protobuf-java-2.5.0.jar
-py4j-0.10.6.jar
+py4j-0.10.7.jar
pyrolite-4.13.jar
scala-compiler-2.11.8.jar
scala-library-2.11.8.jar
@@ -190,7 +190,7 @@ stax-api-1.0.1.jar
stream-2.7.0.jar
stringtemplate-3.2.1.jar
super-csv-2.2.0.jar
-univocity-parsers-2.5.9.jar
+univocity-parsers-2.6.3.jar
validation-api-1.1.0.Final.jar
xbean-asm5-shaded-4.4.jar
xercesImpl-2.9.1.jar
diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7
index 290867035f91d..97ad17a9ff7b1 100644
--- a/dev/deps/spark-deps-hadoop-2.7
+++ b/dev/deps/spark-deps-hadoop-2.7
@@ -163,15 +163,15 @@ orc-mapreduce-1.4.3-nohive.jar
oro-2.0.8.jar
osgi-resource-locator-1.0.1.jar
paranamer-2.8.jar
-parquet-column-1.8.2.jar
-parquet-common-1.8.2.jar
-parquet-encoding-1.8.2.jar
-parquet-format-2.3.1.jar
-parquet-hadoop-1.8.2.jar
+parquet-column-1.10.0.jar
+parquet-common-1.10.0.jar
+parquet-encoding-1.10.0.jar
+parquet-format-2.4.0.jar
+parquet-hadoop-1.10.0.jar
parquet-hadoop-bundle-1.6.0.jar
-parquet-jackson-1.8.2.jar
+parquet-jackson-1.10.0.jar
protobuf-java-2.5.0.jar
-py4j-0.10.6.jar
+py4j-0.10.7.jar
pyrolite-4.13.jar
scala-compiler-2.11.8.jar
scala-library-2.11.8.jar
@@ -191,7 +191,7 @@ stax-api-1.0.1.jar
stream-2.7.0.jar
stringtemplate-3.2.1.jar
super-csv-2.2.0.jar
-univocity-parsers-2.5.9.jar
+univocity-parsers-2.6.3.jar
validation-api-1.1.0.Final.jar
xbean-asm5-shaded-4.4.jar
xercesImpl-2.9.1.jar
diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1
index 97ad65a4096cb..e21bfef8c4291 100644
--- a/dev/deps/spark-deps-hadoop-3.1
+++ b/dev/deps/spark-deps-hadoop-3.1
@@ -181,15 +181,15 @@ orc-mapreduce-1.4.3-nohive.jar
oro-2.0.8.jar
osgi-resource-locator-1.0.1.jar
paranamer-2.8.jar
-parquet-column-1.8.2.jar
-parquet-common-1.8.2.jar
-parquet-encoding-1.8.2.jar
-parquet-format-2.3.1.jar
-parquet-hadoop-1.8.2.jar
+parquet-column-1.10.0.jar
+parquet-common-1.10.0.jar
+parquet-encoding-1.10.0.jar
+parquet-format-2.4.0.jar
+parquet-hadoop-1.10.0.jar
parquet-hadoop-bundle-1.6.0.jar
-parquet-jackson-1.8.2.jar
+parquet-jackson-1.10.0.jar
protobuf-java-2.5.0.jar
-py4j-0.10.6.jar
+py4j-0.10.7.jar
pyrolite-4.13.jar
re2j-1.1.jar
scala-compiler-2.11.8.jar
@@ -211,7 +211,7 @@ stream-2.7.0.jar
stringtemplate-3.2.1.jar
super-csv-2.2.0.jar
token-provider-1.0.1.jar
-univocity-parsers-2.5.9.jar
+univocity-parsers-2.6.3.jar
validation-api-1.1.0.Final.jar
woodstox-core-5.0.3.jar
xbean-asm5-shaded-4.4.jar
diff --git a/dev/run-pip-tests b/dev/run-pip-tests
index 1321c2be4c192..7271d1014e4ae 100755
--- a/dev/run-pip-tests
+++ b/dev/run-pip-tests
@@ -89,7 +89,7 @@ for python in "${PYTHON_EXECS[@]}"; do
source "$VIRTUALENV_PATH"/bin/activate
fi
# Upgrade pip & friends if using virutal env
- if [ ! -n "USE_CONDA" ]; then
+ if [ ! -n "$USE_CONDA" ]; then
pip install --upgrade pip pypandoc wheel numpy
fi
diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md
index d660655e193eb..b3d109039da4d 100644
--- a/docs/ml-classification-regression.md
+++ b/docs/ml-classification-regression.md
@@ -455,11 +455,29 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classificat
## Naive Bayes
[Naive Bayes classifiers](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) are a family of simple
-probabilistic classifiers based on applying Bayes' theorem with strong (naive) independence
-assumptions between the features. The `spark.ml` implementation currently supports both [multinomial
-naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html)
+probabilistic, multiclass classifiers based on applying Bayes' theorem with strong (naive) independence
+assumptions between every pair of features.
+
+Naive Bayes can be trained very efficiently. With a single pass over the training data,
+it computes the conditional probability distribution of each feature given each label.
+For prediction, it applies Bayes' theorem to compute the conditional probability distribution
+of each label given an observation.
+
+MLlib supports both [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes)
and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html).
-More information can be found in the section on [Naive Bayes in MLlib](mllib-naive-bayes.html#naive-bayes-sparkmllib).
+
+*Input data*:
+These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html).
+Within that context, each observation is a document and each feature represents a term.
+A feature's value is the frequency of the term (in multinomial Naive Bayes) or
+a zero or one indicating whether the term was found in the document (in Bernoulli Naive Bayes).
+Feature values must be *non-negative*. The model type is selected with an optional parameter
+"multinomial" or "bernoulli" with "multinomial" as the default.
+For document classification, the input feature vectors should usually be sparse vectors.
+Since the training data is only used once, it is not necessary to cache it.
+
+[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by
+setting the parameter $\lambda$ (default to $1.0$).
**Examples**
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index ceda8a3ae2403..c9e68c3bfd056 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -133,9 +133,8 @@ To use a custom metrics.properties for the application master and executors, upd
spark.yarn.am.waitTime |
100s |
- In cluster mode, time for the YARN Application Master to wait for the
- SparkContext to be initialized. In client mode, time for the YARN Application Master to wait
- for the driver to connect to it.
+ Only used in cluster mode. Time for the YARN Application Master to wait for the
+ SparkContext to be initialized.
|
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 075b953a0898e..3f79ed6422205 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -964,7 +964,7 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession
Sets the compression codec used when writing Parquet files. If either `compression` or
`parquet.compression` is specified in the table-specific options/properties, the precedence would be
`compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`. Acceptable values include:
- none, uncompressed, snappy, gzip, lzo.
+ none, uncompressed, snappy, gzip, lzo, brotli, lz4, zstd.
@@ -1017,7 +1017,7 @@ the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is also
Property Name | Default | Meaning |
spark.sql.orc.impl |
- hive |
+ native |
The name of ORC implementation. It can be one of native and hive . native means the native ORC support that is built on Apache ORC 1.4. `hive` means the ORC library in Hive 1.2.1. |
@@ -1812,6 +1812,9 @@ working with timestamps in `pandas_udf`s to get the best performance, see
- Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0.
- Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception.
- In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround.
+ - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files.
+ - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior.
+
## Upgrading From Spark SQL 2.2 to 2.3
- Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`.
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
index f26c134c2f6e9..88abf8a8dd027 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset}
import org.apache.spark.sql.types.StructType
/**
@@ -86,7 +86,7 @@ class KafkaContinuousReader(
KafkaSourceOffset(JsonUtils.partitionOffsets(json))
}
- override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
+ override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = {
import scala.collection.JavaConverters._
val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset)
@@ -108,7 +108,7 @@ class KafkaContinuousReader(
case (topicPartition, start) =>
KafkaContinuousDataReaderFactory(
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
- .asInstanceOf[DataReaderFactory[UnsafeRow]]
+ .asInstanceOf[InputPartition[UnsafeRow]]
}.asJava
}
@@ -161,18 +161,18 @@ case class KafkaContinuousDataReaderFactory(
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
- failOnDataLoss: Boolean) extends ContinuousDataReaderFactory[UnsafeRow] {
+ failOnDataLoss: Boolean) extends ContinuousInputPartition[UnsafeRow] {
- override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[UnsafeRow] = {
+ override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[UnsafeRow] = {
val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset]
require(kafkaOffset.topicPartition == topicPartition,
s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}")
- new KafkaContinuousDataReader(
+ new KafkaContinuousInputPartitionReader(
topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss)
}
- override def createDataReader(): KafkaContinuousDataReader = {
- new KafkaContinuousDataReader(
+ override def createPartitionReader(): KafkaContinuousInputPartitionReader = {
+ new KafkaContinuousInputPartitionReader(
topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss)
}
}
@@ -187,12 +187,12 @@ case class KafkaContinuousDataReaderFactory(
* @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
* are skipped.
*/
-class KafkaContinuousDataReader(
+class KafkaContinuousInputPartitionReader(
topicPartition: TopicPartition,
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
- failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] {
+ failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[UnsafeRow] {
private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false)
private val converter = new KafkaRecordToUnsafeRowConverter
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala
index cbe655f9bff1f..8a377738ea782 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset}
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.DataSourceOptions
-import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow}
+import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.UninterruptibleThread
@@ -101,7 +101,7 @@ private[kafka010] class KafkaMicroBatchReader(
}
}
- override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
+ override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = {
// Find the new partitions, and get their earliest offsets
val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet)
val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq)
@@ -146,7 +146,7 @@ private[kafka010] class KafkaMicroBatchReader(
new KafkaMicroBatchDataReaderFactory(
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer)
}
- factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava
+ factories.map(_.asInstanceOf[InputPartition[UnsafeRow]]).asJava
}
override def getStartOffset: Offset = {
@@ -299,27 +299,28 @@ private[kafka010] class KafkaMicroBatchReader(
}
}
-/** A [[DataReaderFactory]] for reading Kafka data in a micro-batch streaming query. */
+/** A [[InputPartition]] for reading Kafka data in a micro-batch streaming query. */
private[kafka010] case class KafkaMicroBatchDataReaderFactory(
offsetRange: KafkaOffsetRange,
executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean,
- reuseKafkaConsumer: Boolean) extends DataReaderFactory[UnsafeRow] {
+ reuseKafkaConsumer: Boolean) extends InputPartition[UnsafeRow] {
override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray
- override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader(
- offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer)
+ override def createPartitionReader(): InputPartitionReader[UnsafeRow] =
+ new KafkaMicroBatchInputPartitionReader(offsetRange, executorKafkaParams, pollTimeoutMs,
+ failOnDataLoss, reuseKafkaConsumer)
}
-/** A [[DataReader]] for reading Kafka data in a micro-batch streaming query. */
-private[kafka010] case class KafkaMicroBatchDataReader(
+/** A [[InputPartitionReader]] for reading Kafka data in a micro-batch streaming query. */
+private[kafka010] case class KafkaMicroBatchInputPartitionReader(
offsetRange: KafkaOffsetRange,
executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean,
- reuseKafkaConsumer: Boolean) extends DataReader[UnsafeRow] with Logging {
+ reuseKafkaConsumer: Boolean) extends InputPartitionReader[UnsafeRow] with Logging {
private val consumer = KafkaDataConsumer.acquire(
offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer)
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
index 36b9f0466566b..d225c1ea6b7f1 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSessio
import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport}
+import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
@@ -149,7 +150,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
}
/**
- * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousDataReader]] to read
+ * Creates a [[ContinuousInputPartitionReader]] to read
* Kafka data in a continuous streaming query.
*/
override def createContinuousReader(
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
index d2d04b68de6ab..871f9700cd1db 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
@@ -678,7 +678,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase {
Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))),
Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L)))
)
- val factories = reader.createUnsafeRowReaderFactories().asScala
+ val factories = reader.planUnsafeInputPartitions().asScala
.map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory])
withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") {
assert(factories.size == numPartitionsGenerated)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 57797d1cc4978..c9786f1f7ceb1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -97,9 +97,11 @@ class DecisionTreeClassifier @Since("1.4.0") (
override def setSeed(value: Long): this.type = set(seed, value)
override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
+ val instr = Instrumentation.create(this, dataset)
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
+ instr.logNumClasses(numClasses)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
@@ -110,8 +112,8 @@ class DecisionTreeClassifier @Since("1.4.0") (
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
- val instr = Instrumentation.create(this, oldDataset)
- instr.logParams(params: _*)
+ instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
+ cacheNodeIds, checkpointInterval, impurity, seed)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
@@ -125,7 +127,8 @@ class DecisionTreeClassifier @Since("1.4.0") (
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeClassificationModel = {
val instr = Instrumentation.create(this, data)
- instr.logParams(params: _*)
+ instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
+ cacheNodeIds, checkpointInterval, impurity, seed)
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 0L, instr = Some(instr), parentUID = Some(uid))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 0aa24f0a3cfcc..3fb6d1e4e4f3e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -334,6 +334,21 @@ class GBTClassificationModel private[ml](
// hard coded loss, which is not meant to be changed in the model
private val loss = getOldLossType
+ /**
+ * Method to compute error or loss for every iteration of gradient boosting.
+ *
+ * @param dataset Dataset for validation.
+ */
+ @Since("2.4.0")
+ def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = {
+ val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
+ case Row(label: Double, features: Vector) => LabeledPoint(label, features)
+ }
+ GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, loss,
+ OldAlgo.Classification
+ )
+ }
+
@Since("2.0.0")
override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index 80c537e1e0eb2..38eb04556b775 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -170,7 +170,7 @@ class LinearSVC @Since("2.2.0") (
Instance(label, weight, features)
}
- val instr = Instrumentation.create(this, instances)
+ val instr = Instrumentation.create(this, dataset)
instr.logParams(regParam, maxIter, fitIntercept, tol, standardization, threshold,
aggregationDepth)
@@ -187,6 +187,9 @@ class LinearSVC @Since("2.2.0") (
(new MultivariateOnlineSummarizer, new MultiClassSummarizer)
)(seqOp, combOp, $(aggregationDepth))
}
+ instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count)
+ instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString)
+ instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString)
val histogram = labelSummarizer.histogram
val numInvalid = labelSummarizer.countInvalid
@@ -209,7 +212,7 @@ class LinearSVC @Since("2.2.0") (
if (numInvalid != 0) {
val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " +
s"Found $numInvalid invalid labels."
- logError(msg)
+ instr.logError(msg)
throw new SparkException(msg)
}
@@ -246,7 +249,7 @@ class LinearSVC @Since("2.2.0") (
bcFeaturesStd.destroy(blocking = false)
if (state == null) {
val msg = s"${optimizer.getClass.getName} failed."
- logError(msg)
+ instr.logError(msg)
throw new SparkException(msg)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 45fb585ed2262..1dde18d2d1a31 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -126,8 +126,10 @@ class NaiveBayes @Since("1.5.0") (
private[spark] def trainWithLabelCheck(
dataset: Dataset[_],
positiveLabel: Boolean): NaiveBayesModel = {
+ val instr = Instrumentation.create(this, dataset)
if (positiveLabel && isDefined(thresholds)) {
val numClasses = getNumClasses(dataset)
+ instr.logNumClasses(numClasses)
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
@@ -146,7 +148,6 @@ class NaiveBayes @Since("1.5.0") (
}
}
- val instr = Instrumentation.create(this, dataset)
instr.logParams(labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol,
probabilityCol, modelType, smoothing, thresholds)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 7df53a6b8ad10..3474b61e40136 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -366,7 +366,7 @@ final class OneVsRest @Since("1.4.0") (
transformSchema(dataset.schema)
val instr = Instrumentation.create(this, dataset)
- instr.logParams(labelCol, featuresCol, predictionCol, parallelism)
+ instr.logParams(labelCol, featuresCol, predictionCol, parallelism, rawPredictionCol)
instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName)
// determine number of classes either from metadata if provided, or via computation.
@@ -383,7 +383,7 @@ final class OneVsRest @Since("1.4.0") (
getClassifier match {
case _: HasWeightCol => true
case c =>
- logWarning(s"weightCol is ignored, as it is not supported by $c now.")
+ instr.logWarning(s"weightCol is ignored, as it is not supported by $c now.")
false
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index f1ef26a07d3f8..040db3b94b041 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -116,6 +116,7 @@ class RandomForestClassifier @Since("1.4.0") (
set(featureSubsetStrategy, value)
override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = {
+ val instr = Instrumentation.create(this, dataset)
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
@@ -130,7 +131,6 @@ class RandomForestClassifier @Since("1.4.0") (
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
- val instr = Instrumentation.create(this, oldDataset)
instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol,
impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)
@@ -141,6 +141,8 @@ class RandomForestClassifier @Since("1.4.0") (
val numFeatures = oldDataset.first().features.size
val m = new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)
+ instr.logNumClasses(numClasses)
+ instr.logNumFeatures(numFeatures)
instr.logSuccess(m)
m
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 438e53ba6197c..1ad4e097246a3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -261,8 +261,9 @@ class BisectingKMeans @Since("2.0.0") (
transformSchema(dataset.schema, logging = true)
val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol)
- val instr = Instrumentation.create(this, rdd)
- instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize)
+ val instr = Instrumentation.create(this, dataset)
+ instr.logParams(featuresCol, predictionCol, k, maxIter, seed,
+ minDivisibleClusterSize, distanceMeasure)
val bkm = new MLlibBisectingKMeans()
.setK($(k))
@@ -275,6 +276,8 @@ class BisectingKMeans @Since("2.0.0") (
val summary = new BisectingKMeansSummary(
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(Some(summary))
+ // TODO: need to extend logNamedValue to support Array
+ instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]"))
instr.logSuccess(model)
model
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 88d618c3a03a8..3091bb5a2e54c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -352,7 +352,7 @@ class GaussianMixture @Since("2.0.0") (
s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" +
s" matrix is quadratic in the number of features.")
- val instr = Instrumentation.create(this, instances)
+ val instr = Instrumentation.create(this, dataset)
instr.logParams(featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol)
instr.logNumFeatures(numFeatures)
@@ -425,6 +425,9 @@ class GaussianMixture @Since("2.0.0") (
val summary = new GaussianMixtureSummary(model.transform(dataset),
$(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood)
model.setSummary(Some(summary))
+ instr.logNamedValue("logLikelihood", logLikelihood)
+ // TODO: need to extend logNamedValue to support Array
+ instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]"))
instr.logSuccess(model)
model
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 97f246fbfd859..e72d7f9485e6a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -342,7 +342,7 @@ class KMeans @Since("1.5.0") (
instances.persist(StorageLevel.MEMORY_AND_DISK)
}
- val instr = Instrumentation.create(this, instances)
+ val instr = Instrumentation.create(this, dataset)
instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure,
maxIter, seed, tol)
val algo = new MLlibKMeans()
@@ -359,6 +359,8 @@ class KMeans @Since("1.5.0") (
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(Some(summary))
+ // TODO: need to extend logNamedValue to support Array
+ instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]"))
instr.logSuccess(model)
if (handlePersistence) {
instances.unpersist()
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index afe599cd167cb..fed42c959b5ef 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -569,10 +569,14 @@ abstract class LDAModel private[ml] (
class LocalLDAModel private[ml] (
uid: String,
vocabSize: Int,
- @Since("1.6.0") override private[clustering] val oldLocalModel: OldLocalLDAModel,
+ private[clustering] val oldLocalModel_ : OldLocalLDAModel,
sparkSession: SparkSession)
extends LDAModel(uid, vocabSize, sparkSession) {
+ override private[clustering] def oldLocalModel: OldLocalLDAModel = {
+ oldLocalModel_.setSeed(getSeed)
+ }
+
@Since("1.6.0")
override def copy(extra: ParamMap): LocalLDAModel = {
val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sparkSession)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 8598e808c4946..d7e054bf55ef6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -34,7 +34,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
/**
@@ -269,6 +269,21 @@ class GBTRegressionModel private[ml](
new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights)
}
+ /**
+ * Method to compute error or loss for every iteration of gradient boosting.
+ *
+ * @param dataset Dataset for validation.
+ * @param loss The loss function used to compute error. Supported options: squared, absolute
+ */
+ @Since("2.4.0")
+ def evaluateEachIteration(dataset: Dataset[_], loss: String): Array[Double] = {
+ val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
+ case Row(label: Double, features: Vector) => LabeledPoint(label, features)
+ }
+ GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights,
+ convertToOldLossType(loss), OldAlgo.Regression)
+ }
+
@Since("2.0.0")
override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 81b6222acc7ce..ec8868bb42cbb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -579,7 +579,11 @@ private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams
/** (private[ml]) Convert new loss to old loss. */
override private[ml] def getOldLossType: OldLoss = {
- getLossType match {
+ convertToOldLossType(getLossType)
+ }
+
+ private[ml] def convertToOldLossType(loss: String): OldLoss = {
+ loss match {
case "squared" => OldSquaredError
case "absolute" => OldAbsoluteError
case _ =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index b8a6e94248421..f915062d77389 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SparkSession}
-import org.apache.spark.util.BoundedPriorityQueue
+import org.apache.spark.util.{BoundedPriorityQueue, Utils}
/**
* Latent Dirichlet Allocation (LDA) model.
@@ -194,6 +194,8 @@ class LocalLDAModel private[spark] (
override protected[spark] val gammaShape: Double = 100)
extends LDAModel with Serializable {
+ private var seed: Long = Utils.random.nextLong()
+
@Since("1.3.0")
override def k: Int = topics.numCols
@@ -216,6 +218,21 @@ class LocalLDAModel private[spark] (
override protected def formatVersion = "1.0"
+ /**
+ * Random seed for cluster initialization.
+ */
+ @Since("2.4.0")
+ def getSeed: Long = seed
+
+ /**
+ * Set the random seed for cluster initialization.
+ */
+ @Since("2.4.0")
+ def setSeed(seed: Long): this.type = {
+ this.seed = seed
+ this
+ }
+
@Since("1.5.0")
override def save(sc: SparkContext, path: String): Unit = {
LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration,
@@ -298,6 +315,7 @@ class LocalLDAModel private[spark] (
// by topic (columns of lambda)
val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t
val ElogbetaBc = documents.sparkContext.broadcast(Elogbeta)
+ val gammaSeed = this.seed
// Sum bound components for each document:
// component for prob(tokens) + component for prob(document-topic distribution)
@@ -306,7 +324,7 @@ class LocalLDAModel private[spark] (
val localElogbeta = ElogbetaBc.value
var docBound = 0.0D
val (gammad: BDV[Double], _, _) = OnlineLDAOptimizer.variationalTopicInference(
- termCounts, exp(localElogbeta), brzAlpha, gammaShape, k)
+ termCounts, exp(localElogbeta), brzAlpha, gammaShape, k, gammaSeed + id)
val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad)
// E[log p(doc | theta, beta)]
@@ -352,6 +370,7 @@ class LocalLDAModel private[spark] (
val docConcentrationBrz = this.docConcentration.asBreeze
val gammaShape = this.gammaShape
val k = this.k
+ val gammaSeed = this.seed
documents.map { case (id: Long, termCounts: Vector) =>
if (termCounts.numNonzeros == 0) {
@@ -362,7 +381,8 @@ class LocalLDAModel private[spark] (
expElogbetaBc.value,
docConcentrationBrz,
gammaShape,
- k)
+ k,
+ gammaSeed + id)
(id, Vectors.dense(normalize(gamma, 1.0).toArray))
}
}
@@ -376,6 +396,7 @@ class LocalLDAModel private[spark] (
val docConcentrationBrz = this.docConcentration.asBreeze
val gammaShape = this.gammaShape
val k = this.k
+ val gammaSeed = this.seed
(termCounts: Vector) =>
if (termCounts.numNonzeros == 0) {
@@ -386,7 +407,8 @@ class LocalLDAModel private[spark] (
expElogbeta,
docConcentrationBrz,
gammaShape,
- k)
+ k,
+ gammaSeed)
Vectors.dense(normalize(gamma, 1.0).toArray)
}
}
@@ -403,6 +425,7 @@ class LocalLDAModel private[spark] (
*/
@Since("2.0.0")
def topicDistribution(document: Vector): Vector = {
+ val gammaSeed = this.seed
val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t)
if (document.numNonzeros == 0) {
Vectors.zeros(this.k)
@@ -412,7 +435,8 @@ class LocalLDAModel private[spark] (
expElogbeta,
this.docConcentration.asBreeze,
gammaShape,
- this.k)
+ this.k,
+ gammaSeed)
Vectors.dense(normalize(gamma, 1.0).toArray)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 693a2a31f026b..f8e5f3ed76457 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
@@ -464,6 +465,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging {
val alpha = this.alpha.asBreeze
val gammaShape = this.gammaShape
val optimizeDocConcentration = this.optimizeDocConcentration
+ val seed = randomGenerator.nextLong()
// If and only if optimizeDocConcentration is set true,
// we calculate logphat in the same pass as other statistics.
// No calculation of loghat happens otherwise.
@@ -473,20 +475,21 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging {
None
}
- val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitions { docs =>
- val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0)
-
- val stat = BDM.zeros[Double](k, vocabSize)
- val logphatPartOption = logphatPartOptionBase()
- var nonEmptyDocCount: Long = 0L
- nonEmptyDocs.foreach { case (_, termCounts: Vector) =>
- nonEmptyDocCount += 1
- val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference(
- termCounts, expElogbetaBc.value, alpha, gammaShape, k)
- stat(::, ids) := stat(::, ids) + sstats
- logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad))
- }
- Iterator((stat, logphatPartOption, nonEmptyDocCount))
+ val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitionsWithIndex {
+ (index, docs) =>
+ val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0)
+
+ val stat = BDM.zeros[Double](k, vocabSize)
+ val logphatPartOption = logphatPartOptionBase()
+ var nonEmptyDocCount: Long = 0L
+ nonEmptyDocs.foreach { case (_, termCounts: Vector) =>
+ nonEmptyDocCount += 1
+ val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference(
+ termCounts, expElogbetaBc.value, alpha, gammaShape, k, seed + index)
+ stat(::, ids) := stat(::, ids) + sstats
+ logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad))
+ }
+ Iterator((stat, logphatPartOption, nonEmptyDocCount))
}
val elementWiseSum = (
@@ -578,7 +581,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging {
}
override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
- new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta, gammaShape)
+ new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta)
+ .setSeed(randomGenerator.nextLong())
}
}
@@ -605,18 +609,20 @@ private[clustering] object OnlineLDAOptimizer {
expElogbeta: BDM[Double],
alpha: breeze.linalg.Vector[Double],
gammaShape: Double,
- k: Int): (BDV[Double], BDM[Double], List[Int]) = {
+ k: Int,
+ seed: Long): (BDV[Double], BDM[Double], List[Int]) = {
val (ids: List[Int], cts: Array[Double]) = termCounts match {
case v: DenseVector => ((0 until v.size).toList, v.values)
case v: SparseVector => (v.indices.toList, v.values)
}
// Initialize the variational distribution q(theta|gamma) for the mini-batch
+ val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(seed))
val gammad: BDV[Double] =
- new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K
+ new Gamma(gammaShape, 1.0 / gammaShape)(randBasis).samplesVector(k) // K
val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K
val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K
- val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids
+ val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids
var meanGammaChange = 1D
val ctsVector = new BDV[Double](cts) // ids
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index 3ca75e8cdb97a..7a5e520d5818e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -43,7 +43,7 @@ import org.apache.spark.util.random.XORShiftRandom
* $$
* \begin{align}
* c_t+1 &= [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] \\
- * n_t+t &= n_t * a + m_t
+ * n_t+1 &= n_t * a + m_t
* \end{align}
* $$
*
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index f0ee5496f9d1d..e20de196d65ca 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.RegressionLeafNode
-import org.apache.spark.ml.tree.impl.TreeTests
+import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
@@ -365,6 +365,33 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
assert(mostImportantFeature !== mostIF)
}
+ test("model evaluateEachIteration") {
+ val gbt = new GBTClassifier()
+ .setSeed(1L)
+ .setMaxDepth(2)
+ .setMaxIter(3)
+ .setLossType("logistic")
+ val model3 = gbt.fit(trainData.toDF)
+ val model1 = new GBTClassificationModel("gbt-cls-model-test1",
+ model3.trees.take(1), model3.treeWeights.take(1), model3.numFeatures, model3.numClasses)
+ val model2 = new GBTClassificationModel("gbt-cls-model-test2",
+ model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures, model3.numClasses)
+
+ val evalArr = model3.evaluateEachIteration(validationData.toDF)
+ val remappedValidationData = validationData.map(
+ x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val lossErr1 = GradientBoostedTrees.computeError(remappedValidationData,
+ model1.trees, model1.treeWeights, model1.getOldLossType)
+ val lossErr2 = GradientBoostedTrees.computeError(remappedValidationData,
+ model2.trees, model2.treeWeights, model2.getOldLossType)
+ val lossErr3 = GradientBoostedTrees.computeError(remappedValidationData,
+ model3.trees, model3.treeWeights, model3.getOldLossType)
+
+ assert(evalArr(0) ~== lossErr1 relTol 1E-3)
+ assert(evalArr(1) ~== lossErr2 relTol 1E-3)
+ assert(evalArr(2) ~== lossErr3 relTol 1E-3)
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index 8d728f063dd8c..4d848205034c0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -253,6 +253,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
val lda = new LDA()
testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings,
LDASuite.allParamSettings, checkModelData)
+
+ // Make sure the result is deterministic after saving and loading the model
+ val model = lda.fit(dataset)
+ val model2 = testDefaultReadWrite(model)
+ assert(model.logLikelihood(dataset) ~== model2.logLikelihood(dataset) absTol 1e-6)
+ assert(model.logPerplexity(dataset) ~== model2.logPerplexity(dataset) absTol 1e-6)
}
test("read/write DistributedLDAModel") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index fad11d078250f..773f6d2c542fe 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -20,8 +20,9 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.tree.impl.TreeTests
+import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
+import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -201,7 +202,34 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
assert(mostImportantFeature !== mostIF)
}
-
+ test("model evaluateEachIteration") {
+ for (lossType <- GBTRegressor.supportedLossTypes) {
+ val gbt = new GBTRegressor()
+ .setSeed(1L)
+ .setMaxDepth(2)
+ .setMaxIter(3)
+ .setLossType(lossType)
+ val model3 = gbt.fit(trainData.toDF)
+ val model1 = new GBTRegressionModel("gbt-reg-model-test1",
+ model3.trees.take(1), model3.treeWeights.take(1), model3.numFeatures)
+ val model2 = new GBTRegressionModel("gbt-reg-model-test2",
+ model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures)
+
+ for (evalLossType <- GBTRegressor.supportedLossTypes) {
+ val evalArr = model3.evaluateEachIteration(validationData.toDF, evalLossType)
+ val lossErr1 = GradientBoostedTrees.computeError(validationData,
+ model1.trees, model1.treeWeights, model1.convertToOldLossType(evalLossType))
+ val lossErr2 = GradientBoostedTrees.computeError(validationData,
+ model2.trees, model2.treeWeights, model2.convertToOldLossType(evalLossType))
+ val lossErr3 = GradientBoostedTrees.computeError(validationData,
+ model3.trees, model3.treeWeights, model3.convertToOldLossType(evalLossType))
+
+ assert(evalArr(0) ~== lossErr1 relTol 1E-3)
+ assert(evalArr(1) ~== lossErr2 relTol 1E-3)
+ assert(evalArr(2) ~== lossErr3 relTol 1E-3)
+ }
+ }
+ }
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
diff --git a/pom.xml b/pom.xml
index 88e77ff874748..6e37e518d86e4 100644
--- a/pom.xml
+++ b/pom.xml
@@ -129,7 +129,7 @@
1.2.1
10.12.1.1
- 1.8.2
+ 1.10.0
1.4.3
nohive
1.6.0
@@ -1778,6 +1778,12 @@
parquet-hadoop
${parquet.version}
${parquet.deps.scope}
+
+
+ commons-pool
+ commons-pool
+
+
org.apache.parquet
diff --git a/python/README.md b/python/README.md
index 2e0112da58b94..c020d84b01ffd 100644
--- a/python/README.md
+++ b/python/README.md
@@ -29,4 +29,4 @@ The Python packaging for Spark is not intended to replace all of the other use c
## Python Requirements
-At its core PySpark depends on Py4J (currently version 0.10.6), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow).
+At its core PySpark depends on Py4J (currently version 0.10.7), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow).
diff --git a/python/docs/Makefile b/python/docs/Makefile
index 09898f29950ed..b8e079483c90c 100644
--- a/python/docs/Makefile
+++ b/python/docs/Makefile
@@ -7,7 +7,7 @@ SPHINXBUILD ?= sphinx-build
PAPER ?=
BUILDDIR ?= _build
-export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.6-src.zip)
+export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.7-src.zip)
# User-friendly check for sphinx-build
ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
diff --git a/python/lib/py4j-0.10.6-src.zip b/python/lib/py4j-0.10.6-src.zip
deleted file mode 100644
index 2f8edcc0c0b88..0000000000000
Binary files a/python/lib/py4j-0.10.6-src.zip and /dev/null differ
diff --git a/python/lib/py4j-0.10.7-src.zip b/python/lib/py4j-0.10.7-src.zip
new file mode 100644
index 0000000000000..128e321078793
Binary files /dev/null and b/python/lib/py4j-0.10.7-src.zip differ
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 7c664966ed74e..ede3b6af0a8cf 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -211,9 +211,21 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
for path in self._conf.get("spark.submit.pyFiles", "").split(","):
if path != "":
(dirname, filename) = os.path.split(path)
- if filename[-4:].lower() in self.PACKAGE_EXTENSIONS:
- self._python_includes.append(filename)
- sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename))
+ try:
+ filepath = os.path.join(SparkFiles.getRootDirectory(), filename)
+ if not os.path.exists(filepath):
+ # In case of YARN with shell mode, 'spark.submit.pyFiles' files are
+ # not added via SparkContext.addFile. Here we check if the file exists,
+ # try to copy and then add it to the path. See SPARK-21945.
+ shutil.copyfile(path, filepath)
+ if filename[-4:].lower() in self.PACKAGE_EXTENSIONS:
+ self._python_includes.append(filename)
+ sys.path.insert(1, filepath)
+ except Exception:
+ warnings.warn(
+ "Failed to add file [%s] speficied in 'spark.submit.pyFiles' to "
+ "Python path:\n %s" % (path, "\n ".join(sys.path)),
+ RuntimeWarning)
# Create a temporary directory inside spark.local.dir:
local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
@@ -998,8 +1010,8 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
# by runJob() in order to avoid having to pass a Python lambda into
# SparkContext#runJob.
mappedRDD = rdd.mapPartitions(partitionFunc)
- port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions)
- return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))
+ sock_info = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions)
+ return list(_load_from_socket(sock_info, mappedRDD._jrdd_deserializer))
def show_profiles(self):
""" Print the profile stats to stdout """
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 7bed5216eabf3..ebdd665e349c5 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -29,7 +29,7 @@
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT
from pyspark.worker import main as worker_main
-from pyspark.serializers import read_int, write_int
+from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer
def compute_real_exit_code(exit_code):
@@ -40,7 +40,7 @@ def compute_real_exit_code(exit_code):
return 1
-def worker(sock):
+def worker(sock, authenticated):
"""
Called by a worker process after the fork().
"""
@@ -56,6 +56,18 @@ def worker(sock):
# otherwise writes also cause a seek that makes us miss data on the read side.
infile = os.fdopen(os.dup(sock.fileno()), "rb", 65536)
outfile = os.fdopen(os.dup(sock.fileno()), "wb", 65536)
+
+ if not authenticated:
+ client_secret = UTF8Deserializer().loads(infile)
+ if os.environ["PYTHON_WORKER_FACTORY_SECRET"] == client_secret:
+ write_with_length("ok".encode("utf-8"), outfile)
+ outfile.flush()
+ else:
+ write_with_length("err".encode("utf-8"), outfile)
+ outfile.flush()
+ sock.close()
+ return 1
+
exit_code = 0
try:
worker_main(infile, outfile)
@@ -153,8 +165,11 @@ def handle_sigterm(*args):
write_int(os.getpid(), outfile)
outfile.flush()
outfile.close()
+ authenticated = False
while True:
- code = worker(sock)
+ code = worker(sock, authenticated)
+ if code == 0:
+ authenticated = True
if not reuse or code:
# wait for closing
try:
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 3e704fe9bf6ec..0afbe9dc6aa3e 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -21,16 +21,19 @@
import select
import signal
import shlex
+import shutil
import socket
import platform
+import tempfile
+import time
from subprocess import Popen, PIPE
if sys.version >= '3':
xrange = range
-from py4j.java_gateway import java_import, JavaGateway, GatewayClient
+from py4j.java_gateway import java_import, JavaGateway, GatewayParameters
from pyspark.find_spark_home import _find_spark_home
-from pyspark.serializers import read_int
+from pyspark.serializers import read_int, write_with_length, UTF8Deserializer
def launch_gateway(conf=None):
@@ -41,6 +44,7 @@ def launch_gateway(conf=None):
"""
if "PYSPARK_GATEWAY_PORT" in os.environ:
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
+ gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
else:
SPARK_HOME = _find_spark_home()
# Launch the Py4j gateway using Spark's run command so that we pick up the
@@ -59,40 +63,40 @@ def launch_gateway(conf=None):
])
command = command + shlex.split(submit_args)
- # Start a socket that will be used by PythonGatewayServer to communicate its port to us
- callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- callback_socket.bind(('127.0.0.1', 0))
- callback_socket.listen(1)
- callback_host, callback_port = callback_socket.getsockname()
- env = dict(os.environ)
- env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host
- env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port)
-
- # Launch the Java gateway.
- # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
- if not on_windows:
- # Don't send ctrl-c / SIGINT to the Java gateway:
- def preexec_func():
- signal.signal(signal.SIGINT, signal.SIG_IGN)
- proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
- else:
- # preexec_fn not supported on Windows
- proc = Popen(command, stdin=PIPE, env=env)
-
- gateway_port = None
- # We use select() here in order to avoid blocking indefinitely if the subprocess dies
- # before connecting
- while gateway_port is None and proc.poll() is None:
- timeout = 1 # (seconds)
- readable, _, _ = select.select([callback_socket], [], [], timeout)
- if callback_socket in readable:
- gateway_connection = callback_socket.accept()[0]
- # Determine which ephemeral port the server started on:
- gateway_port = read_int(gateway_connection.makefile(mode="rb"))
- gateway_connection.close()
- callback_socket.close()
- if gateway_port is None:
- raise Exception("Java gateway process exited before sending the driver its port number")
+ # Create a temporary directory where the gateway server should write the connection
+ # information.
+ conn_info_dir = tempfile.mkdtemp()
+ try:
+ fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir)
+ os.close(fd)
+ os.unlink(conn_info_file)
+
+ env = dict(os.environ)
+ env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file
+
+ # Launch the Java gateway.
+ # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
+ if not on_windows:
+ # Don't send ctrl-c / SIGINT to the Java gateway:
+ def preexec_func():
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
+ proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
+ else:
+ # preexec_fn not supported on Windows
+ proc = Popen(command, stdin=PIPE, env=env)
+
+ # Wait for the file to appear, or for the process to exit, whichever happens first.
+ while not proc.poll() and not os.path.isfile(conn_info_file):
+ time.sleep(0.1)
+
+ if not os.path.isfile(conn_info_file):
+ raise Exception("Java gateway process exited before sending its port number")
+
+ with open(conn_info_file, "rb") as info:
+ gateway_port = read_int(info)
+ gateway_secret = UTF8Deserializer().loads(info)
+ finally:
+ shutil.rmtree(conn_info_dir)
# In Windows, ensure the Java child processes do not linger after Python has exited.
# In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
@@ -111,7 +115,9 @@ def killChild():
atexit.register(killChild)
# Connect to the gateway
- gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True)
+ gateway = JavaGateway(
+ gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret,
+ auto_convert=True))
# Import the classes used by PySpark
java_import(gateway.jvm, "org.apache.spark.SparkConf")
@@ -126,3 +132,16 @@ def killChild():
java_import(gateway.jvm, "scala.Tuple2")
return gateway
+
+
+def do_server_auth(conn, auth_secret):
+ """
+ Performs the authentication protocol defined by the SocketAuthHelper class on the given
+ file-like object 'conn'.
+ """
+ write_with_length(auth_secret.encode("utf-8"), conn)
+ conn.flush()
+ reply = UTF8Deserializer().loads(conn)
+ if reply != "ok":
+ conn.close()
+ raise Exception("Unexpected reply from iterator server.")
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index ec17653a1adf9..424ecfd89b060 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -1222,6 +1222,10 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
True
>>> model.trees
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
+ >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)],
+ ... ["indexed", "features"])
+ >>> model.evaluateEachIteration(validation)
+ [0.25..., 0.23..., 0.21..., 0.19..., 0.18...]
.. versionadded:: 1.4.0
"""
@@ -1319,6 +1323,17 @@ def trees(self):
"""Trees in this ensemble. Warning: These have null parent Estimators."""
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
+ @since("2.4.0")
+ def evaluateEachIteration(self, dataset):
+ """
+ Method to compute error or loss for every iteration of gradient boosting.
+
+ :param dataset:
+ Test dataset to evaluate model on, where dataset is an
+ instance of :py:class:`pyspark.sql.DataFrame`
+ """
+ return self._call_java("evaluateEachIteration", dataset)
+
@inherit_doc
class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 9a66d87d7f211..dd0b62f184d26 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -1056,6 +1056,10 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
True
>>> model.trees
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
+ >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))],
+ ... ["label", "features"])
+ >>> model.evaluateEachIteration(validation, "squared")
+ [0.0, 0.0, 0.0, 0.0, 0.0]
.. versionadded:: 1.4.0
"""
@@ -1156,6 +1160,20 @@ def trees(self):
"""Trees in this ensemble. Warning: These have null parent Estimators."""
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
+ @since("2.4.0")
+ def evaluateEachIteration(self, dataset, loss):
+ """
+ Method to compute error or loss for every iteration of gradient boosting.
+
+ :param dataset:
+ Test dataset to evaluate model on, where dataset is an
+ instance of :py:class:`pyspark.sql.DataFrame`
+ :param loss:
+ The loss function used to compute error.
+ Supported options: squared, absolute
+ """
+ return self._call_java("evaluateEachIteration", dataset, loss)
+
@inherit_doc
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 093593132e56d..0dde0db9e3339 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -1595,6 +1595,44 @@ def test_default_read_write(self):
self.assertEqual(lr.uid, lr3.uid)
self.assertEqual(lr.extractParamMap(), lr3.extractParamMap())
+ def test_default_read_write_default_params(self):
+ lr = LogisticRegression()
+ self.assertFalse(lr.isSet(lr.getParam("threshold")))
+
+ lr.setMaxIter(50)
+ lr.setThreshold(.75)
+
+ # `threshold` is set by user, default param `predictionCol` is not set by user.
+ self.assertTrue(lr.isSet(lr.getParam("threshold")))
+ self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
+ self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))
+
+ writer = DefaultParamsWriter(lr)
+ metadata = json.loads(writer._get_metadata_to_save(lr, self.sc))
+ self.assertTrue("defaultParamMap" in metadata)
+
+ reader = DefaultParamsReadable.read()
+ metadataStr = json.dumps(metadata, separators=[',', ':'])
+ loadedMetadata = reader._parseMetaData(metadataStr, )
+ reader.getAndSetParams(lr, loadedMetadata)
+
+ self.assertTrue(lr.isSet(lr.getParam("threshold")))
+ self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
+ self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))
+
+ # manually create metadata without `defaultParamMap` section.
+ del metadata['defaultParamMap']
+ metadataStr = json.dumps(metadata, separators=[',', ':'])
+ loadedMetadata = reader._parseMetaData(metadataStr, )
+ with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"):
+ reader.getAndSetParams(lr, loadedMetadata)
+
+ # Prior to 2.4.0, metadata doesn't have `defaultParamMap`.
+ metadata['sparkVersion'] = '2.3.0'
+ metadataStr = json.dumps(metadata, separators=[',', ':'])
+ loadedMetadata = reader._parseMetaData(metadataStr, )
+ reader.getAndSetParams(lr, loadedMetadata)
+
class LDATest(SparkSessionTestCase):
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index a486c6a3fdeb5..9fa85664939b8 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -30,6 +30,7 @@
from pyspark import SparkContext, since
from pyspark.ml.common import inherit_doc
from pyspark.sql import SparkSession
+from pyspark.util import VersionUtils
def _jvm():
@@ -396,6 +397,7 @@ def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None):
- sparkVersion
- uid
- paramMap
+ - defaultParamMap (since 2.4.0)
- (optionally, extra metadata)
:param extraMetadata: Extra metadata to be saved at same level as uid, paramMap, etc.
:param paramMap: If given, this is saved in the "paramMap" field.
@@ -417,15 +419,24 @@ def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None):
"""
uid = instance.uid
cls = instance.__module__ + '.' + instance.__class__.__name__
- params = instance.extractParamMap()
+
+ # User-supplied param values
+ params = instance._paramMap
jsonParams = {}
if paramMap is not None:
jsonParams = paramMap
else:
for p in params:
jsonParams[p.name] = params[p]
+
+ # Default param values
+ jsonDefaultParams = {}
+ for p in instance._defaultParamMap:
+ jsonDefaultParams[p.name] = instance._defaultParamMap[p]
+
basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)),
- "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams}
+ "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams,
+ "defaultParamMap": jsonDefaultParams}
if extraMetadata is not None:
basicMetadata.update(extraMetadata)
return json.dumps(basicMetadata, separators=[',', ':'])
@@ -523,11 +534,26 @@ def getAndSetParams(instance, metadata):
"""
Extract Params from metadata, and set them in the instance.
"""
+ # Set user-supplied param values
for paramName in metadata['paramMap']:
param = instance.getParam(paramName)
paramValue = metadata['paramMap'][paramName]
instance.set(param, paramValue)
+ # Set default param values
+ majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata['sparkVersion'])
+ major = majorAndMinorVersions[0]
+ minor = majorAndMinorVersions[1]
+
+ # For metadata file prior to Spark 2.4, there is no default section.
+ if major > 2 or (major == 2 and minor >= 4):
+ assert 'defaultParamMap' in metadata, "Error loading metadata: Expected " + \
+ "`defaultParamMap` section not found"
+
+ for paramName in metadata['defaultParamMap']:
+ paramValue = metadata['defaultParamMap'][paramName]
+ instance._setDefault(**{paramName: paramValue})
+
@staticmethod
def loadParamsInstance(path, sc):
"""
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 4b44f76747264..d5a237a5b2855 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -39,9 +39,11 @@
else:
from itertools import imap as map, ifilter as filter
+from pyspark.java_gateway import do_server_auth
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
- PickleSerializer, pack_long, AutoBatchedSerializer
+ PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \
+ UTF8Deserializer
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_full_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
@@ -136,7 +138,8 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])
-def _load_from_socket(port, serializer):
+def _load_from_socket(sock_info, serializer):
+ port, auth_secret = sock_info
sock = None
# Support for both IPv4 and IPv6.
# On most of IPv6-ready systems, IPv6 will take precedence.
@@ -156,8 +159,12 @@ def _load_from_socket(port, serializer):
# The RDD materialization time is unpredicable, if we set a timeout for socket reading
# operation, it will very possibly fail. See SPARK-18281.
sock.settimeout(None)
+
+ sockfile = sock.makefile("rwb", 65536)
+ do_server_auth(sockfile, auth_secret)
+
# The socket will be automatically closed when garbage-collected.
- return serializer.load_stream(sock.makefile("rb", 65536))
+ return serializer.load_stream(sockfile)
def ignore_unicode_prefix(f):
@@ -822,8 +829,8 @@ def collect(self):
to be small, as all the data is loaded into the driver's memory.
"""
with SCCallSiteSync(self.context) as css:
- port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
- return list(_load_from_socket(port, self._jrdd_deserializer))
+ sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
+ return list(_load_from_socket(sock_info, self._jrdd_deserializer))
def reduce(self, f):
"""
@@ -2380,8 +2387,8 @@ def toLocalIterator(self):
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
with SCCallSiteSync(self.context) as css:
- port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
- return _load_from_socket(port, self._jrdd_deserializer)
+ sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
+ return _load_from_socket(sock_info, self._jrdd_deserializer)
def _prepare_for_python_RDD(sc, command):
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 16f8e52dead7b..213dc158f9328 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -463,8 +463,8 @@ def collect(self):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
- port = self._jdf.collectToPython()
- return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
+ sock_info = self._jdf.collectToPython()
+ return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer())))
@ignore_unicode_prefix
@since(2.0)
@@ -477,8 +477,8 @@ def toLocalIterator(self):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
- port = self._jdf.toPythonIterator()
- return _load_from_socket(port, BatchedSerializer(PickleSerializer()))
+ sock_info = self._jdf.toPythonIterator()
+ return _load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))
@ignore_unicode_prefix
@since(1.3)
@@ -2087,8 +2087,8 @@ def _collectAsArrow(self):
.. note:: Experimental.
"""
with SCCallSiteSync(self._sc) as css:
- port = self._jdf.collectAsArrowToPython()
- return list(_load_from_socket(port, ArrowSerializer()))
+ sock_info = self._jdf.collectAsArrowToPython()
+ return list(_load_from_socket(sock_info, ArrowSerializer()))
##########################################################################################
# Pandas compatibility
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 0b8eb19435300..ac312d2e3d395 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -152,6 +152,9 @@ def _():
_collect_list_doc = """
Aggregate function: returns a list of objects with duplicates.
+ .. note:: The function is non-deterministic because the order of collected results depends
+ on order of rows which may be non-deterministic after a shuffle.
+
>>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',))
>>> df2.agg(collect_list('age')).collect()
[Row(collect_list(age)=[2, 5, 5])]
@@ -159,6 +162,9 @@ def _():
_collect_set_doc = """
Aggregate function: returns a set of objects with duplicate elements eliminated.
+ .. note:: The function is non-deterministic because the order of collected results depends
+ on order of rows which may be non-deterministic after a shuffle.
+
>>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',))
>>> df2.agg(collect_set('age')).collect()
[Row(collect_set(age)=[5, 2])]
@@ -401,6 +407,9 @@ def first(col, ignorenulls=False):
The function by default returns the first values it sees. It will return the first non-null
value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+
+ .. note:: The function is non-deterministic because its results depends on order of rows which
+ may be non-deterministic after a shuffle.
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls)
@@ -489,6 +498,9 @@ def last(col, ignorenulls=False):
The function by default returns the last values it sees. It will return the last non-null
value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+
+ .. note:: The function is non-deterministic because its results depends on order of rows
+ which may be non-deterministic after a shuffle.
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls)
@@ -504,6 +516,8 @@ def monotonically_increasing_id():
within each partition in the lower 33 bits. The assumption is that the data frame has
less than 1 billion partitions, and each partition has less than 8 billion records.
+ .. note:: The function is non-deterministic because its result depends on partition IDs.
+
As an example, consider a :class:`DataFrame` with two partitions, each with 3 records.
This expression would return the following IDs:
0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594.
@@ -536,6 +550,8 @@ def rand(seed=None):
"""Generates a random column with independent and identically distributed (i.i.d.) samples
from U[0.0, 1.0].
+ .. note:: The function is non-deterministic in general case.
+
>>> df.withColumn('rand', rand(seed=42) * 3).collect()
[Row(age=2, name=u'Alice', rand=1.1568609015300986),
Row(age=5, name=u'Bob', rand=1.403379671529166)]
@@ -554,6 +570,8 @@ def randn(seed=None):
"""Generates a column with independent and identically distributed (i.i.d.) samples from
the standard normal distribution.
+ .. note:: The function is non-deterministic in general case.
+
>>> df.withColumn('randn', randn(seed=42)).collect()
[Row(age=2, name=u'Alice', randn=-0.7556247885860078),
Row(age=5, name=u'Bob', randn=-0.0861619008451133)]
@@ -1090,8 +1108,11 @@ def add_months(start, months):
@since(1.5)
def months_between(date1, date2, roundOff=True):
"""
- Returns the number of months between date1 and date2.
- Unless `roundOff` is set to `False`, the result is rounded off to 8 digits.
+ Returns number of months between dates date1 and date2.
+ If date1 is later than date2, then the result is positive.
+ If date1 and date2 are on the same day of month, or both are the last day of month,
+ returns an integer (time of day will be ignored).
+ The result is rounded off to 8 digits unless `roundOff` is set to `False`.
>>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2'])
>>> df.select(months_between(df.date1, df.date2).alias('months')).collect()
@@ -2074,12 +2095,13 @@ def json_tuple(col, *fields):
return Column(jc)
+@ignore_unicode_prefix
@since(2.1)
def from_json(col, schema, options={}):
"""
- Parses a column containing a JSON string into a :class:`StructType` or :class:`ArrayType`
- of :class:`StructType`\\s with the specified schema. Returns `null`, in the case of an
- unparseable string.
+ Parses a column containing a JSON string into a :class:`MapType` with :class:`StringType`
+ as keys type, :class:`StructType` or :class:`ArrayType` of :class:`StructType`\\s with
+ the specified schema. Returns `null`, in the case of an unparseable string.
:param col: string column in json format
:param schema: a StructType or ArrayType of StructType to use when parsing the json column.
@@ -2096,6 +2118,9 @@ def from_json(col, schema, options={}):
[Row(json=Row(a=1))]
>>> df.select(from_json(df.value, "a INT").alias("json")).collect()
[Row(json=Row(a=1))]
+ >>> schema = MapType(StringType(), IntegerType())
+ >>> df.select(from_json(df.value, schema).alias("json")).collect()
+ [Row(json={u'a': 1})]
>>> data = [(1, '''[{"a": 1}]''')]
>>> schema = ArrayType(StructType([StructField("a", IntegerType())]))
>>> df = spark.createDataFrame(data, ("key", "value"))
@@ -2324,6 +2349,20 @@ def map_entries(col):
return Column(sc._jvm.functions.map_entries(_to_java_column(col)))
+@ignore_unicode_prefix
+@since(2.4)
+def array_repeat(col, count):
+ """
+ Collection function: creates an array containing a column repeated count times.
+
+ >>> df = spark.createDataFrame([('ab',)], ['data'])
+ >>> df.select(array_repeat(df.data, 3).alias('r')).collect()
+ [Row(r=[u'ab', u'ab', u'ab'])]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count))
+
+
# ---------------------------- User Defined Function ----------------------------------
class PandasUDFType(object):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 16aa9378ad8ee..a1b6db71782bb 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -4680,6 +4680,26 @@ def test_supported_types(self):
self.assertPandasEqual(expected2, result2)
self.assertPandasEqual(expected3, result3)
+ def test_array_type_correct(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
+
+ df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id")
+
+ output_schema = StructType(
+ [StructField('id', LongType()),
+ StructField('v', IntegerType()),
+ StructField('arr', ArrayType(LongType()))])
+
+ udf = pandas_udf(
+ lambda pdf: pdf,
+ output_schema,
+ PandasUDFType.GROUPED_MAP
+ )
+
+ result = df.groupby('id').apply(udf).sort('id').toPandas()
+ expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True)
+ self.assertPandasEqual(expected, result)
+
def test_register_grouped_map_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 7b8ce2c6b799f..498d6b57e4353 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -2312,6 +2312,10 @@ def test_py4j_exception_message(self):
self.assertTrue('NullPointerException' in _exception_message(context.exception))
+ def test_parsing_version_string(self):
+ from pyspark.util import VersionUtils
+ self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced"))
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 04df835bf6717..59cc2a6329350 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -62,24 +62,31 @@ def _get_argspec(f):
return argspec
-def majorMinorVersion(version):
+class VersionUtils(object):
"""
- Get major and minor version numbers for given Spark version string.
-
- >>> version = "2.4.0"
- >>> majorMinorVersion(version)
- (2, 4)
+ Provides utility method to determine Spark versions with given input string.
+ """
+ @staticmethod
+ def majorMinorVersion(sparkVersion):
+ """
+ Given a Spark version string, return the (major version number, minor version number).
+ E.g., for 2.0.1-SNAPSHOT, return (2, 0).
- >>> version = "abc"
- >>> majorMinorVersion(version) is None
- True
+ >>> sparkVersion = "2.4.0"
+ >>> VersionUtils.majorMinorVersion(sparkVersion)
+ (2, 4)
+ >>> sparkVersion = "2.3.0-SNAPSHOT"
+ >>> VersionUtils.majorMinorVersion(sparkVersion)
+ (2, 3)
- """
- m = re.search('^(\d+)\.(\d+)(\..*)?$', version)
- if m is None:
- return None
- else:
- return (int(m.group(1)), int(m.group(2)))
+ """
+ m = re.search('^(\d+)\.(\d+)(\..*)?$', sparkVersion)
+ if m is not None:
+ return (int(m.group(1)), int(m.group(2)))
+ else:
+ raise ValueError("Spark tried to parse '%s' as a Spark" % sparkVersion +
+ " version string, but it could not find the major and minor" +
+ " version numbers.")
if __name__ == "__main__":
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index a1a4336b1e8de..5d2e58bef6466 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -27,6 +27,7 @@
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
+from pyspark.java_gateway import do_server_auth
from pyspark.taskcontext import TaskContext
from pyspark.files import SparkFiles
from pyspark.rdd import PythonEvalType
@@ -81,7 +82,7 @@ def wrap_scalar_pandas_udf(f, return_type):
def verify_result_length(*a):
result = f(*a)
if not hasattr(result, "__len__"):
- raise TypeError("Return type of the user-defined functon should be "
+ raise TypeError("Return type of the user-defined function should be "
"Pandas.Series, but is {}".format(type(result)))
if len(result) != len(a[0]):
raise RuntimeError("Result vector from pandas_udf was not the required length: "
@@ -301,9 +302,11 @@ def process():
if __name__ == '__main__':
- # Read a local port to connect to from stdin
- java_port = int(sys.stdin.readline())
+ # Read information about how to connect back to the JVM from the environment.
+ java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
+ auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(("127.0.0.1", java_port))
sock_file = sock.makefile("rwb", 65536)
+ do_server_auth(sock_file, auth_secret)
main(sock_file, sock_file)
diff --git a/python/setup.py b/python/setup.py
index 794ceceae3008..d309e0564530a 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -201,7 +201,7 @@ def _supports_symlinks():
'pyspark.examples.src.main.python': ['*.py', '*/*.py']},
scripts=scripts,
license='http://www.apache.org/licenses/LICENSE-2.0',
- install_requires=['py4j==0.10.6'],
+ install_requires=['py4j==0.10.7'],
setup_requires=['pypandoc'],
extras_require={
'ml': ['numpy>=1.7'],
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala
new file mode 100644
index 0000000000000..70b307303d149
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import java.nio.file.Paths
+import java.util.UUID
+
+import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, VolumeBuilder, VolumeMountBuilder}
+
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod}
+
+private[spark] class LocalDirsFeatureStep(
+ conf: KubernetesConf[_ <: KubernetesRoleSpecificConf],
+ defaultLocalDir: String = s"/var/data/spark-${UUID.randomUUID}")
+ extends KubernetesFeatureConfigStep {
+
+ // Cannot use Utils.getConfiguredLocalDirs because that will default to the Java system
+ // property - we want to instead default to mounting an emptydir volume that doesn't already
+ // exist in the image.
+ // We could make utils.getConfiguredLocalDirs opinionated about Kubernetes, as it is already
+ // a bit opinionated about YARN and Mesos.
+ private val resolvedLocalDirs = Option(conf.sparkConf.getenv("SPARK_LOCAL_DIRS"))
+ .orElse(conf.getOption("spark.local.dir"))
+ .getOrElse(defaultLocalDir)
+ .split(",")
+
+ override def configurePod(pod: SparkPod): SparkPod = {
+ val localDirVolumes = resolvedLocalDirs
+ .zipWithIndex
+ .map { case (localDir, index) =>
+ new VolumeBuilder()
+ .withName(s"spark-local-dir-${index + 1}")
+ .withNewEmptyDir()
+ .endEmptyDir()
+ .build()
+ }
+ val localDirVolumeMounts = localDirVolumes
+ .zip(resolvedLocalDirs)
+ .map { case (localDirVolume, localDirPath) =>
+ new VolumeMountBuilder()
+ .withName(localDirVolume.getName)
+ .withMountPath(localDirPath)
+ .build()
+ }
+ val podWithLocalDirVolumes = new PodBuilder(pod.pod)
+ .editSpec()
+ .addToVolumes(localDirVolumes: _*)
+ .endSpec()
+ .build()
+ val containerWithLocalDirVolumeMounts = new ContainerBuilder(pod.container)
+ .addNewEnv()
+ .withName("SPARK_LOCAL_DIRS")
+ .withValue(resolvedLocalDirs.mkString(","))
+ .endEnv()
+ .addToVolumeMounts(localDirVolumeMounts: _*)
+ .build()
+ SparkPod(podWithLocalDirVolumes, containerWithLocalDirVolumeMounts)
+ }
+
+ override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty
+
+ override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala
index c7579ed8cb689..10b0154466a3a 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy.k8s.submit
import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf}
-import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, MountSecretsFeatureStep}
+import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep}
private[spark] class KubernetesDriverBuilder(
provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep =
@@ -29,14 +29,18 @@ private[spark] class KubernetesDriverBuilder(
new DriverServiceFeatureStep(_),
provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]
=> MountSecretsFeatureStep) =
- new MountSecretsFeatureStep(_)) {
+ new MountSecretsFeatureStep(_),
+ provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf])
+ => LocalDirsFeatureStep =
+ new LocalDirsFeatureStep(_)) {
def buildFromFeatures(
kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = {
val baseFeatures = Seq(
provideBasicStep(kubernetesConf),
provideCredentialsStep(kubernetesConf),
- provideServiceStep(kubernetesConf))
+ provideServiceStep(kubernetesConf),
+ provideLocalDirsStep(kubernetesConf))
val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) {
baseFeatures ++ Seq(provideSecretsStep(kubernetesConf))
} else baseFeatures
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala
index 22568fe7ea3be..d8f63d57574fb 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala
@@ -17,18 +17,21 @@
package org.apache.spark.scheduler.cluster.k8s
import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod}
-import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, MountSecretsFeatureStep}
+import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep}
private[spark] class KubernetesExecutorBuilder(
provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep =
new BasicExecutorFeatureStep(_),
provideSecretsStep:
(KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep =
- new MountSecretsFeatureStep(_)) {
+ new MountSecretsFeatureStep(_),
+ provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf])
+ => LocalDirsFeatureStep =
+ new LocalDirsFeatureStep(_)) {
def buildFromFeatures(
kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = {
- val baseFeatures = Seq(provideBasicStep(kubernetesConf))
+ val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf))
val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) {
baseFeatures ++ Seq(provideSecretsStep(kubernetesConf))
} else baseFeatures
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala
new file mode 100644
index 0000000000000..91e184b84b86e
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import io.fabric8.kubernetes.api.model.{EnvVarBuilder, VolumeBuilder, VolumeMountBuilder}
+import org.mockito.Mockito
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod}
+
+class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
+ private val defaultLocalDir = "/var/data/default-local-dir"
+ private var sparkConf: SparkConf = _
+ private var kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf] = _
+
+ before {
+ val realSparkConf = new SparkConf(false)
+ sparkConf = Mockito.spy(realSparkConf)
+ kubernetesConf = KubernetesConf(
+ sparkConf,
+ KubernetesDriverSpecificConf(
+ None,
+ "app-name",
+ "main",
+ Seq.empty),
+ "resource",
+ "app-id",
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty)
+ }
+
+ test("Resolve to default local dir if neither env nor configuration are set") {
+ Mockito.doReturn(null).when(sparkConf).get("spark.local.dir")
+ Mockito.doReturn(null).when(sparkConf).getenv("SPARK_LOCAL_DIRS")
+ val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir)
+ val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod())
+ assert(configuredPod.pod.getSpec.getVolumes.size === 1)
+ assert(configuredPod.pod.getSpec.getVolumes.get(0) ===
+ new VolumeBuilder()
+ .withName(s"spark-local-dir-1")
+ .withNewEmptyDir()
+ .endEmptyDir()
+ .build())
+ assert(configuredPod.container.getVolumeMounts.size === 1)
+ assert(configuredPod.container.getVolumeMounts.get(0) ===
+ new VolumeMountBuilder()
+ .withName(s"spark-local-dir-1")
+ .withMountPath(defaultLocalDir)
+ .build())
+ assert(configuredPod.container.getEnv.size === 1)
+ assert(configuredPod.container.getEnv.get(0) ===
+ new EnvVarBuilder()
+ .withName("SPARK_LOCAL_DIRS")
+ .withValue(defaultLocalDir)
+ .build())
+ }
+
+ test("Use configured local dirs split on comma if provided.") {
+ Mockito.doReturn("/var/data/my-local-dir-1,/var/data/my-local-dir-2")
+ .when(sparkConf).getenv("SPARK_LOCAL_DIRS")
+ val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir)
+ val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod())
+ assert(configuredPod.pod.getSpec.getVolumes.size === 2)
+ assert(configuredPod.pod.getSpec.getVolumes.get(0) ===
+ new VolumeBuilder()
+ .withName(s"spark-local-dir-1")
+ .withNewEmptyDir()
+ .endEmptyDir()
+ .build())
+ assert(configuredPod.pod.getSpec.getVolumes.get(1) ===
+ new VolumeBuilder()
+ .withName(s"spark-local-dir-2")
+ .withNewEmptyDir()
+ .endEmptyDir()
+ .build())
+ assert(configuredPod.container.getVolumeMounts.size === 2)
+ assert(configuredPod.container.getVolumeMounts.get(0) ===
+ new VolumeMountBuilder()
+ .withName(s"spark-local-dir-1")
+ .withMountPath("/var/data/my-local-dir-1")
+ .build())
+ assert(configuredPod.container.getVolumeMounts.get(1) ===
+ new VolumeMountBuilder()
+ .withName(s"spark-local-dir-2")
+ .withMountPath("/var/data/my-local-dir-2")
+ .build())
+ assert(configuredPod.container.getEnv.size === 1)
+ assert(configuredPod.container.getEnv.get(0) ===
+ new EnvVarBuilder()
+ .withName("SPARK_LOCAL_DIRS")
+ .withValue("/var/data/my-local-dir-1,/var/data/my-local-dir-2")
+ .build())
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala
index 161f9afe7bba9..a511d254d2175 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala
@@ -18,13 +18,14 @@ package org.apache.spark.deploy.k8s.submit
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf}
-import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep}
+import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep}
class KubernetesDriverBuilderSuite extends SparkFunSuite {
private val BASIC_STEP_TYPE = "basic"
private val CREDENTIALS_STEP_TYPE = "credentials"
private val SERVICE_STEP_TYPE = "service"
+ private val LOCAL_DIRS_STEP_TYPE = "local-dirs"
private val SECRETS_STEP_TYPE = "mount-secrets"
private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
@@ -36,6 +37,9 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite {
private val serviceStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
SERVICE_STEP_TYPE, classOf[DriverServiceFeatureStep])
+ private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
+ LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep])
+
private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep])
@@ -44,7 +48,8 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite {
_ => basicFeatureStep,
_ => credentialsStep,
_ => serviceStep,
- _ => secretsStep)
+ _ => secretsStep,
+ _ => localDirsStep)
test("Apply fundamental steps all the time.") {
val conf = KubernetesConf(
@@ -64,7 +69,8 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite {
builderUnderTest.buildFromFeatures(conf),
BASIC_STEP_TYPE,
CREDENTIALS_STEP_TYPE,
- SERVICE_STEP_TYPE)
+ SERVICE_STEP_TYPE,
+ LOCAL_DIRS_STEP_TYPE)
}
test("Apply secrets step if secrets are present.") {
@@ -86,6 +92,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite {
BASIC_STEP_TYPE,
CREDENTIALS_STEP_TYPE,
SERVICE_STEP_TYPE,
+ LOCAL_DIRS_STEP_TYPE,
SECRETS_STEP_TYPE)
}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala
index f5270623f8acc..9ee86b5a423a9 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala
@@ -20,20 +20,24 @@ import io.fabric8.kubernetes.api.model.PodBuilder
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod}
-import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep}
+import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep}
class KubernetesExecutorBuilderSuite extends SparkFunSuite {
private val BASIC_STEP_TYPE = "basic"
private val SECRETS_STEP_TYPE = "mount-secrets"
+ private val LOCAL_DIRS_STEP_TYPE = "local-dirs"
private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep])
private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep])
+ private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
+ LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep])
private val builderUnderTest = new KubernetesExecutorBuilder(
_ => basicFeatureStep,
- _ => mountSecretsStep)
+ _ => mountSecretsStep,
+ _ => localDirsStep)
test("Basic steps are consistently applied.") {
val conf = KubernetesConf(
@@ -46,7 +50,8 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite {
Map.empty,
Map.empty,
Map.empty)
- validateStepTypesApplied(builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE)
+ validateStepTypesApplied(
+ builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE)
}
test("Apply secrets step if secrets are present.") {
@@ -63,6 +68,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite {
validateStepTypesApplied(
builderUnderTest.buildFromFeatures(conf),
BASIC_STEP_TYPE,
+ LOCAL_DIRS_STEP_TYPE,
SECRETS_STEP_TYPE)
}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 595077e7e809f..3d6ee50b070a3 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -346,7 +346,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
synchronized {
if (!finished) {
val inShutdown = ShutdownHookManager.inShutdown()
- if (registered) {
+ if (registered || !isClusterMode) {
exitCode = code
finalStatus = status
} else {
@@ -389,37 +389,40 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
}
private def registerAM(
+ host: String,
+ port: Int,
_sparkConf: SparkConf,
- _rpcEnv: RpcEnv,
- driverRef: RpcEndpointRef,
- uiAddress: Option[String]) = {
+ uiAddress: Option[String]): Unit = {
val appId = client.getAttemptId().getApplicationId().toString()
val attemptId = client.getAttemptId().getAttemptId().toString()
val historyAddress = ApplicationMaster
.getHistoryServerAddress(_sparkConf, yarnConf, appId, attemptId)
- val driverUrl = RpcEndpointAddress(
- _sparkConf.get("spark.driver.host"),
- _sparkConf.get("spark.driver.port").toInt,
+ client.register(host, port, yarnConf, _sparkConf, uiAddress, historyAddress)
+ registered = true
+ }
+
+ private def createAllocator(driverRef: RpcEndpointRef, _sparkConf: SparkConf): Unit = {
+ val appId = client.getAttemptId().getApplicationId().toString()
+ val driverUrl = RpcEndpointAddress(driverRef.address.host, driverRef.address.port,
CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString
// Before we initialize the allocator, let's log the information about how executors will
// be run up front, to avoid printing this out for every single executor being launched.
// Use placeholders for information that changes such as executor IDs.
logInfo {
- val executorMemory = sparkConf.get(EXECUTOR_MEMORY).toInt
- val executorCores = sparkConf.get(EXECUTOR_CORES)
- val dummyRunner = new ExecutorRunnable(None, yarnConf, sparkConf, driverUrl, "",
+ val executorMemory = _sparkConf.get(EXECUTOR_MEMORY).toInt
+ val executorCores = _sparkConf.get(EXECUTOR_CORES)
+ val dummyRunner = new ExecutorRunnable(None, yarnConf, _sparkConf, driverUrl, "",
"", executorMemory, executorCores, appId, securityMgr, localResources)
dummyRunner.launchContextDebugInfo()
}
- allocator = client.register(driverUrl,
- driverRef,
+ allocator = client.createAllocator(
yarnConf,
_sparkConf,
- uiAddress,
- historyAddress,
+ driverUrl,
+ driverRef,
securityMgr,
localResources)
@@ -434,15 +437,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
reporterThread = launchReporterThread()
}
- /**
- * @return An [[RpcEndpoint]] that communicates with the driver's scheduler backend.
- */
- private def createSchedulerRef(host: String, port: String): RpcEndpointRef = {
- rpcEnv.setupEndpointRef(
- RpcAddress(host, port.toInt),
- YarnSchedulerBackend.ENDPOINT_NAME)
- }
-
private def runDriver(): Unit = {
addAmIpFilter(None)
userClassThread = startUserApplication()
@@ -456,11 +450,16 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
Duration(totalWaitTime, TimeUnit.MILLISECONDS))
if (sc != null) {
rpcEnv = sc.env.rpcEnv
- val driverRef = createSchedulerRef(
- sc.getConf.get("spark.driver.host"),
- sc.getConf.get("spark.driver.port"))
- registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl))
- registered = true
+
+ val userConf = sc.getConf
+ val host = userConf.get("spark.driver.host")
+ val port = userConf.get("spark.driver.port").toInt
+ registerAM(host, port, userConf, sc.ui.map(_.webUrl))
+
+ val driverRef = rpcEnv.setupEndpointRef(
+ RpcAddress(host, port),
+ YarnSchedulerBackend.ENDPOINT_NAME)
+ createAllocator(driverRef, userConf)
} else {
// Sanity check; should never happen in normal operation, since sc should only be null
// if the user app did not create a SparkContext.
@@ -486,10 +485,18 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
val amCores = sparkConf.get(AM_CORES)
rpcEnv = RpcEnv.create("sparkYarnAM", hostname, hostname, -1, sparkConf, securityMgr,
amCores, true)
- val driverRef = waitForSparkDriver()
+
+ // The client-mode AM doesn't listen for incoming connections, so report an invalid port.
+ registerAM(hostname, -1, sparkConf, sparkConf.getOption("spark.driver.appUIAddress"))
+
+ // The driver should be up and listening, so unlike cluster mode, just try to connect to it
+ // with no waiting or retrying.
+ val (driverHost, driverPort) = Utils.parseHostPort(args.userArgs(0))
+ val driverRef = rpcEnv.setupEndpointRef(
+ RpcAddress(driverHost, driverPort),
+ YarnSchedulerBackend.ENDPOINT_NAME)
addAmIpFilter(Some(driverRef))
- registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress"))
- registered = true
+ createAllocator(driverRef, sparkConf)
// In client mode the actor will stop the reporter thread.
reporterThread.join()
@@ -600,40 +607,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
}
}
- private def waitForSparkDriver(): RpcEndpointRef = {
- logInfo("Waiting for Spark driver to be reachable.")
- var driverUp = false
- val hostport = args.userArgs(0)
- val (driverHost, driverPort) = Utils.parseHostPort(hostport)
-
- // Spark driver should already be up since it launched us, but we don't want to
- // wait forever, so wait 100 seconds max to match the cluster mode setting.
- val totalWaitTimeMs = sparkConf.get(AM_MAX_WAIT_TIME)
- val deadline = System.currentTimeMillis + totalWaitTimeMs
-
- while (!driverUp && !finished && System.currentTimeMillis < deadline) {
- try {
- val socket = new Socket(driverHost, driverPort)
- socket.close()
- logInfo("Driver now available: %s:%s".format(driverHost, driverPort))
- driverUp = true
- } catch {
- case e: Exception =>
- logError("Failed to connect to driver at %s:%s, retrying ...".
- format(driverHost, driverPort))
- Thread.sleep(100L)
- }
- }
-
- if (!driverUp) {
- throw new SparkException("Failed to connect to driver!")
- }
-
- sparkConf.set("spark.driver.host", driverHost)
- sparkConf.set("spark.driver.port", driverPort.toString)
- createSchedulerRef(driverHost, driverPort.toString)
- }
-
/** Add the Yarn IP filter that is required for properly securing the UI. */
private def addAmIpFilter(driver: Option[RpcEndpointRef]) = {
val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV)
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index bafb129032b49..7225ff03dc34e 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -1019,8 +1019,7 @@ private[spark] class Client(
appId: ApplicationId,
returnOnRunning: Boolean = false,
logApplicationReport: Boolean = true,
- interval: Long = sparkConf.get(REPORT_INTERVAL)):
- (YarnApplicationState, FinalApplicationStatus) = {
+ interval: Long = sparkConf.get(REPORT_INTERVAL)): YarnAppReport = {
var lastState: YarnApplicationState = null
while (true) {
Thread.sleep(interval)
@@ -1031,11 +1030,13 @@ private[spark] class Client(
case e: ApplicationNotFoundException =>
logError(s"Application $appId not found.")
cleanupStagingDir(appId)
- return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED)
+ return YarnAppReport(YarnApplicationState.KILLED, FinalApplicationStatus.KILLED, None)
case NonFatal(e) =>
- logError(s"Failed to contact YARN for application $appId.", e)
+ val msg = s"Failed to contact YARN for application $appId."
+ logError(msg, e)
// Don't necessarily clean up staging dir because status is unknown
- return (YarnApplicationState.FAILED, FinalApplicationStatus.FAILED)
+ return YarnAppReport(YarnApplicationState.FAILED, FinalApplicationStatus.FAILED,
+ Some(msg))
}
val state = report.getYarnApplicationState
@@ -1073,14 +1074,14 @@ private[spark] class Client(
}
if (state == YarnApplicationState.FINISHED ||
- state == YarnApplicationState.FAILED ||
- state == YarnApplicationState.KILLED) {
+ state == YarnApplicationState.FAILED ||
+ state == YarnApplicationState.KILLED) {
cleanupStagingDir(appId)
- return (state, report.getFinalApplicationStatus)
+ return createAppReport(report)
}
if (returnOnRunning && state == YarnApplicationState.RUNNING) {
- return (state, report.getFinalApplicationStatus)
+ return createAppReport(report)
}
lastState = state
@@ -1129,16 +1130,17 @@ private[spark] class Client(
throw new SparkException(s"Application $appId finished with status: $state")
}
} else {
- val (yarnApplicationState, finalApplicationStatus) = monitorApplication(appId)
- if (yarnApplicationState == YarnApplicationState.FAILED ||
- finalApplicationStatus == FinalApplicationStatus.FAILED) {
+ val YarnAppReport(appState, finalState, diags) = monitorApplication(appId)
+ if (appState == YarnApplicationState.FAILED || finalState == FinalApplicationStatus.FAILED) {
+ diags.foreach { err =>
+ logError(s"Application diagnostics message: $err")
+ }
throw new SparkException(s"Application $appId finished with failed status")
}
- if (yarnApplicationState == YarnApplicationState.KILLED ||
- finalApplicationStatus == FinalApplicationStatus.KILLED) {
+ if (appState == YarnApplicationState.KILLED || finalState == FinalApplicationStatus.KILLED) {
throw new SparkException(s"Application $appId is killed")
}
- if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) {
+ if (finalState == FinalApplicationStatus.UNDEFINED) {
throw new SparkException(s"The final status of application $appId is undefined")
}
}
@@ -1152,7 +1154,7 @@ private[spark] class Client(
val pyArchivesFile = new File(pyLibPath, "pyspark.zip")
require(pyArchivesFile.exists(),
s"$pyArchivesFile not found; cannot run pyspark application in YARN mode.")
- val py4jFile = new File(pyLibPath, "py4j-0.10.6-src.zip")
+ val py4jFile = new File(pyLibPath, "py4j-0.10.7-src.zip")
require(py4jFile.exists(),
s"$py4jFile not found; cannot run pyspark application in YARN mode.")
Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath())
@@ -1477,6 +1479,12 @@ private object Client extends Logging {
uri.startsWith(s"$LOCAL_SCHEME:")
}
+ def createAppReport(report: ApplicationReport): YarnAppReport = {
+ val diags = report.getDiagnostics()
+ val diagsOpt = if (diags != null && diags.nonEmpty) Some(diags) else None
+ YarnAppReport(report.getYarnApplicationState(), report.getFinalApplicationStatus(), diagsOpt)
+ }
+
}
private[spark] class YarnClusterApplication extends SparkApplication {
@@ -1491,3 +1499,8 @@ private[spark] class YarnClusterApplication extends SparkApplication {
}
}
+
+private[spark] case class YarnAppReport(
+ appState: YarnApplicationState,
+ finalState: FinalApplicationStatus,
+ diagnostics: Option[String])
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
index 17234b120ae13..b59dcf158d87c 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
@@ -42,23 +42,20 @@ private[spark] class YarnRMClient extends Logging {
/**
* Registers the application master with the RM.
*
+ * @param driverHost Host name where driver is running.
+ * @param driverPort Port where driver is listening.
* @param conf The Yarn configuration.
* @param sparkConf The Spark configuration.
* @param uiAddress Address of the SparkUI.
* @param uiHistoryAddress Address of the application on the History Server.
- * @param securityMgr The security manager.
- * @param localResources Map with information about files distributed via YARN's cache.
*/
def register(
- driverUrl: String,
- driverRef: RpcEndpointRef,
+ driverHost: String,
+ driverPort: Int,
conf: YarnConfiguration,
sparkConf: SparkConf,
uiAddress: Option[String],
- uiHistoryAddress: String,
- securityMgr: SecurityManager,
- localResources: Map[String, LocalResource]
- ): YarnAllocator = {
+ uiHistoryAddress: String): Unit = {
amClient = AMRMClient.createAMRMClient()
amClient.init(conf)
amClient.start()
@@ -70,10 +67,19 @@ private[spark] class YarnRMClient extends Logging {
logInfo("Registering the ApplicationMaster")
synchronized {
- amClient.registerApplicationMaster(driverRef.address.host, driverRef.address.port,
- trackingUrl)
+ amClient.registerApplicationMaster(driverHost, driverPort, trackingUrl)
registered = true
}
+ }
+
+ def createAllocator(
+ conf: YarnConfiguration,
+ sparkConf: SparkConf,
+ driverUrl: String,
+ driverRef: RpcEndpointRef,
+ securityMgr: SecurityManager,
+ localResources: Map[String, LocalResource]): YarnAllocator = {
+ require(registered, "Must register AM before creating allocator.")
new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr,
localResources, new SparkRackResolver())
}
@@ -88,6 +94,9 @@ private[spark] class YarnRMClient extends Logging {
if (registered) {
amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress)
}
+ if (amClient != null) {
+ amClient.stop()
+ }
}
/** Returns the attempt ID. */
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index 06e54a2eaf95a..f1a8df00f9c5b 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.yarn.api.records.YarnApplicationState
import org.apache.spark.{SparkContext, SparkException}
-import org.apache.spark.deploy.yarn.{Client, ClientArguments}
+import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnAppReport}
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging
import org.apache.spark.launcher.SparkAppHandle
@@ -75,13 +75,23 @@ private[spark] class YarnClientSchedulerBackend(
val monitorInterval = conf.get(CLIENT_LAUNCH_MONITOR_INTERVAL)
assert(client != null && appId.isDefined, "Application has not been submitted yet!")
- val (state, _) = client.monitorApplication(appId.get, returnOnRunning = true,
- interval = monitorInterval) // blocking
+ val YarnAppReport(state, _, diags) = client.monitorApplication(appId.get,
+ returnOnRunning = true, interval = monitorInterval)
if (state == YarnApplicationState.FINISHED ||
- state == YarnApplicationState.FAILED ||
- state == YarnApplicationState.KILLED) {
- throw new SparkException("Yarn application has already ended! " +
- "It might have been killed or unable to launch application master.")
+ state == YarnApplicationState.FAILED ||
+ state == YarnApplicationState.KILLED) {
+ val genericMessage = "The YARN application has already ended! " +
+ "It might have been killed or the Application Master may have failed to start. " +
+ "Check the YARN application logs for more details."
+ val exceptionMsg = diags match {
+ case Some(msg) =>
+ logError(genericMessage)
+ msg
+
+ case None =>
+ genericMessage
+ }
+ throw new SparkException(exceptionMsg)
}
if (state == YarnApplicationState.RUNNING) {
logInfo(s"Application ${appId.get} has started running.")
@@ -100,8 +110,13 @@ private[spark] class YarnClientSchedulerBackend(
override def run() {
try {
- val (state, _) = client.monitorApplication(appId.get, logApplicationReport = false)
- logError(s"Yarn application has already exited with state $state!")
+ val YarnAppReport(_, state, diags) =
+ client.monitorApplication(appId.get, logApplicationReport = true)
+ logError(s"YARN application has exited unexpectedly with state $state! " +
+ "Check the YARN application logs for more details.")
+ diags.foreach { err =>
+ logError(s"Diagnostics message: $err")
+ }
allowInterrupt = false
sc.stop()
} catch {
@@ -124,7 +139,7 @@ private[spark] class YarnClientSchedulerBackend(
private def asyncMonitorApplication(): MonitorThread = {
assert(client != null && appId.isDefined, "Application has not been submitted yet!")
val t = new MonitorThread
- t.setName("Yarn application state monitor")
+ t.setName("YARN application state monitor")
t.setDaemon(true)
t
}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index a129be7c06b53..59b0f29e37d84 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -265,7 +265,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
// needed locations.
val sparkHome = sys.props("spark.test.home")
val pythonPath = Seq(
- s"$sparkHome/python/lib/py4j-0.10.6-src.zip",
+ s"$sparkHome/python/lib/py4j-0.10.7-src.zip",
s"$sparkHome/python")
val extraEnvVars = Map(
"PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator),
diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh
index bac154e10ae62..bf3da18c3706e 100755
--- a/sbin/spark-config.sh
+++ b/sbin/spark-config.sh
@@ -28,6 +28,6 @@ export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}"
# Add the PySpark classes to the PYTHONPATH:
if [ -z "${PYSPARK_PYTHONPATH_SET}" ]; then
export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}"
- export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:${PYTHONPATH}"
+ export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:${PYTHONPATH}"
export PYSPARK_PYTHONPATH_SET=1
fi
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index f7f921ec22c35..7c54851097af3 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -398,7 +398,7 @@ hintStatement
;
fromClause
- : FROM relation (',' relation)* (pivotClause | lateralView*)?
+ : FROM relation (',' relation)* lateralView* pivotClause?
;
aggregation
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index ccdb6bc5d4b7c..7b02317b8538f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -68,10 +68,10 @@ import org.apache.spark.sql.types._
*/
@Experimental
@InterfaceStability.Evolving
-@implicitNotFound("Unable to find encoder for type stored in a Dataset. Primitive types " +
- "(Int, String, etc) and Product types (case classes) are supported by importing " +
- "spark.implicits._ Support for serializing other types will be added in future " +
- "releases.")
+@implicitNotFound("Unable to find encoder for type ${T}. An implicit Encoder[${T}] is needed to " +
+ "store ${T} instances in a Dataset. Primitive types (Int, String, etc) and Product types (case " +
+ "classes) are supported by importing spark.implicits._ Support for serializing other types " +
+ "will be added in future releases.")
trait Encoder[T] extends Serializable {
/** Returns the schema of encoding this type of object as a Row. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index dfdcdbc1eb2c7..3eaa9ecf5d075 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -676,13 +676,13 @@ class Analyzer(
try {
catalog.lookupRelation(tableIdentWithDb)
} catch {
- case _: NoSuchTableException =>
- u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}")
+ case e: NoSuchTableException =>
+ u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}", e)
// If the database is defined and that database is not found, throw an AnalysisException.
// Note that if the database is not defined, it is possible we are looking up a temp view.
case e: NoSuchDatabaseException =>
u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}, the " +
- s"database ${e.db} doesn't exist.")
+ s"database ${e.db} doesn't exist.", e)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index f5e1d13125d1a..914027c27f18c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -299,6 +299,15 @@ object FunctionRegistry {
expression[CollectList]("collect_list"),
expression[CollectSet]("collect_set"),
expression[CountMinSketchAgg]("count_min_sketch"),
+ expression[RegrCount]("regr_count"),
+ expression[RegrSXX]("regr_sxx"),
+ expression[RegrSYY]("regr_syy"),
+ expression[RegrAvgX]("regr_avgx"),
+ expression[RegrAvgY]("regr_avgy"),
+ expression[RegrSXY]("regr_sxy"),
+ expression[RegrSlope]("regr_slope"),
+ expression[RegrR2]("regr_r2"),
+ expression[RegrIntercept]("regr_intercept"),
// string functions
expression[Ascii]("ascii"),
@@ -419,6 +428,7 @@ object FunctionRegistry {
expression[Reverse]("reverse"),
expression[Concat]("concat"),
expression[Flatten]("flatten"),
+ expression[ArrayRepeat]("array_repeat"),
CreateStruct.registryEntry,
// misc functions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
index f2df3e132629f..71ed75454cd4d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
@@ -103,7 +103,7 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas
castedExpr.eval()
} catch {
case NonFatal(ex) =>
- table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")
+ table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}", ex)
}
})
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index d3d6c636c4ba8..2bed41672fe33 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
index 7731336d247db..354a3fa0602a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
@@ -41,6 +41,11 @@ package object analysis {
def failAnalysis(msg: String): Nothing = {
throw new AnalysisException(msg, t.origin.line, t.origin.startPosition)
}
+
+ /** Fails the analysis at the point where a specific tree node was parsed. */
+ def failAnalysis(msg: String, cause: Throwable): Nothing = {
+ throw new AnalysisException(msg, t.origin.line, t.origin.startPosition, cause = Some(cause))
+ }
}
/** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
index ad1e7bdb31987..9f0779642271d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.types.{DataType, LongType}
puts the partition ID in the upper 31 bits, and the lower 33 bits represent the record number
within each partition. The assumption is that the data frame has less than 1 billion
partitions, and each partition has less than 8 billion records.
+ The function is non-deterministic because its result depends on partition IDs.
""")
case class MonotonicallyIncreasingID() extends LeafExpression with Stateful {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 708bdbfc36058..a133bc2361eb5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -23,24 +23,12 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
-@ExpressionDescription(
- usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.")
-case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {
-
- override def prettyName: String = "avg"
-
- override def children: Seq[Expression] = child :: Nil
+abstract class AverageLike(child: Expression) extends DeclarativeAggregate {
override def nullable: Boolean = true
-
// Return data type.
override def dataType: DataType = resultType
- override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
-
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForNumericExpr(child.dataType, "function average")
-
private lazy val resultType = child.dataType match {
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 4, s + 4)
@@ -62,14 +50,6 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
/* count = */ Literal(0L)
)
- override lazy val updateExpressions = Seq(
- /* sum = */
- Add(
- sum,
- Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
- /* count = */ If(IsNull(child), count, count + 1L)
- )
-
override lazy val mergeExpressions = Seq(
/* sum = */ sum.left + sum.right,
/* count = */ count.left + count.right
@@ -85,4 +65,29 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
case _ =>
Cast(sum, resultType) / Cast(count, resultType)
}
+
+ protected def updateExpressionsDef: Seq[Expression] = Seq(
+ /* sum = */
+ Add(
+ sum,
+ Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
+ /* count = */ If(IsNull(child), count, count + 1L)
+ )
+
+ override lazy val updateExpressions = updateExpressionsDef
+}
+
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.")
+case class Average(child: Expression)
+ extends AverageLike(child) with ImplicitCastInputTypes {
+
+ override def prettyName: String = "avg"
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForNumericExpr(child.dataType, "function average")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index 572d29caf5bc9..6bbb083f1e18e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -67,35 +67,7 @@ abstract class CentralMomentAgg(child: Expression)
override val initialValues: Seq[Expression] = Array.fill(momentOrder + 1)(Literal(0.0))
- override val updateExpressions: Seq[Expression] = {
- val newN = n + Literal(1.0)
- val delta = child - avg
- val deltaN = delta / newN
- val newAvg = avg + deltaN
- val newM2 = m2 + delta * (delta - deltaN)
-
- val delta2 = delta * delta
- val deltaN2 = deltaN * deltaN
- val newM3 = if (momentOrder >= 3) {
- m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2)
- } else {
- Literal(0.0)
- }
- val newM4 = if (momentOrder >= 4) {
- m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 +
- delta * (delta * delta2 - deltaN * deltaN2)
- } else {
- Literal(0.0)
- }
-
- trimHigherOrder(Seq(
- If(IsNull(child), n, newN),
- If(IsNull(child), avg, newAvg),
- If(IsNull(child), m2, newM2),
- If(IsNull(child), m3, newM3),
- If(IsNull(child), m4, newM4)
- ))
- }
+ override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef
override val mergeExpressions: Seq[Expression] = {
@@ -128,6 +100,36 @@ abstract class CentralMomentAgg(child: Expression)
trimHigherOrder(Seq(newN, newAvg, newM2, newM3, newM4))
}
+
+ protected def updateExpressionsDef: Seq[Expression] = {
+ val newN = n + Literal(1.0)
+ val delta = child - avg
+ val deltaN = delta / newN
+ val newAvg = avg + deltaN
+ val newM2 = m2 + delta * (delta - deltaN)
+
+ val delta2 = delta * delta
+ val deltaN2 = deltaN * deltaN
+ val newM3 = if (momentOrder >= 3) {
+ m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2)
+ } else {
+ Literal(0.0)
+ }
+ val newM4 = if (momentOrder >= 4) {
+ m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 +
+ delta * (delta * delta2 - deltaN * deltaN2)
+ } else {
+ Literal(0.0)
+ }
+
+ trimHigherOrder(Seq(
+ If(IsNull(child), n, newN),
+ If(IsNull(child), avg, newAvg),
+ If(IsNull(child), m2, newM2),
+ If(IsNull(child), m3, newM3),
+ If(IsNull(child), m4, newM4)
+ ))
+ }
}
// Compute the population standard deviation of a column
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
index 95a4a0d5af634..3cdef72c1f2c4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
@@ -22,17 +22,13 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
/**
- * Compute Pearson correlation between two expressions.
+ * Base class for computing Pearson correlation between two expressions.
* When applied on empty data (i.e., count is zero), it returns NULL.
*
* Definition of Pearson correlation can be found at
* http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient
*/
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.")
-// scalastyle:on line.size.limit
-case class Corr(x: Expression, y: Expression)
+abstract class PearsonCorrelation(x: Expression, y: Expression)
extends DeclarativeAggregate with ImplicitCastInputTypes {
override def children: Seq[Expression] = Seq(x, y)
@@ -51,7 +47,26 @@ case class Corr(x: Expression, y: Expression)
override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0))
- override val updateExpressions: Seq[Expression] = {
+ override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef
+
+ override val mergeExpressions: Seq[Expression] = {
+ val n1 = n.left
+ val n2 = n.right
+ val newN = n1 + n2
+ val dx = xAvg.right - xAvg.left
+ val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN)
+ val dy = yAvg.right - yAvg.left
+ val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN)
+ val newXAvg = xAvg.left + dxN * n2
+ val newYAvg = yAvg.left + dyN * n2
+ val newCk = ck.left + ck.right + dx * dyN * n1 * n2
+ val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2
+ val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2
+
+ Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk)
+ }
+
+ protected def updateExpressionsDef: Seq[Expression] = {
val newN = n + Literal(1.0)
val dx = x - xAvg
val dxN = dx / newN
@@ -73,24 +88,15 @@ case class Corr(x: Expression, y: Expression)
If(isNull, yMk, newYMk)
)
}
+}
- override val mergeExpressions: Seq[Expression] = {
-
- val n1 = n.left
- val n2 = n.right
- val newN = n1 + n2
- val dx = xAvg.right - xAvg.left
- val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN)
- val dy = yAvg.right - yAvg.left
- val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN)
- val newXAvg = xAvg.left + dxN * n2
- val newYAvg = yAvg.left + dyN * n2
- val newCk = ck.left + ck.right + dx * dyN * n1 * n2
- val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2
- val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2
- Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk)
- }
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.")
+// scalastyle:on line.size.limit
+case class Corr(x: Expression, y: Expression)
+ extends PearsonCorrelation(x, y) {
override val evaluateExpression: Expression = {
If(n === Literal(0.0), Literal.create(null, DoubleType),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index 1990f2f2f0722..40582d0abd762 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -21,24 +21,16 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = """
- _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null.
-
- _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-null.
-
- _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-null.
- """)
-// scalastyle:on line.size.limit
-case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
-
+/**
+ * Base class for all counting aggregators.
+ */
+abstract class CountLike extends DeclarativeAggregate {
override def nullable: Boolean = false
// Return data type.
override def dataType: DataType = LongType
- private lazy val count = AttributeReference("count", LongType, nullable = false)()
+ protected lazy val count = AttributeReference("count", LongType, nullable = false)()
override lazy val aggBufferAttributes = count :: Nil
@@ -46,6 +38,27 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
/* count = */ Literal(0L)
)
+ override lazy val mergeExpressions = Seq(
+ /* count = */ count.left + count.right
+ )
+
+ override lazy val evaluateExpression = count
+
+ override def defaultResult: Option[Literal] = Option(Literal(0L))
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null.
+
+ _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-null.
+
+ _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-null.
+ """)
+// scalastyle:on line.size.limit
+case class Count(children: Seq[Expression]) extends CountLike {
+
override lazy val updateExpressions = {
val nullableChildren = children.filter(_.nullable)
if (nullableChildren.isEmpty) {
@@ -58,14 +71,6 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
)
}
}
-
- override lazy val mergeExpressions = Seq(
- /* count = */ count.left + count.right
- )
-
- override lazy val evaluateExpression = count
-
- override def defaultResult: Option[Literal] = Option(Literal(0L))
}
object Count {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
index fc6c34baafdd1..72a7c62b328ee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
@@ -42,23 +42,7 @@ abstract class Covariance(x: Expression, y: Expression)
override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0))
- override lazy val updateExpressions: Seq[Expression] = {
- val newN = n + Literal(1.0)
- val dx = x - xAvg
- val dy = y - yAvg
- val dyN = dy / newN
- val newXAvg = xAvg + dx / newN
- val newYAvg = yAvg + dyN
- val newCk = ck + dx * (y - newYAvg)
-
- val isNull = IsNull(x) || IsNull(y)
- Seq(
- If(isNull, n, newN),
- If(isNull, xAvg, newXAvg),
- If(isNull, yAvg, newYAvg),
- If(isNull, ck, newCk)
- )
- }
+ override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef
override val mergeExpressions: Seq[Expression] = {
@@ -75,6 +59,24 @@ abstract class Covariance(x: Expression, y: Expression)
Seq(newN, newXAvg, newYAvg, newCk)
}
+
+ protected def updateExpressionsDef: Seq[Expression] = {
+ val newN = n + Literal(1.0)
+ val dx = x - xAvg
+ val dy = y - yAvg
+ val dyN = dy / newN
+ val newXAvg = xAvg + dx / newN
+ val newYAvg = yAvg + dyN
+ val newCk = ck + dx * (y - newYAvg)
+
+ val isNull = IsNull(x) || IsNull(y)
+ Seq(
+ If(isNull, n, newN),
+ If(isNull, xAvg, newXAvg),
+ If(isNull, yAvg, newYAvg),
+ If(isNull, ck, newCk)
+ )
+ }
}
@ExpressionDescription(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala
new file mode 100644
index 0000000000000..d8f4505588ff2
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{AbstractDataType, DoubleType}
+
+/**
+ * Base trait for all regression functions.
+ */
+trait RegrLike extends AggregateFunction with ImplicitCastInputTypes {
+ def y: Expression
+ def x: Expression
+
+ override def children: Seq[Expression] = Seq(y, x)
+ override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
+
+ protected def updateIfNotNull(exprs: Seq[Expression]): Seq[Expression] = {
+ assert(aggBufferAttributes.length == exprs.length)
+ val nullableChildren = children.filter(_.nullable)
+ if (nullableChildren.isEmpty) {
+ exprs
+ } else {
+ exprs.zip(aggBufferAttributes).map { case (e, a) =>
+ If(nullableChildren.map(IsNull).reduce(Or), a, e)
+ }
+ }
+ }
+}
+
+
+@ExpressionDescription(
+ usage = "_FUNC_(y, x) - Returns the number of non-null pairs.",
+ since = "2.4.0")
+case class RegrCount(y: Expression, x: Expression)
+ extends CountLike with RegrLike {
+
+ override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(Seq(count + 1L))
+
+ override def prettyName: String = "regr_count"
+}
+
+
+@ExpressionDescription(
+ usage = "_FUNC_(y, x) - Returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored.",
+ since = "2.4.0")
+case class RegrSXX(y: Expression, x: Expression)
+ extends CentralMomentAgg(x) with RegrLike {
+
+ override protected def momentOrder = 2
+
+ override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType), m2)
+ }
+
+ override def prettyName: String = "regr_sxx"
+}
+
+
+@ExpressionDescription(
+ usage = "_FUNC_(y, x) - Returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored.",
+ since = "2.4.0")
+case class RegrSYY(y: Expression, x: Expression)
+ extends CentralMomentAgg(y) with RegrLike {
+
+ override protected def momentOrder = 2
+
+ override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType), m2)
+ }
+
+ override def prettyName: String = "regr_syy"
+}
+
+
+@ExpressionDescription(
+ usage = "_FUNC_(y, x) - Returns the average of x. Any pair with a NULL is ignored.",
+ since = "2.4.0")
+case class RegrAvgX(y: Expression, x: Expression)
+ extends AverageLike(x) with RegrLike {
+
+ override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+ override def prettyName: String = "regr_avgx"
+}
+
+
+@ExpressionDescription(
+ usage = "_FUNC_(y, x) - Returns the average of y. Any pair with a NULL is ignored.",
+ since = "2.4.0")
+case class RegrAvgY(y: Expression, x: Expression)
+ extends AverageLike(y) with RegrLike {
+
+ override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+ override def prettyName: String = "regr_avgy"
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(y, x) - Returns the covariance of y and x multiplied for the number of items in the dataset. Any pair with a NULL is ignored.",
+ since = "2.4.0")
+// scalastyle:on line.size.limit
+case class RegrSXY(y: Expression, x: Expression)
+ extends Covariance(y, x) with RegrLike {
+
+ override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType), ck)
+ }
+
+ override def prettyName: String = "regr_sxy"
+}
+
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(y, x) - Returns the slope of the linear regression line. Any pair with a NULL is ignored.",
+ since = "2.4.0")
+// scalastyle:on line.size.limit
+case class RegrSlope(y: Expression, x: Expression)
+ extends PearsonCorrelation(y, x) with RegrLike {
+
+ override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+ override val evaluateExpression: Expression = {
+ If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), ck / yMk)
+ }
+
+ override def prettyName: String = "regr_slope"
+}
+
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(y, x) - Returns the coefficient of determination (also called R-squared or goodness of fit) for the regression line. Any pair with a NULL is ignored.",
+ since = "2.4.0")
+// scalastyle:on line.size.limit
+case class RegrR2(y: Expression, x: Expression)
+ extends PearsonCorrelation(y, x) with RegrLike {
+
+ override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+ override val evaluateExpression: Expression = {
+ If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType),
+ If(xMk === Literal(0.0), Literal(1.0), ck * ck / yMk / xMk))
+ }
+
+ override def prettyName: String = "regr_r2"
+}
+
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(y, x) - Returns the y-intercept of the linear regression line. Any pair with a NULL is ignored.",
+ since = "2.4.0")
+// scalastyle:on line.size.limit
+case class RegrIntercept(y: Expression, x: Expression)
+ extends PearsonCorrelation(y, x) with RegrLike {
+
+ override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0) || yMk === Literal(0.0), Literal.create(null, DoubleType),
+ xAvg - (ck / yMk) * yAvg)
+ }
+
+ override def prettyName: String = "regr_intercept"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index baeefae570997..f29628d3bfb08 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1621,3 +1621,152 @@ case class Flatten(child: Expression) extends UnaryExpression {
override def prettyName: String = "flatten"
}
+
+/**
+ * Returns the array containing the given input value (left) count (right) times.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(element, count) - Returns the array containing element count times.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_('123', 2);
+ ['123', '123']
+ """,
+ since = "2.4.0")
+case class ArrayRepeat(left: Expression, right: Expression)
+ extends BinaryExpression with ExpectsInputTypes {
+
+ private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
+
+ override def dataType: ArrayType = ArrayType(left.dataType, left.nullable)
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType)
+
+ override def nullable: Boolean = right.nullable
+
+ override def eval(input: InternalRow): Any = {
+ val count = right.eval(input)
+ if (count == null) {
+ null
+ } else {
+ if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) {
+ throw new RuntimeException(s"Unsuccessful try to create array with $count elements " +
+ s"due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+ }
+ val element = left.eval(input)
+ new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element))
+ }
+ }
+
+ override def prettyName: String = "array_repeat"
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val leftGen = left.genCode(ctx)
+ val rightGen = right.genCode(ctx)
+ val element = leftGen.value
+ val count = rightGen.value
+ val et = dataType.elementType
+
+ val coreLogic = if (CodeGenerator.isPrimitiveType(et)) {
+ genCodeForPrimitiveElement(ctx, et, element, count, leftGen.isNull, ev.value)
+ } else {
+ genCodeForNonPrimitiveElement(ctx, element, count, leftGen.isNull, ev.value)
+ }
+ val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic)
+
+ ev.copy(code =
+ s"""
+ |boolean ${ev.isNull} = false;
+ |${leftGen.code}
+ |${rightGen.code}
+ |${CodeGenerator.javaType(dataType)} ${ev.value} =
+ | ${CodeGenerator.defaultValue(dataType)};
+ |$resultCode
+ """.stripMargin)
+ }
+
+ private def nullElementsProtection(
+ ev: ExprCode,
+ rightIsNull: String,
+ coreLogic: String): String = {
+ if (nullable) {
+ s"""
+ |if ($rightIsNull) {
+ | ${ev.isNull} = true;
+ |} else {
+ | ${coreLogic}
+ |}
+ """.stripMargin
+ } else {
+ coreLogic
+ }
+ }
+
+ private def genCodeForNumberOfElements(ctx: CodegenContext, count: String): (String, String) = {
+ val numElements = ctx.freshName("numElements")
+ val numElementsCode =
+ s"""
+ |int $numElements = 0;
+ |if ($count > 0) {
+ | $numElements = $count;
+ |}
+ |if ($numElements > $MAX_ARRAY_LENGTH) {
+ | throw new RuntimeException("Unsuccessful try to create array with " + $numElements +
+ | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+ |}
+ """.stripMargin
+
+ (numElements, numElementsCode)
+ }
+
+ private def genCodeForPrimitiveElement(
+ ctx: CodegenContext,
+ elementType: DataType,
+ element: String,
+ count: String,
+ leftIsNull: String,
+ arrayDataName: String): String = {
+ val tempArrayDataName = ctx.freshName("tempArrayData")
+ val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+ val errorMessage = s" $prettyName failed."
+ val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count)
+
+ s"""
+ |$numElemCode
+ |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, errorMessage)}
+ |if (!$leftIsNull) {
+ | for (int k = 0; k < $tempArrayDataName.numElements(); k++) {
+ | $tempArrayDataName.set$primitiveValueTypeName(k, $element);
+ | }
+ |} else {
+ | for (int k = 0; k < $tempArrayDataName.numElements(); k++) {
+ | $tempArrayDataName.setNullAt(k);
+ | }
+ |}
+ |$arrayDataName = $tempArrayDataName;
+ """.stripMargin
+ }
+
+ private def genCodeForNonPrimitiveElement(
+ ctx: CodegenContext,
+ element: String,
+ count: String,
+ leftIsNull: String,
+ arrayDataName: String): String = {
+ val genericArrayClass = classOf[GenericArrayData].getName
+ val arrayName = ctx.freshName("arrayObject")
+ val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count)
+
+ s"""
+ |$numElemCode
+ |Object[] $arrayName = new Object[(int)$numElemName];
+ |if (!$leftIsNull) {
+ | for (int k = 0; k < $numElemName; k++) {
+ | $arrayName[k] = $element;
+ | }
+ |}
+ |$arrayDataName = new $genericArrayClass($arrayName);
+ """.stripMargin
+ }
+
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 76aa61415a11f..03422fecb3209 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -1194,13 +1194,21 @@ case class AddMonths(startDate: Expression, numMonths: Expression)
}
/**
- * Returns number of months between dates date1 and date2.
+ * Returns number of months between times `timestamp1` and `timestamp2`.
+ * If `timestamp1` is later than `timestamp2`, then the result is positive.
+ * If `timestamp1` and `timestamp2` are on the same day of month, or both
+ * are the last day of month, time of day will be ignored. Otherwise, the
+ * difference is calculated based on 31 days per month, and rounded to
+ * 8 digits unless roundOff=false.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
- _FUNC_(timestamp1, timestamp2[, roundOff]) - Returns number of months between `timestamp1` and `timestamp2`.
- The result is rounded to 8 decimal places by default. Set roundOff=false otherwise."""",
+ _FUNC_(timestamp1, timestamp2[, roundOff]) - If `timestamp1` is later than `timestamp2`, then the result
+ is positive. If `timestamp1` and `timestamp2` are on the same day of month, or both
+ are the last day of month, time of day will be ignored. Otherwise, the difference is
+ calculated based on 31 days per month, and rounded to 8 digits unless roundOff=false.
+ """,
examples = """
Examples:
> SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30');
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index 34161f0f03f4a..04a4eb0ffc032 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -548,7 +548,7 @@ case class JsonToStructs(
forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA))
override def checkInputDataTypes(): TypeCheckResult = nullableSchema match {
- case _: StructType | ArrayType(_: StructType, _) =>
+ case _: StructType | ArrayType(_: StructType, _) | _: MapType =>
super.checkInputDataTypes()
case _ => TypeCheckResult.TypeCheckFailure(
s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.")
@@ -558,6 +558,7 @@ case class JsonToStructs(
lazy val rowSchema = nullableSchema match {
case st: StructType => st
case ArrayType(st: StructType, _) => st
+ case mt: MapType => mt
}
// This converts parsed rows to the desired output by the given schema.
@@ -567,6 +568,8 @@ case class JsonToStructs(
(rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null
case ArrayType(_: StructType, _) =>
(rows: Seq[InternalRow]) => new GenericArrayData(rows)
+ case _: MapType =>
+ (rows: Seq[InternalRow]) => rows.head.getMap(0)
}
@transient
@@ -613,6 +616,11 @@ case class JsonToStructs(
}
override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
+
+ override def sql: String = schema match {
+ case _: MapType => "entries"
+ case _ => super.sql
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 7eda65a867028..b7834696cafc3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -117,12 +117,13 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable {
// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "_FUNC_() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID 36-character string.",
+ usage = """_FUNC_() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID 36-character string.""",
examples = """
Examples:
> SELECT _FUNC_();
46707d92-02f4-4817-8116-a4c3b23e6266
- """)
+ """,
+ note = "The function is non-deterministic.")
// scalastyle:on line.size.limit
case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index 70186053617f8..2653b28f6c3bd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -68,7 +68,8 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful
0.8446490682263027
> SELECT _FUNC_(null);
0.8446490682263027
- """)
+ """,
+ note = "The function is non-deterministic in general case.")
// scalastyle:on line.size.limit
case class Rand(child: Expression) extends RDG {
@@ -96,7 +97,7 @@ object Rand {
/** Generate a random column with i.i.d. values drawn from the standard normal distribution. */
// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.",
+ usage = """_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.""",
examples = """
Examples:
> SELECT _FUNC_();
@@ -105,7 +106,8 @@ object Rand {
1.1164209726833079
> SELECT _FUNC_(null);
1.1164209726833079
- """)
+ """,
+ note = "The function is non-deterministic in general case.")
// scalastyle:on line.size.limit
case class Randn(child: Expression) extends RDG {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index 5f130af606e19..2ff12acb2946f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util._
* Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]].
*/
private[sql] class JSONOptions(
- @transient private val parameters: CaseInsensitiveMap[String],
+ @transient val parameters: CaseInsensitiveMap[String],
defaultTimeZoneId: String,
defaultColumnNameOfCorruptRecord: String)
extends Logging with Serializable {
@@ -110,11 +110,12 @@ private[sql] class JSONOptions(
val blacklist = Seq(Charset.forName("UTF-16"), Charset.forName("UTF-32"))
val isBlacklisted = blacklist.contains(Charset.forName(enc))
require(multiLine || !isBlacklisted,
- s"""The ${enc} encoding must not be included in the blacklist when multiLine is disabled:
- | ${blacklist.mkString(", ")}""".stripMargin)
+ s"""The $enc encoding in the blacklist is not allowed when multiLine is disabled.
+ |Blacklist: ${blacklist.mkString(", ")}""".stripMargin)
+
+ val isLineSepRequired =
+ multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty
- val isLineSepRequired = !(multiLine == false &&
- Charset.forName(enc) != StandardCharsets.UTF_8 && lineSeparator.isEmpty)
require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding")
enc
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
index a5a4a13eb608b..c3a4ca8f64bf6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
@@ -36,7 +36,7 @@ import org.apache.spark.util.Utils
* Constructs a parser for a given schema that translates a json string to an [[InternalRow]].
*/
class JacksonParser(
- schema: StructType,
+ schema: DataType,
val options: JSONOptions) extends Logging {
import JacksonUtils._
@@ -57,7 +57,14 @@ class JacksonParser(
* to a value according to a desired schema. This is a wrapper for the method
* `makeConverter()` to handle a row wrapped with an array.
*/
- private def makeRootConverter(st: StructType): JsonParser => Seq[InternalRow] = {
+ private def makeRootConverter(dt: DataType): JsonParser => Seq[InternalRow] = {
+ dt match {
+ case st: StructType => makeStructRootConverter(st)
+ case mt: MapType => makeMapRootConverter(mt)
+ }
+ }
+
+ private def makeStructRootConverter(st: StructType): JsonParser => Seq[InternalRow] = {
val elementConverter = makeConverter(st)
val fieldConverters = st.map(_.dataType).map(makeConverter).toArray
(parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, st) {
@@ -87,6 +94,13 @@ class JacksonParser(
}
}
+ private def makeMapRootConverter(mt: MapType): JsonParser => Seq[InternalRow] = {
+ val fieldConverter = makeConverter(mt.valueType)
+ (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, mt) {
+ case START_OBJECT => Seq(InternalRow(convertMap(parser, fieldConverter)))
+ }
+ }
+
/**
* Create a converter which converts the JSON documents held by the `JsonParser`
* to a value according to a desired schema.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 64eed23884584..b9ece295c2510 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -504,6 +504,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
withJoinRelations(join, relation)
}
if (ctx.pivotClause() != null) {
+ if (!ctx.lateralView.isEmpty) {
+ throw new ParseException("LATERAL cannot be used together with PIVOT in FROM clause", ctx)
+ }
withPivot(ctx.pivotClause, from)
} else {
ctx.lateralView.asScala.foldLeft(from)(withGenerate)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
index 720d42ab409a0..8c4828a4cef23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.types.{StructField, StructType}
object LocalRelation {
@@ -77,7 +78,7 @@ case class LocalRelation(
}
override def computeStats(): Statistics =
- Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length)
+ Statistics(sizeInBytes = EstimationUtils.getSizePerRow(output) * data.length)
def toSQL(inlineTableName: String): String = {
require(data.nonEmpty)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
index 0f147f0ffb135..211a2a0717371 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
-import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.math.BigDecimal.RoundingMode
@@ -25,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.{DecimalType, _}
-
object EstimationUtils {
/** Check if each plan has rowCount in its statistics. */
@@ -73,13 +71,12 @@ object EstimationUtils {
AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _)))
}
- def getOutputSize(
+ def getSizePerRow(
attributes: Seq[Attribute],
- outputRowCount: BigInt,
attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = {
// We assign a generic overhead for a Row object, the actual overhead is different for different
// Row format.
- val sizePerRow = 8 + attributes.map { attr =>
+ 8 + attributes.map { attr =>
if (attrStats.get(attr).map(_.avgLen.isDefined).getOrElse(false)) {
attr.dataType match {
case StringType =>
@@ -92,10 +89,15 @@ object EstimationUtils {
attr.dataType.defaultSize
}
}.sum
+ }
+ def getOutputSize(
+ attributes: Seq[Attribute],
+ outputRowCount: BigInt,
+ attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = {
// Output size can't be zero, or sizeInBytes of BinaryNode will also be zero
// (simple computation of statistics returns product of children).
- if (outputRowCount > 0) outputRowCount * sizePerRow else 1
+ if (outputRowCount > 0) outputRowCount * getSizePerRow(attributes, attrStats) else 1
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
index 85f67c7d66075..ee43f9126386b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
@@ -33,8 +33,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
private def visitUnaryNode(p: UnaryNode): Statistics = {
// There should be some overhead in Row object, the size should not be zero when there is
// no columns, this help to prevent divide-by-zero error.
- val childRowSize = p.child.output.map(_.dataType.defaultSize).sum + 8
- val outputRowSize = p.output.map(_.dataType.defaultSize).sum + 8
+ val childRowSize = EstimationUtils.getSizePerRow(p.child.output)
+ val outputRowSize = EstimationUtils.getSizePerRow(p.output)
// Assume there will be the same number of rows as child has.
var sizeInBytes = (p.child.stats.sizeInBytes * outputRowSize) / childRowSize
if (sizeInBytes == 0) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index e646da0659e85..80f15053005ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -885,13 +885,13 @@ object DateTimeUtils {
/**
* Returns number of months between time1 and time2. time1 and time2 are expressed in
- * microseconds since 1.1.1970.
+ * microseconds since 1.1.1970. If time1 is later than time2, the result is positive.
*
- * If time1 and time2 having the same day of month, or both are the last day of month,
- * it returns an integer (time under a day will be ignored).
+ * If time1 and time2 are on the same day of month, or both are the last day of month,
+ * returns, time of day will be ignored.
*
* Otherwise, the difference is calculated based on 31 days per month.
- * If `roundOff` is set to true, the result is rounded to 8 decimal places.
+ * The result is rounded to 8 decimal places if `roundOff` is set to true.
*/
def monthsBetween(
time1: SQLTimestamp,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 895e150756567..b00edca97cd44 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -345,7 +345,7 @@ object SQLConf {
"snappy, gzip, lzo.")
.stringConf
.transform(_.toLowerCase(Locale.ROOT))
- .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo"))
+ .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo", "lz4", "brotli", "zstd"))
.createWithDefault("snappy")
val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 3e08d56c028e5..142ff88821ea4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -491,4 +491,22 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Flatten(asa3), null)
checkEvaluation(Flatten(asa4), null)
}
+
+ test("ArrayRepeat") {
+ val intArray = Literal.create(Seq(1, 2), ArrayType(IntegerType))
+ val strArray = Literal.create(Seq("hi", "hola"), ArrayType(StringType))
+
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(0)), Seq())
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(-1)), Seq())
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(1)), Seq("hi"))
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(2)), Seq("hi", "hi"))
+ checkEvaluation(ArrayRepeat(Literal(true), Literal(2)), Seq(true, true))
+ checkEvaluation(ArrayRepeat(Literal(1), Literal(2)), Seq(1, 1))
+ checkEvaluation(ArrayRepeat(Literal(3.2), Literal(2)), Seq(3.2, 3.2))
+ checkEvaluation(ArrayRepeat(Literal(null), Literal(2)), Seq[String](null, null))
+ checkEvaluation(ArrayRepeat(Literal(null, IntegerType), Literal(2)), Seq[Integer](null, null))
+ checkEvaluation(ArrayRepeat(intArray, Literal(2)), Seq(Seq(1, 2), Seq(1, 2)))
+ checkEvaluation(ArrayRepeat(strArray, Literal(2)), Seq(Seq("hi", "hola"), Seq("hi", "hola")))
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 812bfdd7bb885..fb51376c6163f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -318,6 +318,16 @@ class PlanParserSuite extends AnalysisTest {
assertEqual(
"select * from t lateral view posexplode(x) posexpl as x, y",
expected)
+
+ intercept(
+ """select *
+ |from t
+ |lateral view explode(x) expl
+ |pivot (
+ | sum(x)
+ | FOR y IN ('a', 'b')
+ |)""".stripMargin,
+ "LATERAL cannot be used together with PIVOT in FROM clause")
}
test("joins") {
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index ef41837f89d68..f270c70fbfcf0 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -38,7 +38,7 @@
com.univocity
univocity-parsers
- 2.5.9
+ 2.6.3
jar
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
index e65cd252c3ddf..daedfd7e78f5f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
@@ -18,7 +18,6 @@
package org.apache.spark.sql.execution.datasources.parquet;
-import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
@@ -293,7 +292,7 @@ protected static IntIterator createRLEIterator(
return new RLEIntIterator(
new RunLengthBitPackingHybridDecoder(
BytesUtils.getWidthFromMaxInt(maxLevel),
- new ByteArrayInputStream(bytes.toByteArray())));
+ bytes.toInputStream()));
} catch (IOException e) {
throw new IOException("could not read levels in page for col " + descriptor, e);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
index 72f1d024b08ce..d5969b55eef96 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
@@ -21,6 +21,8 @@
import java.util.Arrays;
import java.util.TimeZone;
+import org.apache.parquet.bytes.ByteBufferInputStream;
+import org.apache.parquet.bytes.BytesInput;
import org.apache.parquet.bytes.BytesUtils;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.column.Dictionary;
@@ -388,7 +390,8 @@ private void decodeDictionaryIds(
* is guaranteed that num is smaller than the number of values left in the current page.
*/
- private void readBooleanBatch(int rowId, int num, WritableColumnVector column) {
+ private void readBooleanBatch(int rowId, int num, WritableColumnVector column)
+ throws IOException {
if (column.dataType() != DataTypes.BooleanType) {
throw constructConvertNotSupportedException(descriptor, column);
}
@@ -396,7 +399,7 @@ private void readBooleanBatch(int rowId, int num, WritableColumnVector column) {
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
}
- private void readIntBatch(int rowId, int num, WritableColumnVector column) {
+ private void readIntBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType ||
@@ -414,7 +417,7 @@ private void readIntBatch(int rowId, int num, WritableColumnVector column) {
}
}
- private void readLongBatch(int rowId, int num, WritableColumnVector column) {
+ private void readLongBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
if (column.dataType() == DataTypes.LongType ||
DecimalType.is64BitDecimalType(column.dataType()) ||
@@ -434,7 +437,7 @@ private void readLongBatch(int rowId, int num, WritableColumnVector column) {
}
}
- private void readFloatBatch(int rowId, int num, WritableColumnVector column) {
+ private void readFloatBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: support implicit cast to double?
if (column.dataType() == DataTypes.FloatType) {
@@ -445,7 +448,7 @@ private void readFloatBatch(int rowId, int num, WritableColumnVector column) {
}
}
- private void readDoubleBatch(int rowId, int num, WritableColumnVector column) {
+ private void readDoubleBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
if (column.dataType() == DataTypes.DoubleType) {
@@ -456,7 +459,7 @@ private void readDoubleBatch(int rowId, int num, WritableColumnVector column) {
}
}
- private void readBinaryBatch(int rowId, int num, WritableColumnVector column) {
+ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
@@ -556,7 +559,7 @@ public Void visit(DataPageV2 dataPageV2) {
});
}
- private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) throws IOException {
+ private void initDataReader(Encoding dataEncoding, ByteBufferInputStream in) throws IOException {
this.endOfPageValueCount = valuesRead + pageValueCount;
if (dataEncoding.usesDictionary()) {
this.dataColumn = null;
@@ -581,7 +584,7 @@ private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) thr
}
try {
- dataColumn.initFromPage(pageValueCount, bytes, offset);
+ dataColumn.initFromPage(pageValueCount, in);
} catch (IOException e) {
throw new IOException("could not read page in col " + descriptor, e);
}
@@ -602,12 +605,11 @@ private void readPageV1(DataPageV1 page) throws IOException {
this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader);
this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader);
try {
- byte[] bytes = page.getBytes().toByteArray();
- rlReader.initFromPage(pageValueCount, bytes, 0);
- int next = rlReader.getNextOffset();
- dlReader.initFromPage(pageValueCount, bytes, next);
- next = dlReader.getNextOffset();
- initDataReader(page.getValueEncoding(), bytes, next);
+ BytesInput bytes = page.getBytes();
+ ByteBufferInputStream in = bytes.toInputStream();
+ rlReader.initFromPage(pageValueCount, in);
+ dlReader.initFromPage(pageValueCount, in);
+ initDataReader(page.getValueEncoding(), in);
} catch (IOException e) {
throw new IOException("could not read page " + page + " in col " + descriptor, e);
}
@@ -619,12 +621,13 @@ private void readPageV2(DataPageV2 page) throws IOException {
page.getRepetitionLevels(), descriptor);
int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel());
- this.defColumn = new VectorizedRleValuesReader(bitWidth);
+ // do not read the length from the stream. v2 pages handle dividing the page bytes.
+ this.defColumn = new VectorizedRleValuesReader(bitWidth, false);
this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumn);
- this.defColumn.initFromBuffer(
- this.pageValueCount, page.getDefinitionLevels().toByteArray());
+ this.defColumn.initFromPage(
+ this.pageValueCount, page.getDefinitionLevels().toInputStream());
try {
- initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0);
+ initDataReader(page.getDataEncoding(), page.getData().toInputStream());
} catch (IOException e) {
throw new IOException("could not read page " + page + " in col " + descriptor, e);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
index 5b75f719339fb..c62dc3d86386e 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
@@ -20,8 +20,9 @@
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import org.apache.parquet.bytes.ByteBufferInputStream;
+import org.apache.parquet.io.ParquetDecodingException;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
-import org.apache.spark.unsafe.Platform;
import org.apache.parquet.column.values.ValuesReader;
import org.apache.parquet.io.api.Binary;
@@ -30,24 +31,18 @@
* An implementation of the Parquet PLAIN decoder that supports the vectorized interface.
*/
public class VectorizedPlainValuesReader extends ValuesReader implements VectorizedValuesReader {
- private byte[] buffer;
- private int offset;
- private int bitOffset; // Only used for booleans.
- private ByteBuffer byteBuffer; // used to wrap the byte array buffer
+ private ByteBufferInputStream in = null;
- private static final boolean bigEndianPlatform =
- ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN);
+ // Only used for booleans.
+ private int bitOffset;
+ private byte currentByte = 0;
public VectorizedPlainValuesReader() {
}
@Override
- public void initFromPage(int valueCount, byte[] bytes, int offset) throws IOException {
- this.buffer = bytes;
- this.offset = offset + Platform.BYTE_ARRAY_OFFSET;
- if (bigEndianPlatform) {
- byteBuffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN);
- }
+ public void initFromPage(int valueCount, ByteBufferInputStream in) throws IOException {
+ this.in = in;
}
@Override
@@ -63,115 +58,157 @@ public final void readBooleans(int total, WritableColumnVector c, int rowId) {
}
}
+ private ByteBuffer getBuffer(int length) {
+ try {
+ return in.slice(length).order(ByteOrder.LITTLE_ENDIAN);
+ } catch (IOException e) {
+ throw new ParquetDecodingException("Failed to read " + length + " bytes", e);
+ }
+ }
+
@Override
public final void readIntegers(int total, WritableColumnVector c, int rowId) {
- c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
- offset += 4 * total;
+ int requiredBytes = total * 4;
+ ByteBuffer buffer = getBuffer(requiredBytes);
+
+ if (buffer.hasArray()) {
+ int offset = buffer.arrayOffset() + buffer.position();
+ c.putIntsLittleEndian(rowId, total, buffer.array(), offset);
+ } else {
+ for (int i = 0; i < total; i += 1) {
+ c.putInt(rowId + i, buffer.getInt());
+ }
+ }
}
@Override
public final void readLongs(int total, WritableColumnVector c, int rowId) {
- c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
- offset += 8 * total;
+ int requiredBytes = total * 8;
+ ByteBuffer buffer = getBuffer(requiredBytes);
+
+ if (buffer.hasArray()) {
+ int offset = buffer.arrayOffset() + buffer.position();
+ c.putLongsLittleEndian(rowId, total, buffer.array(), offset);
+ } else {
+ for (int i = 0; i < total; i += 1) {
+ c.putLong(rowId + i, buffer.getLong());
+ }
+ }
}
@Override
public final void readFloats(int total, WritableColumnVector c, int rowId) {
- c.putFloats(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
- offset += 4 * total;
+ int requiredBytes = total * 4;
+ ByteBuffer buffer = getBuffer(requiredBytes);
+
+ if (buffer.hasArray()) {
+ int offset = buffer.arrayOffset() + buffer.position();
+ c.putFloats(rowId, total, buffer.array(), offset);
+ } else {
+ for (int i = 0; i < total; i += 1) {
+ c.putFloat(rowId + i, buffer.getFloat());
+ }
+ }
}
@Override
public final void readDoubles(int total, WritableColumnVector c, int rowId) {
- c.putDoubles(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
- offset += 8 * total;
+ int requiredBytes = total * 8;
+ ByteBuffer buffer = getBuffer(requiredBytes);
+
+ if (buffer.hasArray()) {
+ int offset = buffer.arrayOffset() + buffer.position();
+ c.putDoubles(rowId, total, buffer.array(), offset);
+ } else {
+ for (int i = 0; i < total; i += 1) {
+ c.putDouble(rowId + i, buffer.getDouble());
+ }
+ }
}
@Override
public final void readBytes(int total, WritableColumnVector c, int rowId) {
- for (int i = 0; i < total; i++) {
- // Bytes are stored as a 4-byte little endian int. Just read the first byte.
- // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride.
- c.putByte(rowId + i, Platform.getByte(buffer, offset));
- offset += 4;
+ // Bytes are stored as a 4-byte little endian int. Just read the first byte.
+ // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride.
+ int requiredBytes = total * 4;
+ ByteBuffer buffer = getBuffer(requiredBytes);
+
+ for (int i = 0; i < total; i += 1) {
+ c.putByte(rowId + i, buffer.get());
+ // skip the next 3 bytes
+ buffer.position(buffer.position() + 3);
}
}
@Override
public final boolean readBoolean() {
- byte b = Platform.getByte(buffer, offset);
- boolean v = (b & (1 << bitOffset)) != 0;
+ // TODO: vectorize decoding and keep boolean[] instead of currentByte
+ if (bitOffset == 0) {
+ try {
+ currentByte = (byte) in.read();
+ } catch (IOException e) {
+ throw new ParquetDecodingException("Failed to read a byte", e);
+ }
+ }
+
+ boolean v = (currentByte & (1 << bitOffset)) != 0;
bitOffset += 1;
if (bitOffset == 8) {
bitOffset = 0;
- offset++;
}
return v;
}
@Override
public final int readInteger() {
- int v = Platform.getInt(buffer, offset);
- if (bigEndianPlatform) {
- v = java.lang.Integer.reverseBytes(v);
- }
- offset += 4;
- return v;
+ return getBuffer(4).getInt();
}
@Override
public final long readLong() {
- long v = Platform.getLong(buffer, offset);
- if (bigEndianPlatform) {
- v = java.lang.Long.reverseBytes(v);
- }
- offset += 8;
- return v;
+ return getBuffer(8).getLong();
}
@Override
public final byte readByte() {
- return (byte)readInteger();
+ return (byte) readInteger();
}
@Override
public final float readFloat() {
- float v;
- if (!bigEndianPlatform) {
- v = Platform.getFloat(buffer, offset);
- } else {
- v = byteBuffer.getFloat(offset - Platform.BYTE_ARRAY_OFFSET);
- }
- offset += 4;
- return v;
+ return getBuffer(4).getFloat();
}
@Override
public final double readDouble() {
- double v;
- if (!bigEndianPlatform) {
- v = Platform.getDouble(buffer, offset);
- } else {
- v = byteBuffer.getDouble(offset - Platform.BYTE_ARRAY_OFFSET);
- }
- offset += 8;
- return v;
+ return getBuffer(8).getDouble();
}
@Override
public final void readBinary(int total, WritableColumnVector v, int rowId) {
for (int i = 0; i < total; i++) {
int len = readInteger();
- int start = offset;
- offset += len;
- v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len);
+ ByteBuffer buffer = getBuffer(len);
+ if (buffer.hasArray()) {
+ v.putByteArray(rowId + i, buffer.array(), buffer.arrayOffset() + buffer.position(), len);
+ } else {
+ byte[] bytes = new byte[len];
+ buffer.get(bytes);
+ v.putByteArray(rowId + i, bytes);
+ }
}
}
@Override
public final Binary readBinary(int len) {
- Binary result = Binary.fromConstantByteArray(buffer, offset - Platform.BYTE_ARRAY_OFFSET, len);
- offset += len;
- return result;
+ ByteBuffer buffer = getBuffer(len);
+ if (buffer.hasArray()) {
+ return Binary.fromConstantByteArray(
+ buffer.array(), buffer.arrayOffset() + buffer.position(), len);
+ } else {
+ byte[] bytes = new byte[len];
+ buffer.get(bytes);
+ return Binary.fromConstantByteArray(bytes);
+ }
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
index fc7fa70c39419..fe3d31ae8e746 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.datasources.parquet;
import org.apache.parquet.Preconditions;
+import org.apache.parquet.bytes.ByteBufferInputStream;
import org.apache.parquet.bytes.BytesUtils;
import org.apache.parquet.column.values.ValuesReader;
import org.apache.parquet.column.values.bitpacking.BytePacker;
@@ -27,6 +28,9 @@
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
/**
* A values reader for Parquet's run-length encoded data. This is based off of the version in
* parquet-mr with these changes:
@@ -49,9 +53,7 @@ private enum MODE {
}
// Encoded data.
- private byte[] in;
- private int end;
- private int offset;
+ private ByteBufferInputStream in;
// bit/byte width of decoded data and utility to batch unpack them.
private int bitWidth;
@@ -70,45 +72,40 @@ private enum MODE {
// If true, the bit width is fixed. This decoder is used in different places and this also
// controls if we need to read the bitwidth from the beginning of the data stream.
private final boolean fixedWidth;
+ private final boolean readLength;
public VectorizedRleValuesReader() {
- fixedWidth = false;
+ this.fixedWidth = false;
+ this.readLength = false;
}
public VectorizedRleValuesReader(int bitWidth) {
- fixedWidth = true;
+ this.fixedWidth = true;
+ this.readLength = bitWidth != 0;
+ init(bitWidth);
+ }
+
+ public VectorizedRleValuesReader(int bitWidth, boolean readLength) {
+ this.fixedWidth = true;
+ this.readLength = readLength;
init(bitWidth);
}
@Override
- public void initFromPage(int valueCount, byte[] page, int start) {
- this.offset = start;
- this.in = page;
+ public void initFromPage(int valueCount, ByteBufferInputStream in) throws IOException {
+ this.in = in;
if (fixedWidth) {
- if (bitWidth != 0) {
+ // initialize for repetition and definition levels
+ if (readLength) {
int length = readIntLittleEndian();
- this.end = this.offset + length;
+ this.in = in.sliceStream(length);
}
} else {
- this.end = page.length;
- if (this.end != this.offset) init(page[this.offset++] & 255);
- }
- if (bitWidth == 0) {
- // 0 bit width, treat this as an RLE run of valueCount number of 0's.
- this.mode = MODE.RLE;
- this.currentCount = valueCount;
- this.currentValue = 0;
- } else {
- this.currentCount = 0;
+ // initialize for values
+ if (in.available() > 0) {
+ init(in.read());
+ }
}
- }
-
- // Initialize the reader from a buffer. This is used for the V2 page encoding where the
- // definition are in its own buffer.
- public void initFromBuffer(int valueCount, byte[] data) {
- this.offset = 0;
- this.in = data;
- this.end = data.length;
if (bitWidth == 0) {
// 0 bit width, treat this as an RLE run of valueCount number of 0's.
this.mode = MODE.RLE;
@@ -129,11 +126,6 @@ private void init(int bitWidth) {
this.packer = Packer.LITTLE_ENDIAN.newBytePacker(bitWidth);
}
- @Override
- public int getNextOffset() {
- return this.end;
- }
-
@Override
public boolean readBoolean() {
return this.readInteger() != 0;
@@ -182,7 +174,7 @@ public void readIntegers(
WritableColumnVector c,
int rowId,
int level,
- VectorizedValuesReader data) {
+ VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -217,7 +209,7 @@ public void readBooleans(
WritableColumnVector c,
int rowId,
int level,
- VectorizedValuesReader data) {
+ VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -251,7 +243,7 @@ public void readBytes(
WritableColumnVector c,
int rowId,
int level,
- VectorizedValuesReader data) {
+ VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -285,7 +277,7 @@ public void readShorts(
WritableColumnVector c,
int rowId,
int level,
- VectorizedValuesReader data) {
+ VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -321,7 +313,7 @@ public void readLongs(
WritableColumnVector c,
int rowId,
int level,
- VectorizedValuesReader data) {
+ VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -355,7 +347,7 @@ public void readFloats(
WritableColumnVector c,
int rowId,
int level,
- VectorizedValuesReader data) {
+ VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -389,7 +381,7 @@ public void readDoubles(
WritableColumnVector c,
int rowId,
int level,
- VectorizedValuesReader data) {
+ VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -423,7 +415,7 @@ public void readBinarys(
WritableColumnVector c,
int rowId,
int level,
- VectorizedValuesReader data) {
+ VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -462,7 +454,7 @@ public void readIntegers(
WritableColumnVector nulls,
int rowId,
int level,
- VectorizedValuesReader data) {
+ VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -559,12 +551,12 @@ public Binary readBinary(int len) {
/**
* Reads the next varint encoded int.
*/
- private int readUnsignedVarInt() {
+ private int readUnsignedVarInt() throws IOException {
int value = 0;
int shift = 0;
int b;
do {
- b = in[offset++] & 255;
+ b = in.read();
value |= (b & 0x7F) << shift;
shift += 7;
} while ((b & 0x80) != 0);
@@ -574,35 +566,32 @@ private int readUnsignedVarInt() {
/**
* Reads the next 4 byte little endian int.
*/
- private int readIntLittleEndian() {
- int ch4 = in[offset] & 255;
- int ch3 = in[offset + 1] & 255;
- int ch2 = in[offset + 2] & 255;
- int ch1 = in[offset + 3] & 255;
- offset += 4;
+ private int readIntLittleEndian() throws IOException {
+ int ch4 = in.read();
+ int ch3 = in.read();
+ int ch2 = in.read();
+ int ch1 = in.read();
return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0));
}
/**
* Reads the next byteWidth little endian int.
*/
- private int readIntLittleEndianPaddedOnBitWidth() {
+ private int readIntLittleEndianPaddedOnBitWidth() throws IOException {
switch (bytesWidth) {
case 0:
return 0;
case 1:
- return in[offset++] & 255;
+ return in.read();
case 2: {
- int ch2 = in[offset] & 255;
- int ch1 = in[offset + 1] & 255;
- offset += 2;
+ int ch2 = in.read();
+ int ch1 = in.read();
return (ch1 << 8) + ch2;
}
case 3: {
- int ch3 = in[offset] & 255;
- int ch2 = in[offset + 1] & 255;
- int ch1 = in[offset + 2] & 255;
- offset += 3;
+ int ch3 = in.read();
+ int ch2 = in.read();
+ int ch1 = in.read();
return (ch1 << 16) + (ch2 << 8) + (ch3 << 0);
}
case 4: {
@@ -619,32 +608,36 @@ private int ceil8(int value) {
/**
* Reads the next group.
*/
- private void readNextGroup() {
- int header = readUnsignedVarInt();
- this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED;
- switch (mode) {
- case RLE:
- this.currentCount = header >>> 1;
- this.currentValue = readIntLittleEndianPaddedOnBitWidth();
- return;
- case PACKED:
- int numGroups = header >>> 1;
- this.currentCount = numGroups * 8;
- int bytesToRead = ceil8(this.currentCount * this.bitWidth);
-
- if (this.currentBuffer.length < this.currentCount) {
- this.currentBuffer = new int[this.currentCount];
- }
- currentBufferIdx = 0;
- int valueIndex = 0;
- for (int byteIndex = offset; valueIndex < this.currentCount; byteIndex += this.bitWidth) {
- this.packer.unpack8Values(in, byteIndex, this.currentBuffer, valueIndex);
- valueIndex += 8;
- }
- offset += bytesToRead;
- return;
- default:
- throw new ParquetDecodingException("not a valid mode " + this.mode);
+ private void readNextGroup() {
+ try {
+ int header = readUnsignedVarInt();
+ this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED;
+ switch (mode) {
+ case RLE:
+ this.currentCount = header >>> 1;
+ this.currentValue = readIntLittleEndianPaddedOnBitWidth();
+ return;
+ case PACKED:
+ int numGroups = header >>> 1;
+ this.currentCount = numGroups * 8;
+
+ if (this.currentBuffer.length < this.currentCount) {
+ this.currentBuffer = new int[this.currentCount];
+ }
+ currentBufferIdx = 0;
+ int valueIndex = 0;
+ while (valueIndex < this.currentCount) {
+ // values are bit packed 8 at a time, so reading bitWidth will always work
+ ByteBuffer buffer = in.slice(bitWidth);
+ this.packer.unpack8Values(buffer, buffer.position(), this.currentBuffer, valueIndex);
+ valueIndex += 8;
+ }
+ return;
+ default:
+ throw new ParquetDecodingException("not a valid mode " + this.mode);
+ }
+ } catch (IOException e) {
+ throw new ParquetDecodingException("Failed to read from input stream", e);
}
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java
index 209ffa7a0b9fa..7f4a2c9593c76 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java
@@ -34,7 +34,7 @@ public interface MicroBatchReadSupport extends DataSourceV2 {
* streaming query.
*
* The execution engine will create a micro-batch reader at the start of a streaming query,
- * alternate calls to setOffsetRange and createDataReaderFactories for each batch to process, and
+ * alternate calls to setOffsetRange and planInputPartitions for each batch to process, and
* then call stop() when the execution is complete. Note that a single query may have multiple
* executions due to restart or failure recovery.
*
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java
index 0ea4dc6b5def3..b2526ded53d92 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java
@@ -30,7 +30,7 @@ public interface ReadSupport extends DataSourceV2 {
/**
* Creates a {@link DataSourceReader} to scan the data from this data source.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*
* @param options the options for the returned data source reader, which is an immutable
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java
index 3801402268af1..f31659904cc53 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java
@@ -35,7 +35,7 @@ public interface ReadSupportWithSchema extends DataSourceV2 {
/**
* Create a {@link DataSourceReader} to scan the data from this data source.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*
* @param schema the full schema of this data source reader. Full schema usually maps to the
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java
index cab56453816cc..83aeec0c47853 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java
@@ -35,7 +35,7 @@ public interface WriteSupport extends DataSourceV2 {
* Creates an optional {@link DataSourceWriter} to save the data to this data source. Data
* sources can return None if there is no writing needed to be done according to the save mode.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*
* @param jobId A unique string for the writing job. It's possible that there are many writing
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java
similarity index 78%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java
rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java
index a61697649c43e..c24f3b21eade1 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java
@@ -21,15 +21,15 @@
import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset;
/**
- * A mix-in interface for {@link DataReaderFactory}. Continuous data reader factories can
- * implement this interface to provide creating {@link DataReader} with particular offset.
+ * A mix-in interface for {@link InputPartition}. Continuous input partitions can
+ * implement this interface to provide creating {@link InputPartitionReader} with particular offset.
*/
@InterfaceStability.Evolving
-public interface ContinuousDataReaderFactory extends DataReaderFactory {
+public interface ContinuousInputPartition extends InputPartition {
/**
* Create a DataReader with particular offset as its startOffset.
*
* @param offset offset want to set as the DataReader's startOffset.
*/
- DataReader createDataReaderWithOffset(PartitionOffset offset);
+ InputPartitionReader createContinuousReader(PartitionOffset offset);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java
index a470bccc5aad2..36a3e542b5a11 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java
@@ -31,8 +31,8 @@
* {@link ReadSupport#createReader(DataSourceOptions)} or
* {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}.
* It can mix in various query optimization interfaces to speed up the data scan. The actual scan
- * logic is delegated to {@link DataReaderFactory}s that are returned by
- * {@link #createDataReaderFactories()}.
+ * logic is delegated to {@link InputPartition}s, which are returned by
+ * {@link #planInputPartitions()}.
*
* There are mainly 3 kinds of query optimizations:
* 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column
@@ -45,8 +45,8 @@
* only one of them would be respected, according to the priority list from high to low:
* {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}.
*
- * If an exception was throw when applying any of these query optimizations, the action would fail
- * and no Spark job was submitted.
+ * If an exception was throw when applying any of these query optimizations, the action will fail
+ * and no Spark job will be submitted.
*
* Spark first applies all operator push-down optimizations that this data source supports. Then
* Spark collects information this data source reported for further optimizations. Finally Spark
@@ -59,22 +59,22 @@ public interface DataSourceReader {
* Returns the actual schema of this data source reader, which may be different from the physical
* schema of the underlying storage, as column pruning or other optimizations may happen.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*/
StructType readSchema();
/**
- * Returns a list of reader factories. Each factory is responsible for creating a data reader to
- * output data for one RDD partition. That means the number of factories returned here is same as
- * the number of RDD partitions this scan outputs.
+ * Returns a list of {@link InputPartition}s. Each {@link InputPartition} is responsible for
+ * creating a data reader to output data of one RDD partition. The number of input partitions
+ * returned here is the same as the number of RDD partitions this scan outputs.
*
* Note that, this may not be a full scan if the data source reader mixes in other optimization
* interfaces like column pruning, filter push-down, etc. These optimizations are applied before
* Spark issues the scan request.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*/
- List> createDataReaderFactories();
+ List> planInputPartitions();
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java
similarity index 68%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java
rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java
index 32e98e8f5d8bd..3524481784fea 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java
@@ -22,29 +22,30 @@
import org.apache.spark.annotation.InterfaceStability;
/**
- * A reader factory returned by {@link DataSourceReader#createDataReaderFactories()} and is
- * responsible for creating the actual data reader. The relationship between
- * {@link DataReaderFactory} and {@link DataReader}
+ * An input partition returned by {@link DataSourceReader#planInputPartitions()} and is
+ * responsible for creating the actual data reader of one RDD partition.
+ * The relationship between {@link InputPartition} and {@link InputPartitionReader}
* is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}.
*
- * Note that, the reader factory will be serialized and sent to executors, then the data reader
- * will be created on executors and do the actual reading. So {@link DataReaderFactory} must be
- * serializable and {@link DataReader} doesn't need to be.
+ * Note that {@link InputPartition}s will be serialized and sent to executors, then
+ * {@link InputPartitionReader}s will be created on executors to do the actual reading. So
+ * {@link InputPartition} must be serializable while {@link InputPartitionReader} doesn't need to
+ * be.
*/
@InterfaceStability.Evolving
-public interface DataReaderFactory extends Serializable {
+public interface InputPartition extends Serializable {
/**
- * The preferred locations where the data reader returned by this reader factory can run faster,
+ * The preferred locations where the data reader returned by this partition can run faster,
* but Spark does not guarantee to run the data reader on these locations.
* The implementations should make sure that it can be run on any location.
* The location is a string representing the host name.
*
* Note that if a host name cannot be recognized by Spark, it will be ignored as it was not in
- * the returned locations. By default this method returns empty string array, which means this
- * task has no location preference.
+ * the returned locations. The default return value is empty string array, which means this
+ * input partition's reader has no location preference.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*/
default String[] preferredLocations() {
@@ -57,5 +58,5 @@ default String[] preferredLocations() {
* If this method fails (by throwing an exception), the corresponding Spark task would fail and
* get retried until hitting the maximum retry times.
*/
- DataReader createDataReader();
+ InputPartitionReader createPartitionReader();
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java
similarity index 92%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java
rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java
index bb9790a1c819e..1b7051f1ad0af 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java
@@ -23,7 +23,7 @@
import org.apache.spark.annotation.InterfaceStability;
/**
- * A data reader returned by {@link DataReaderFactory#createDataReader()} and is responsible for
+ * A data reader returned by {@link InputPartition#createPartitionReader()} and is responsible for
* outputting data for a RDD partition.
*
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data
@@ -31,7 +31,7 @@
* readers that mix in {@link SupportsScanUnsafeRow}.
*/
@InterfaceStability.Evolving
-public interface DataReader extends Closeable {
+public interface InputPartitionReader extends Closeable {
/**
* Proceed to next record, returns false if there is no more records.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
index 607628746e873..6b60da7c4dc1d 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
@@ -24,7 +24,7 @@
* A mix in interface for {@link DataSourceReader}. Data source readers can implement this
* interface to report data partitioning and try to avoid shuffle at Spark side.
*
- * Note that, when the reader creates exactly one {@link DataReaderFactory}, Spark may avoid
+ * Note that, when the reader creates exactly one {@link InputPartition}, Spark may avoid
* adding a shuffle even if the reader does not implement this interface.
*/
@InterfaceStability.Evolving
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java
index 2e5cfa78511f0..0faf81db24605 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java
@@ -30,22 +30,22 @@
@InterfaceStability.Evolving
public interface SupportsScanColumnarBatch extends DataSourceReader {
@Override
- default List> createDataReaderFactories() {
+ default List> planInputPartitions() {
throw new IllegalStateException(
- "createDataReaderFactories not supported by default within SupportsScanColumnarBatch.");
+ "planInputPartitions not supported by default within SupportsScanColumnarBatch.");
}
/**
- * Similar to {@link DataSourceReader#createDataReaderFactories()}, but returns columnar data
+ * Similar to {@link DataSourceReader#planInputPartitions()}, but returns columnar data
* in batches.
*/
- List> createBatchDataReaderFactories();
+ List> planBatchInputPartitions();
/**
* Returns true if the concrete data source reader can read data in batch according to the scan
* properties like required columns, pushes filters, etc. It's possible that the implementation
* can only support some certain columns with certain types. Users can overwrite this method and
- * {@link #createDataReaderFactories()} to fallback to normal read path under some conditions.
+ * {@link #planInputPartitions()} to fallback to normal read path under some conditions.
*/
default boolean enableBatchRead() {
return true;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java
index 9cd749e8e4ce9..f2220f6d31093 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java
@@ -33,14 +33,14 @@
public interface SupportsScanUnsafeRow extends DataSourceReader {
@Override
- default List> createDataReaderFactories() {
+ default List> planInputPartitions() {
throw new IllegalStateException(
- "createDataReaderFactories not supported by default within SupportsScanUnsafeRow");
+ "planInputPartitions not supported by default within SupportsScanUnsafeRow");
}
/**
- * Similar to {@link DataSourceReader#createDataReaderFactories()},
+ * Similar to {@link DataSourceReader#planInputPartitions()},
* but returns data in unsafe row format.
*/
- List> createUnsafeRowReaderFactories();
+ List> planUnsafeInputPartitions();
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java
index 2d0ee50212b56..38ca5fc6387b2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java
@@ -18,12 +18,12 @@
package org.apache.spark.sql.sources.v2.reader.partitioning;
import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.sources.v2.reader.DataReader;
+import org.apache.spark.sql.sources.v2.reader.InputPartitionReader;
/**
* A concrete implementation of {@link Distribution}. Represents a distribution where records that
* share the same values for the {@link #clusteredColumns} will be produced by the same
- * {@link DataReader}.
+ * {@link InputPartitionReader}.
*/
@InterfaceStability.Evolving
public class ClusteredDistribution implements Distribution {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java
index f6b111fdf220d..5e32ba6952e1c 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java
@@ -18,13 +18,14 @@
package org.apache.spark.sql.sources.v2.reader.partitioning;
import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.sources.v2.reader.DataReader;
+import org.apache.spark.sql.sources.v2.reader.InputPartitionReader;
/**
* An interface to represent data distribution requirement, which specifies how the records should
- * be distributed among the data partitions(one {@link DataReader} outputs data for one partition).
+ * be distributed among the data partitions (one {@link InputPartitionReader} outputs data for one
+ * partition).
* Note that this interface has nothing to do with the data ordering inside one
- * partition(the output records of a single {@link DataReader}).
+ * partition(the output records of a single {@link InputPartitionReader}).
*
* The instance of this interface is created and provided by Spark, then consumed by
* {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java
index 309d9e5de0a0f..f460f6bfe3bb9 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java
@@ -18,7 +18,7 @@
package org.apache.spark.sql.sources.v2.reader.partitioning;
import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.sources.v2.reader.DataReaderFactory;
+import org.apache.spark.sql.sources.v2.reader.InputPartition;
import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning;
/**
@@ -31,7 +31,7 @@
public interface Partitioning {
/**
- * Returns the number of partitions(i.e., {@link DataReaderFactory}s) the data source outputs.
+ * Returns the number of partitions(i.e., {@link InputPartition}s) the data source outputs.
*/
int numPartitions();
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java
similarity index 84%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java
rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java
index 47d26440841fd..7b0ba0bbdda90 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java
@@ -18,13 +18,13 @@
package org.apache.spark.sql.sources.v2.reader.streaming;
import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.sources.v2.reader.DataReader;
+import org.apache.spark.sql.sources.v2.reader.InputPartitionReader;
/**
- * A variation on {@link DataReader} for use with streaming in continuous processing mode.
+ * A variation on {@link InputPartitionReader} for use with streaming in continuous processing mode.
*/
@InterfaceStability.Evolving
-public interface ContinuousDataReader extends DataReader {
+public interface ContinuousInputPartitionReader extends InputPartitionReader {
/**
* Get the offset of the current record, or the start offset if no records have been read.
*
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java
index 7fe7f00ac2fa8..6e960bedf8020 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java
@@ -27,7 +27,7 @@
* A mix-in interface for {@link DataSourceReader}. Data source readers can implement this
* interface to allow reading in a continuous processing mode stream.
*
- * Implementations must ensure each reader factory output is a {@link ContinuousDataReader}.
+ * Implementations must ensure each partition reader is a {@link ContinuousInputPartitionReader}.
*
* Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with
* DataSource V1 APIs. This extension will be removed once we get rid of V1 completely.
@@ -35,8 +35,8 @@
@InterfaceStability.Evolving
public interface ContinuousReader extends BaseStreamingSource, DataSourceReader {
/**
- * Merge partitioned offsets coming from {@link ContinuousDataReader} instances for each
- * partition to a single global offset.
+ * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances
+ * for each partition to a single global offset.
*/
Offset mergeOffsets(PartitionOffset[] offsets);
@@ -47,7 +47,7 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceReader
Offset deserializeOffset(String json);
/**
- * Set the desired start offset for reader factories created from this reader. The scan will
+ * Set the desired start offset for partitions created from this reader. The scan will
* start from the first record after the provided offset, or from an implementation-defined
* inferred starting point if no offset is provided.
*/
@@ -61,8 +61,8 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceReader
Offset getStartOffset();
/**
- * The execution engine will call this method in every epoch to determine if new reader
- * factories need to be generated, which may be required if for example the underlying
+ * The execution engine will call this method in every epoch to determine if new input
+ * partitions need to be generated, which may be required if for example the underlying
* source system has had partitions added or removed.
*
* If true, the query will be shut down and restarted with a new reader.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java
index 67ebde30d61a9..0159c731762d9 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java
@@ -33,7 +33,7 @@
@InterfaceStability.Evolving
public interface MicroBatchReader extends DataSourceReader, BaseStreamingSource {
/**
- * Set the desired offset range for reader factories created from this reader. Reader factories
+ * Set the desired offset range for input partitions created from this reader. Partition readers
* will generate only data within (`start`, `end`]; that is, from the first record after `start`
* to the record with offset `end`.
*
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java
index 0a0fd8db58035..0030a9f05dba7 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java
@@ -34,8 +34,8 @@
* It can mix in various writing optimization interfaces to speed up the data saving. The actual
* writing logic is delegated to {@link DataWriter}.
*
- * If an exception was throw when applying any of these writing optimizations, the action would fail
- * and no Spark job was submitted.
+ * If an exception was throw when applying any of these writing optimizations, the action will fail
+ * and no Spark job will be submitted.
*
* The writing procedure is:
* 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the
@@ -58,7 +58,7 @@ public interface DataSourceWriter {
/**
* Creates a writer factory which will be serialized and sent to executors.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*/
DataWriterFactory createWriterFactory();
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
index c2c2ab73257e8..7527bcc0c4027 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
@@ -35,7 +35,7 @@ public interface DataWriterFactory extends Serializable {
/**
* Returns a data writer to do the actual writing work.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*
* @param partitionId A unique id of the RDD partition that the returned writer will process.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index e183fa6f9542b..90bea2d676e22 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -330,8 +330,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}
private def getBucketSpec: Option[BucketSpec] = {
- if (sortColumnNames.isDefined) {
- require(numBuckets.isDefined, "sortBy must be used together with bucketBy")
+ if (sortColumnNames.isDefined && numBuckets.isEmpty) {
+ throw new AnalysisException("sortBy must be used together with bucketBy")
}
numBuckets.map { n =>
@@ -340,8 +340,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}
private def assertNotBucketed(operation: String): Unit = {
- if (numBuckets.isDefined || sortColumnNames.isDefined) {
- throw new AnalysisException(s"'$operation' does not support bucketing right now")
+ if (getBucketSpec.isDefined) {
+ if (sortColumnNames.isEmpty) {
+ throw new AnalysisException(s"'$operation' does not support bucketBy right now")
+ } else {
+ throw new AnalysisException(s"'$operation' does not support bucketBy and sortBy right now")
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index cd4def71e6f3b..f001f16e1d5ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -511,6 +511,16 @@ class Dataset[T] private[sql](
*/
def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation]
+ /**
+ * Returns true if the `Dataset` is empty.
+ *
+ * @group basic
+ * @since 2.4.0
+ */
+ def isEmpty: Boolean = withAction("isEmpty", limit(1).groupBy().count().queryExecution) { plan =>
+ plan.executeCollect().head.getLong(0) == 0
+ }
+
/**
* Returns true if this Dataset contains one or more sources that continuously
* return data as it arrives. A Dataset that reads data from a streaming source
@@ -3187,7 +3197,7 @@ class Dataset[T] private[sql](
EvaluatePython.javaToPython(rdd)
}
- private[sql] def collectToPython(): Int = {
+ private[sql] def collectToPython(): Array[Any] = {
EvaluatePython.registerPicklers()
withAction("collectToPython", queryExecution) { plan =>
val toJava: (Any) => Any = EvaluatePython.toJava(_, schema)
@@ -3200,7 +3210,7 @@ class Dataset[T] private[sql](
/**
* Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
*/
- private[sql] def collectAsArrowToPython(): Int = {
+ private[sql] def collectAsArrowToPython(): Array[Any] = {
withAction("collectAsArrowToPython", queryExecution) { plan =>
val iter: Iterator[Array[Byte]] =
toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
@@ -3208,7 +3218,7 @@ class Dataset[T] private[sql](
}
}
- private[sql] def toPythonIterator(): Int = {
+ private[sql] def toPythonIterator(): Array[Any] = {
withNewExecutionId {
PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 82b4eb9fba242..37a0b9d6c8728 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -361,7 +361,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case Join(left, right, _, _) if left.isStreaming && right.isStreaming =>
throw new AnalysisException(
- "Stream stream joins without equality predicate is not supported", plan = Some(plan))
+ "Stream-stream join without equality predicate is not supported", plan = Some(plan))
case _ => Nil
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index 22b63513548fe..66888fce7f9f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -133,6 +133,14 @@ private[arrow] abstract class ArrowFieldWriter {
valueVector match {
case fixedWidthVector: BaseFixedWidthVector => fixedWidthVector.reset()
case variableWidthVector: BaseVariableWidthVector => variableWidthVector.reset()
+ case listVector: ListVector =>
+ // Manual "reset" the underlying buffer.
+ // TODO: When we upgrade to Arrow 0.10.0, we can simply remove this and call
+ // `listVector.reset()`.
+ val buffers = listVector.getBuffers(false)
+ buffers.foreach(buf => buf.setZero(0, buf.capacity()))
+ listVector.setValueCount(0)
+ listVector.setLastSet(0)
case _ =>
}
count = 0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index bc1f4ab3bb053..dc54d182651b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -185,7 +185,8 @@ object TextInputCSVDataSource extends CSVDataSource {
DataSource.apply(
sparkSession,
paths = paths,
- className = classOf[TextFileFormat].getName
+ className = classOf[TextFileFormat].getName,
+ options = options.parameters
).resolveRelation(checkFilesExist = false))
.select("value").as[String](Encoders.STRING)
} else {
@@ -250,7 +251,8 @@ object MultiLineCSVDataSource extends CSVDataSource {
options: CSVOptions): RDD[PortableDataStream] = {
val paths = inputPaths.map(_.getPath)
val name = paths.mkString(",")
- val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
+ val job = Job.getInstance(sparkSession.sessionState.newHadoopConfWithOptions(
+ options.parameters))
FileInputFormat.setInputPaths(job, paths: _*)
val conf = job.getConfiguration
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 2ec0fc605a84b..1066d156acd74 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -27,7 +27,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util._
class CSVOptions(
- @transient private val parameters: CaseInsensitiveMap[String],
+ @transient val parameters: CaseInsensitiveMap[String],
defaultTimeZoneId: String,
defaultColumnNameOfCorruptRecord: String)
extends Logging with Serializable {
@@ -164,7 +164,7 @@ class CSVOptions(
writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite)
writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite)
writerSettings.setNullValue(nullValue)
- writerSettings.setEmptyValue(nullValue)
+ writerSettings.setEmptyValue("\"\"")
writerSettings.setSkipEmptyLines(true)
writerSettings.setQuoteAllFields(quoteAll)
writerSettings.setQuoteEscapingEnabled(escapeQuotes)
@@ -185,6 +185,7 @@ class CSVOptions(
settings.setInputBufferSize(inputBufferSize)
settings.setMaxColumns(maxColumns)
settings.setNullValue(nullValue)
+ settings.setEmptyValue("")
settings.setMaxCharsPerColumn(maxCharsPerColumn)
settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER)
settings
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
index 31464f1bcc68e..9dae41b63e810 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
@@ -17,10 +17,8 @@
package org.apache.spark.sql.execution.datasources.csv
-import org.apache.spark.input.PortableDataStream
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
-import org.apache.spark.sql.catalyst.json.JSONOptions
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
index 3d6cc30f2ba83..99557a1ceb0c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
@@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources.csv
import java.io.InputStream
import java.math.BigDecimal
-import java.text.NumberFormat
-import java.util.Locale
import scala.util.Try
import scala.util.control.NonFatal
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
index 983a5f0dcade2..ba83df0efebd0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
@@ -121,7 +121,7 @@ object TextInputJsonDataSource extends JsonDataSource {
sparkSession,
paths = paths,
className = classOf[TextFileFormat].getName,
- options = textOptions
+ options = parsedOptions.parameters
).resolveRelation(checkFilesExist = false))
.select("value").as(Encoders.STRING)
}
@@ -159,7 +159,7 @@ object MultiLineJsonDataSource extends JsonDataSource {
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: JSONOptions): StructType = {
- val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths)
+ val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions)
val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions)
val parser = parsedOptions.encoding
.map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream))
@@ -170,9 +170,11 @@ object MultiLineJsonDataSource extends JsonDataSource {
private def createBaseRdd(
sparkSession: SparkSession,
- inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = {
+ inputPaths: Seq[FileStatus],
+ parsedOptions: JSONOptions): RDD[PortableDataStream] = {
val paths = inputPaths.map(_.getPath)
- val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
+ val job = Job.getInstance(sparkSession.sessionState.newHadoopConfWithOptions(
+ parsedOptions.parameters))
val conf = job.getConfiguration
val name = paths.mkString(",")
FileInputFormat.setInputPaths(job, paths: _*)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala
index f36a89a4c3c5f..9cfc30725f03a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala
@@ -81,7 +81,10 @@ object ParquetOptions {
"uncompressed" -> CompressionCodecName.UNCOMPRESSED,
"snappy" -> CompressionCodecName.SNAPPY,
"gzip" -> CompressionCodecName.GZIP,
- "lzo" -> CompressionCodecName.LZO)
+ "lzo" -> CompressionCodecName.LZO,
+ "lz4" -> CompressionCodecName.LZ4,
+ "brotli" -> CompressionCodecName.BROTLI,
+ "zstd" -> CompressionCodecName.ZSTD)
def getParquetCompressionCodecName(name: String): String = {
shortParquetCompressionCodecNames(name).name()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 0dea767840ed3..cab00251622b8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -61,7 +61,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] {
case _: ClassNotFoundException => u
case e: Exception =>
// the provider is valid, but failed to create a logical plan
- u.failAnalysis(e.getMessage)
+ u.failAnalysis(e.getMessage, e)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index f85971be394b1..1a6b32429313a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -22,14 +22,14 @@ import scala.reflect.ClassTag
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
+import org.apache.spark.sql.sources.v2.reader.InputPartition
-class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T])
+class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: InputPartition[T])
extends Partition with Serializable
class DataSourceRDD[T: ClassTag](
sc: SparkContext,
- @transient private val readerFactories: Seq[DataReaderFactory[T]])
+ @transient private val readerFactories: Seq[InputPartition[T]])
extends RDD[T](sc, Nil) {
override protected def getPartitions: Array[Partition] = {
@@ -39,7 +39,8 @@ class DataSourceRDD[T: ClassTag](
}
override def compute(split: Partition, context: TaskContext): Iterator[T] = {
- val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader()
+ val reader = split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition
+ .createPartitionReader()
context.addTaskCompletionListener(_ => reader.close())
val iter = new Iterator[T] {
private[this] var valuePrepared = false
@@ -63,6 +64,6 @@ class DataSourceRDD[T: ClassTag](
}
override def getPreferredLocations(split: Partition): Seq[String] = {
- split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.preferredLocations()
+ split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition.preferredLocations()
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
index 77cb707340b0f..c6a7684bf6ab0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
@@ -59,13 +59,13 @@ case class DataSourceV2ScanExec(
}
override def outputPartitioning: physical.Partitioning = reader match {
- case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchReaderFactories.size == 1 =>
+ case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchPartitions.size == 1 =>
SinglePartition
- case r: SupportsScanColumnarBatch if !r.enableBatchRead() && readerFactories.size == 1 =>
+ case r: SupportsScanColumnarBatch if !r.enableBatchRead() && partitions.size == 1 =>
SinglePartition
- case r if !r.isInstanceOf[SupportsScanColumnarBatch] && readerFactories.size == 1 =>
+ case r if !r.isInstanceOf[SupportsScanColumnarBatch] && partitions.size == 1 =>
SinglePartition
case s: SupportsReportPartitioning =>
@@ -75,19 +75,19 @@ case class DataSourceV2ScanExec(
case _ => super.outputPartitioning
}
- private lazy val readerFactories: Seq[DataReaderFactory[UnsafeRow]] = reader match {
- case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories().asScala
+ private lazy val partitions: Seq[InputPartition[UnsafeRow]] = reader match {
+ case r: SupportsScanUnsafeRow => r.planUnsafeInputPartitions().asScala
case _ =>
- reader.createDataReaderFactories().asScala.map {
- new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow]
+ reader.planInputPartitions().asScala.map {
+ new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[UnsafeRow]
}
}
- private lazy val batchReaderFactories: Seq[DataReaderFactory[ColumnarBatch]] = reader match {
+ private lazy val batchPartitions: Seq[InputPartition[ColumnarBatch]] = reader match {
case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
assert(!reader.isInstanceOf[ContinuousReader],
"continuous stream reader does not support columnar read yet.")
- r.createBatchDataReaderFactories().asScala
+ r.planBatchInputPartitions().asScala
}
private lazy val inputRDD: RDD[InternalRow] = reader match {
@@ -95,19 +95,18 @@ case class DataSourceV2ScanExec(
EpochCoordinatorRef.get(
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
sparkContext.env)
- .askSync[Unit](SetReaderPartitions(readerFactories.size))
+ .askSync[Unit](SetReaderPartitions(partitions.size))
new ContinuousDataSourceRDD(
sparkContext,
sqlContext.conf.continuousStreamingExecutorQueueSize,
sqlContext.conf.continuousStreamingExecutorPollIntervalMs,
- readerFactories)
- .asInstanceOf[RDD[InternalRow]]
+ partitions).asInstanceOf[RDD[InternalRow]]
case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
- new DataSourceRDD(sparkContext, batchReaderFactories).asInstanceOf[RDD[InternalRow]]
+ new DataSourceRDD(sparkContext, batchPartitions).asInstanceOf[RDD[InternalRow]]
case _ =>
- new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]]
+ new DataSourceRDD(sparkContext, partitions).asInstanceOf[RDD[InternalRow]]
}
override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD)
@@ -132,19 +131,22 @@ case class DataSourceV2ScanExec(
}
}
-class RowToUnsafeRowDataReaderFactory(rowReaderFactory: DataReaderFactory[Row], schema: StructType)
- extends DataReaderFactory[UnsafeRow] {
+class RowToUnsafeRowInputPartition(partition: InputPartition[Row], schema: StructType)
+ extends InputPartition[UnsafeRow] {
- override def preferredLocations: Array[String] = rowReaderFactory.preferredLocations
+ override def preferredLocations: Array[String] = partition.preferredLocations
- override def createDataReader: DataReader[UnsafeRow] = {
- new RowToUnsafeDataReader(
- rowReaderFactory.createDataReader, RowEncoder.apply(schema).resolveAndBind())
+ override def createPartitionReader: InputPartitionReader[UnsafeRow] = {
+ new RowToUnsafeInputPartitionReader(
+ partition.createPartitionReader, RowEncoder.apply(schema).resolveAndBind())
}
}
-class RowToUnsafeDataReader(val rowReader: DataReader[Row], encoder: ExpressionEncoder[Row])
- extends DataReader[UnsafeRow] {
+class RowToUnsafeInputPartitionReader(
+ val rowReader: InputPartitionReader[Row],
+ encoder: ExpressionEncoder[Row])
+
+ extends InputPartitionReader[UnsafeRow] {
override def next: Boolean = rowReader.next
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
index 9293d4f831bff..e894f8afd6762 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
@@ -23,17 +23,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project
import org.apache.spark.sql.catalyst.rules.Rule
object PushDownOperatorsToDataSource extends Rule[LogicalPlan] {
- override def apply(
- plan: LogicalPlan): LogicalPlan = plan transformUp {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan match {
// PhysicalOperation guarantees that filters are deterministic; no need to check
- case PhysicalOperation(project, newFilters, relation : DataSourceV2Relation) =>
- // merge the filters
- val filters = relation.filters match {
- case Some(existing) =>
- existing ++ newFilters
- case _ =>
- newFilters
- }
+ case PhysicalOperation(project, filters, relation: DataSourceV2Relation) =>
+ assert(relation.filters.isEmpty, "data source v2 should do push down only once.")
val projectAttrs = project.map(_.toAttribute)
val projectSet = AttributeSet(project.flatMap(_.references))
@@ -67,5 +60,7 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] {
} else {
filtered
}
+
+ case other => other.mapChildren(apply)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
index f02d3a2c3733f..24195b5657e8a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
@@ -66,6 +66,7 @@ case class StreamingExecutionRelation(
output: Seq[Attribute])(session: SparkSession)
extends LeafNode with MultiInstanceRelation {
+ override def otherCopyArgs: Seq[AnyRef] = session :: Nil
override def isStreaming: Boolean = true
override def toString: String = source.toString
@@ -97,6 +98,7 @@ case class StreamingRelationV2(
output: Seq[Attribute],
v1Relation: Option[StreamingRelation])(session: SparkSession)
extends LeafNode with MultiInstanceRelation {
+ override def otherCopyArgs: Seq[AnyRef] = session :: Nil
override def isStreaming: Boolean = true
override def toString: String = sourceName
@@ -116,6 +118,7 @@ case class ContinuousExecutionRelation(
output: Seq[Attribute])(session: SparkSession)
extends LeafNode with MultiInstanceRelation {
+ override def otherCopyArgs: Seq[AnyRef] = session :: Nil
override def isStreaming: Boolean = true
override def toString: String = source.toString
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
index fa7c8ee906ecd..afa664eb76525 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
@@ -187,6 +187,17 @@ case class StreamingSymmetricHashJoinExec(
s"${getClass.getSimpleName} should not take $x as the JoinType")
}
+ override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = {
+ val watermarkUsedForStateCleanup =
+ stateWatermarkPredicates.left.nonEmpty || stateWatermarkPredicates.right.nonEmpty
+
+ // Latest watermark value is more than that used in this previous executed plan
+ val watermarkHasChanged =
+ eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get
+
+ watermarkUsedForStateCleanup && watermarkHasChanged
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val stateStoreCoord = sqlContext.sessionState.streamingQueryManager.stateStoreCoordinator
val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
@@ -319,8 +330,7 @@ case class StreamingSymmetricHashJoinExec(
// outer join) if possible. In all cases, nothing needs to be outputted, hence the removal
// needs to be done greedily by immediately consuming the returned iterator.
val cleanupIter = joinType match {
- case Inner =>
- leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState()
+ case Inner => leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState()
case LeftOuter => rightSideJoiner.removeOldState()
case RightOuter => leftSideJoiner.removeOldState()
case _ => throwBadJoinTypeException()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
index 0a3b9dcccb6c5..a7ccce10b0cee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
@@ -21,14 +21,14 @@ import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader}
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeInputPartitionReader}
import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset}
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, PartitionOffset}
import org.apache.spark.util.{NextIterator, ThreadUtils}
class ContinuousDataSourceRDDPartition(
val index: Int,
- val readerFactory: DataReaderFactory[UnsafeRow])
+ val inputPartition: InputPartition[UnsafeRow])
extends Partition with Serializable {
// This is semantically a lazy val - it's initialized once the first time a call to
@@ -51,12 +51,12 @@ class ContinuousDataSourceRDD(
sc: SparkContext,
dataQueueSize: Int,
epochPollIntervalMs: Long,
- @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]])
+ @transient private val readerFactories: Seq[InputPartition[UnsafeRow]])
extends RDD[UnsafeRow](sc, Nil) {
override protected def getPartitions: Array[Partition] = {
readerFactories.zipWithIndex.map {
- case (readerFactory, index) => new ContinuousDataSourceRDDPartition(index, readerFactory)
+ case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition)
}.toArray
}
@@ -75,7 +75,7 @@ class ContinuousDataSourceRDD(
if (partition.queueReader == null) {
partition.queueReader =
new ContinuousQueuedDataReader(
- partition.readerFactory, context, dataQueueSize, epochPollIntervalMs)
+ partition.inputPartition, context, dataQueueSize, epochPollIntervalMs)
}
partition.queueReader
@@ -96,17 +96,17 @@ class ContinuousDataSourceRDD(
}
override def getPreferredLocations(split: Partition): Seq[String] = {
- split.asInstanceOf[ContinuousDataSourceRDDPartition].readerFactory.preferredLocations()
+ split.asInstanceOf[ContinuousDataSourceRDDPartition].inputPartition.preferredLocations()
}
}
object ContinuousDataSourceRDD {
private[continuous] def getContinuousReader(
- reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = {
+ reader: InputPartitionReader[UnsafeRow]): ContinuousInputPartitionReader[_] = {
reader match {
- case r: ContinuousDataReader[UnsafeRow] => r
- case wrapped: RowToUnsafeDataReader =>
- wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]]
+ case r: ContinuousInputPartitionReader[UnsafeRow] => r
+ case wrapped: RowToUnsafeInputPartitionReader =>
+ wrapped.rowReader.asInstanceOf[ContinuousInputPartitionReader[Row]]
case _ =>
throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index f58146ac42398..0e7d1019b9c8f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -122,16 +122,7 @@ class ContinuousExecution(
s"Batch $latestEpochId was committed without end epoch offsets!")
}
committedOffsets = nextOffsets.toStreamProgress(sources)
-
- // Get to an epoch ID that has definitely never been sent to a sink before. Since sink
- // commit happens between offset log write and commit log write, this means an epoch ID
- // which is not in the offset log.
- val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse {
- throw new IllegalStateException(
- s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" +
- s"an element.")
- }
- currentBatchId = latestOffsetEpoch + 1
+ currentBatchId = latestEpochId + 1
logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets")
nextOffsets
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
index 01a999f6505fc..f38577b6a9f16 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
@@ -18,15 +18,14 @@
package org.apache.spark.sql.execution.streaming.continuous
import java.io.Closeable
-import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit}
-import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent.{ArrayBlockingQueue, TimeUnit}
import scala.util.control.NonFatal
-import org.apache.spark.{Partition, SparkEnv, SparkException, TaskContext}
+import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory}
+import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader}
import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset
import org.apache.spark.util.ThreadUtils
@@ -38,17 +37,15 @@ import org.apache.spark.util.ThreadUtils
* offsets across epochs. Each compute() should call the next() method here until null is returned.
*/
class ContinuousQueuedDataReader(
- factory: DataReaderFactory[UnsafeRow],
+ partition: InputPartition[UnsafeRow],
context: TaskContext,
dataQueueSize: Int,
epochPollIntervalMs: Long) extends Closeable {
- private val reader = factory.createDataReader()
+ private val reader = partition.createPartitionReader()
// Important sequencing - we must get our starting point before the provider threads start running
private var currentOffset: PartitionOffset =
ContinuousDataSourceRDD.getContinuousReader(reader).getOffset
- private var currentEpoch: Long =
- context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
/**
* The record types in the read buffer.
@@ -116,8 +113,7 @@ class ContinuousQueuedDataReader(
currentEntry match {
case EpochMarker =>
epochCoordEndpoint.send(ReportPartitionOffset(
- context.partitionId(), currentEpoch, currentOffset))
- currentEpoch += 1
+ context.partitionId(), EpochTracker.getCurrentEpoch.get, currentOffset))
null
case ContinuousRow(row, offset) =>
currentOffset = offset
@@ -132,7 +128,7 @@ class ContinuousQueuedDataReader(
/**
* The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when
- * a new row arrives to the [[DataReader]].
+ * a new row arrives to the [[InputPartitionReader]].
*/
class DataReaderThread extends Thread(
s"continuous-reader--${context.partitionId()}--" +
@@ -185,7 +181,7 @@ class ContinuousQueuedDataReader(
private val epochCoordEndpoint = EpochCoordinatorRef.get(
context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get)
- // Note that this is *not* the same as the currentEpoch in [[ContinuousDataQueuedReader]]! That
+ // Note that this is *not* the same as the currentEpoch in [[ContinuousWriteRDD]]! That
// field represents the epoch wrt the data being processed. The currentEpoch here is just a
// counter to ensure we send the appropriate number of markers if we fall behind the driver.
private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
index 2f0de2612c150..8d25d9ccc43d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeM
import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset}
import org.apache.spark.sql.types.StructType
case class RateStreamPartitionOffset(
@@ -67,7 +67,7 @@ class RateStreamContinuousReader(options: DataSourceOptions)
override def getStartOffset(): Offset = offset
- override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = {
+ override def planInputPartitions(): java.util.List[InputPartition[Row]] = {
val partitionStartMap = offset match {
case off: RateStreamOffset => off.partitionToValueAndRunTimeMs
case off =>
@@ -91,7 +91,7 @@ class RateStreamContinuousReader(options: DataSourceOptions)
i,
numPartitions,
perPartitionRate)
- .asInstanceOf[DataReaderFactory[Row]]
+ .asInstanceOf[InputPartition[Row]]
}.asJava
}
@@ -119,13 +119,13 @@ case class RateStreamContinuousDataReaderFactory(
partitionIndex: Int,
increment: Long,
rowsPerSecond: Double)
- extends ContinuousDataReaderFactory[Row] {
+ extends ContinuousInputPartition[Row] {
- override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[Row] = {
+ override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[Row] = {
val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset]
require(rateStreamOffset.partition == partitionIndex,
s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}")
- new RateStreamContinuousDataReader(
+ new RateStreamContinuousInputPartitionReader(
rateStreamOffset.currentValue,
rateStreamOffset.currentTimeMs,
partitionIndex,
@@ -133,18 +133,18 @@ case class RateStreamContinuousDataReaderFactory(
rowsPerSecond)
}
- override def createDataReader(): DataReader[Row] =
- new RateStreamContinuousDataReader(
+ override def createPartitionReader(): InputPartitionReader[Row] =
+ new RateStreamContinuousInputPartitionReader(
startValue, startTimeMs, partitionIndex, increment, rowsPerSecond)
}
-class RateStreamContinuousDataReader(
+class RateStreamContinuousInputPartitionReader(
startValue: Long,
startTimeMs: Long,
partitionIndex: Int,
increment: Long,
rowsPerSecond: Double)
- extends ContinuousDataReader[Row] {
+ extends ContinuousInputPartitionReader[Row] {
private var nextReadTime: Long = startTimeMs
private val readTimeIncrement: Long = (1000 / rowsPerSecond).toLong
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
index 91f1576581511..ef5f0da1e7cc2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
@@ -45,7 +45,8 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor
val epochCoordinator = EpochCoordinatorRef.get(
context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
SparkEnv.get)
- var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
+ EpochTracker.initializeCurrentEpoch(
+ context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong)
while (!context.isInterrupted() && !context.isCompleted()) {
var dataWriter: DataWriter[InternalRow] = null
@@ -54,19 +55,24 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor
try {
val dataIterator = prev.compute(split, context)
dataWriter = writeTask.createDataWriter(
- context.partitionId(), context.attemptNumber(), currentEpoch)
+ context.partitionId(),
+ context.attemptNumber(),
+ EpochTracker.getCurrentEpoch.get)
while (dataIterator.hasNext) {
dataWriter.write(dataIterator.next())
}
logInfo(s"Writer for partition ${context.partitionId()} " +
- s"in epoch $currentEpoch is committing.")
+ s"in epoch ${EpochTracker.getCurrentEpoch.get} is committing.")
val msg = dataWriter.commit()
epochCoordinator.send(
- CommitPartitionEpoch(context.partitionId(), currentEpoch, msg)
+ CommitPartitionEpoch(
+ context.partitionId(),
+ EpochTracker.getCurrentEpoch.get,
+ msg)
)
logInfo(s"Writer for partition ${context.partitionId()} " +
- s"in epoch $currentEpoch committed.")
- currentEpoch += 1
+ s"in epoch ${EpochTracker.getCurrentEpoch.get} committed.")
+ EpochTracker.incrementCurrentEpoch()
} catch {
case _: InterruptedException =>
// Continuous shutdown always involves an interrupt. Just finish the task.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala
new file mode 100644
index 0000000000000..bc0ae428d4521
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import java.util.concurrent.atomic.AtomicLong
+
+/**
+ * Tracks the current continuous processing epoch within a task. Call
+ * EpochTracker.getCurrentEpoch to get the current epoch.
+ */
+object EpochTracker {
+ // The current epoch. Note that this is a shared reference; ContinuousWriteRDD.compute() will
+ // update the underlying AtomicLong as it finishes epochs. Other code should only read the value.
+ private val currentEpoch: ThreadLocal[AtomicLong] = new ThreadLocal[AtomicLong] {
+ override def initialValue() = new AtomicLong(-1)
+ }
+
+ /**
+ * Get the current epoch for the current task, or None if the task has no current epoch.
+ */
+ def getCurrentEpoch: Option[Long] = {
+ currentEpoch.get().get() match {
+ case n if n < 0 => None
+ case e => Some(e)
+ }
+ }
+
+ /**
+ * Increment the current epoch for this task thread. Should be called by [[ContinuousWriteRDD]]
+ * between epochs.
+ */
+ def incrementCurrentEpoch(): Unit = {
+ currentEpoch.get().incrementAndGet()
+ }
+
+ /**
+ * Initialize the current epoch for this task thread. Should be called by [[ContinuousWriteRDD]]
+ * at the beginning of a task.
+ */
+ def initializeCurrentEpoch(startEpoch: Long): Unit = {
+ currentEpoch.get().set(startEpoch)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 22258274c70c1..daa2963220aef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -24,23 +24,21 @@ import javax.annotation.concurrent.GuardedBy
import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
-import scala.reflect.ClassTag
import scala.util.control.NonFatal
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
-import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
-import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow}
+import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
-import org.apache.spark.sql.streaming.{OutputMode, Trigger}
+import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
-
object MemoryStream {
protected val currentBlockId = new AtomicInteger(0)
protected val memoryStreamId = new AtomicInteger(0)
@@ -141,7 +139,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
if (endOffset.offset == -1) null else endOffset
}
- override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
+ override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = {
synchronized {
// Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
val startOrdinal = startOffset.offset.toInt + 1
@@ -158,7 +156,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal))
newBlocks.map { block =>
- new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory[UnsafeRow]]
+ new MemoryStreamDataReaderFactory(block).asInstanceOf[InputPartition[UnsafeRow]]
}.asJava
}
}
@@ -204,9 +202,9 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
class MemoryStreamDataReaderFactory(records: Array[UnsafeRow])
- extends DataReaderFactory[UnsafeRow] {
- override def createDataReader(): DataReader[UnsafeRow] = {
- new DataReader[UnsafeRow] {
+ extends InputPartition[UnsafeRow] {
+ override def createPartitionReader(): InputPartitionReader[UnsafeRow] = {
+ new InputPartitionReader[UnsafeRow] {
private var currentIndex = -1
override def next(): Boolean = {
@@ -307,7 +305,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink
case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode {
def this(sink: MemorySink) = this(sink, sink.schema.toAttributes)
- private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum
+ private val sizePerRow = EstimationUtils.getSizePerRow(sink.schema.toAttributes)
override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
index a8fca3c19a2d2..4daafa65850de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
@@ -34,8 +34,8 @@ import org.apache.spark.sql.{Encoder, Row, SQLContext}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord
import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions}
-import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
+import org.apache.spark.sql.sources.v2.reader.InputPartition
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.RpcUtils
@@ -47,10 +47,9 @@ import org.apache.spark.util.RpcUtils
* ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified
* offset within the list, or null if that offset doesn't yet have a record.
*/
-class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
+class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2)
extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport {
private implicit val formats = Serialization.formats(NoTypeHints)
- private val NUM_PARTITIONS = 2
protected val logicalPlan =
StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession)
@@ -58,7 +57,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
// ContinuousReader implementation
@GuardedBy("this")
- private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A])
+ private val records = Seq.fill(numPartitions)(new ListBuffer[A])
@GuardedBy("this")
private var startOffset: ContinuousMemoryStreamOffset = _
@@ -69,17 +68,17 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
def addData(data: TraversableOnce[A]): Offset = synchronized {
// Distribute data evenly among partition lists.
data.toSeq.zipWithIndex.map {
- case (item, index) => records(index % NUM_PARTITIONS) += item
+ case (item, index) => records(index % numPartitions) += item
}
// The new target offset is the offset where all records in all partitions have been processed.
- ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, records(i).size)).toMap)
+ ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap)
}
override def setStartOffset(start: Optional[Offset]): Unit = synchronized {
// Inferred initial offset is position 0 in each partition.
startOffset = start.orElse {
- ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 0)).toMap)
+ ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap)
}.asInstanceOf[ContinuousMemoryStreamOffset]
}
@@ -99,7 +98,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
)
}
- override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = {
+ override def planInputPartitions(): ju.List[InputPartition[Row]] = {
synchronized {
val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id"
endpointRef =
@@ -108,7 +107,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
startOffset.partitionNums.map {
case (part, index) =>
new ContinuousMemoryStreamDataReaderFactory(
- endpointName, part, index): DataReaderFactory[Row]
+ endpointName, part, index): InputPartition[Row]
}.toList.asJava
}
}
@@ -152,6 +151,9 @@ object ContinuousMemoryStream {
def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
+
+ def singlePartition[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
+ new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, 1)
}
/**
@@ -160,9 +162,9 @@ object ContinuousMemoryStream {
class ContinuousMemoryStreamDataReaderFactory(
driverEndpointName: String,
partition: Int,
- startOffset: Int) extends DataReaderFactory[Row] {
- override def createDataReader: ContinuousMemoryStreamDataReader =
- new ContinuousMemoryStreamDataReader(driverEndpointName, partition, startOffset)
+ startOffset: Int) extends InputPartition[Row] {
+ override def createPartitionReader: ContinuousMemoryStreamInputPartitionReader =
+ new ContinuousMemoryStreamInputPartitionReader(driverEndpointName, partition, startOffset)
}
/**
@@ -170,10 +172,10 @@ class ContinuousMemoryStreamDataReaderFactory(
*
* Polls the driver endpoint for new records.
*/
-class ContinuousMemoryStreamDataReader(
+class ContinuousMemoryStreamInputPartitionReader(
driverEndpointName: String,
partition: Int,
- startOffset: Int) extends ContinuousDataReader[Row] {
+ startOffset: Int) extends ContinuousInputPartitionReader[Row] {
private val endpoint = RpcUtils.makeDriverRef(
driverEndpointName,
SparkEnv.get.conf,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala
index f54291bea6678..723cc3ad5bb89 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala
@@ -134,7 +134,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation:
LongOffset(json.toLong)
}
- override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = {
+ override def planInputPartitions(): java.util.List[InputPartition[Row]] = {
val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L)
val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L)
assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)")
@@ -169,7 +169,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation:
(0 until numPartitions).map { p =>
new RateStreamMicroBatchDataReaderFactory(
p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue)
- : DataReaderFactory[Row]
+ : InputPartition[Row]
}.toList.asJava
}
@@ -188,19 +188,20 @@ class RateStreamMicroBatchDataReaderFactory(
rangeStart: Long,
rangeEnd: Long,
localStartTimeMs: Long,
- relativeMsPerValue: Double) extends DataReaderFactory[Row] {
+ relativeMsPerValue: Double) extends InputPartition[Row] {
- override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader(
- partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue)
+ override def createPartitionReader(): InputPartitionReader[Row] =
+ new RateStreamMicroBatchInputPartitionReader(partitionId, numPartitions, rangeStart, rangeEnd,
+ localStartTimeMs, relativeMsPerValue)
}
-class RateStreamMicroBatchDataReader(
+class RateStreamMicroBatchInputPartitionReader(
partitionId: Int,
numPartitions: Int,
rangeStart: Long,
rangeEnd: Long,
localStartTimeMs: Long,
- relativeMsPerValue: Double) extends DataReader[Row] {
+ relativeMsPerValue: Double) extends InputPartitionReader[Row] {
private var count = 0
override def next(): Boolean = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
index 0d6c239274dd8..468313bfe8c3c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
@@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update}
import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink}
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport}
@@ -182,7 +183,7 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode)
* Used to query the data that has been written into a [[MemorySinkV2]].
*/
case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode {
- private val sizePerRow = output.map(_.dataType.defaultSize).sum
+ private val sizePerRow = EstimationUtils.getSizePerRow(output)
override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
index 90f4a5ba4234d..8240e06d4ab72 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.execution.streaming.LongOffset
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport}
-import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory}
+import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
@@ -140,7 +140,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR
}
}
- override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
+ override def planInputPartitions(): JList[InputPartition[Row]] = {
assert(startOffset != null && endOffset != null,
"start offset and end offset should already be set before create read tasks.")
@@ -165,21 +165,22 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR
(0 until numPartitions).map { i =>
val slice = slices(i)
- new DataReaderFactory[Row] {
- override def createDataReader(): DataReader[Row] = new DataReader[Row] {
- private var currentIdx = -1
+ new InputPartition[Row] {
+ override def createPartitionReader(): InputPartitionReader[Row] =
+ new InputPartitionReader[Row] {
+ private var currentIdx = -1
+
+ override def next(): Boolean = {
+ currentIdx += 1
+ currentIdx < slice.size
+ }
- override def next(): Boolean = {
- currentIdx += 1
- currentIdx < slice.size
- }
+ override def get(): Row = {
+ Row(slice(currentIdx)._1, slice(currentIdx)._2)
+ }
- override def get(): Row = {
- Row(slice(currentIdx)._1, slice(currentIdx)._2)
+ override def close(): Unit = {}
}
-
- override def close(): Unit = {}
- }
}
}.toList.asJava
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
index 01d8e75980993..3f11b8f79943c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -23,6 +23,7 @@ import scala.reflect.ClassTag
import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.execution.streaming.continuous.EpochTracker
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
@@ -71,8 +72,15 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
StateStoreId(checkpointLocation, operatorId, partition.index),
queryRunId)
+ // If we're in continuous processing mode, we should get the store version for the current
+ // epoch rather than the one at planning time.
+ val currentVersion = EpochTracker.getCurrentEpoch match {
+ case None => storeVersion
+ case Some(value) => value
+ }
+
store = StateStore.get(
- storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion,
+ storeProviderId, keySchema, valueSchema, indexOrdinal, currentVersion,
storeConf, hadoopConfBroadcast.value.value)
val inputIter = dataRDD.iterator(partition, ctxt)
storeUpdateFunction(store, inputIter)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index b74f1d54e04f0..10153d91b5a08 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -283,6 +283,9 @@ object functions {
/**
* Aggregate function: returns a list of objects with duplicates.
*
+ * @note The function is non-deterministic because the order of collected results depends
+ * on order of rows which may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 1.6.0
*/
@@ -291,6 +294,9 @@ object functions {
/**
* Aggregate function: returns a list of objects with duplicates.
*
+ * @note The function is non-deterministic because the order of collected results depends
+ * on order of rows which may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 1.6.0
*/
@@ -299,6 +305,9 @@ object functions {
/**
* Aggregate function: returns a set of objects with duplicate elements eliminated.
*
+ * @note The function is non-deterministic because the order of collected results depends
+ * on order of rows which may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 1.6.0
*/
@@ -307,6 +316,9 @@ object functions {
/**
* Aggregate function: returns a set of objects with duplicate elements eliminated.
*
+ * @note The function is non-deterministic because the order of collected results depends
+ * on order of rows which may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 1.6.0
*/
@@ -422,6 +434,9 @@ object functions {
* The function by default returns the first values it sees. It will return the first non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
+ * @note The function is non-deterministic because its results depends on order of rows which
+ * may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 2.0.0
*/
@@ -435,6 +450,9 @@ object functions {
* The function by default returns the first values it sees. It will return the first non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
+ * @note The function is non-deterministic because its results depends on order of rows which
+ * may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 2.0.0
*/
@@ -448,6 +466,9 @@ object functions {
* The function by default returns the first values it sees. It will return the first non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
+ * @note The function is non-deterministic because its results depends on order of rows which
+ * may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 1.3.0
*/
@@ -459,6 +480,9 @@ object functions {
* The function by default returns the first values it sees. It will return the first non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
+ * @note The function is non-deterministic because its results depends on order of rows which
+ * may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 1.3.0
*/
@@ -535,6 +559,9 @@ object functions {
* The function by default returns the last values it sees. It will return the last non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
+ * @note The function is non-deterministic because its results depends on order of rows which
+ * may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 2.0.0
*/
@@ -548,6 +575,9 @@ object functions {
* The function by default returns the last values it sees. It will return the last non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
+ * @note The function is non-deterministic because its results depends on order of rows which
+ * may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 2.0.0
*/
@@ -561,6 +591,9 @@ object functions {
* The function by default returns the last values it sees. It will return the last non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
+ * @note The function is non-deterministic because its results depends on order of rows which
+ * may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 1.3.0
*/
@@ -572,6 +605,9 @@ object functions {
* The function by default returns the last values it sees. It will return the last non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
+ * @note The function is non-deterministic because its results depends on order of rows which
+ * may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 1.3.0
*/
@@ -775,6 +811,7 @@ object functions {
*/
def var_pop(columnName: String): Column = var_pop(Column(columnName))
+
//////////////////////////////////////////////////////////////////////////////////////////////
// Window functions
//////////////////////////////////////////////////////////////////////////////////////////////
@@ -1172,7 +1209,7 @@ object functions {
* Generate a random column with independent and identically distributed (i.i.d.) samples
* from U[0.0, 1.0].
*
- * @note This is indeterministic when data partitions are not fixed.
+ * @note The function is non-deterministic in general case.
*
* @group normal_funcs
* @since 1.4.0
@@ -1183,6 +1220,8 @@ object functions {
* Generate a random column with independent and identically distributed (i.i.d.) samples
* from U[0.0, 1.0].
*
+ * @note The function is non-deterministic in general case.
+ *
* @group normal_funcs
* @since 1.4.0
*/
@@ -1192,7 +1231,7 @@ object functions {
* Generate a column with independent and identically distributed (i.i.d.) samples from
* the standard normal distribution.
*
- * @note This is indeterministic when data partitions are not fixed.
+ * @note The function is non-deterministic in general case.
*
* @group normal_funcs
* @since 1.4.0
@@ -1203,6 +1242,8 @@ object functions {
* Generate a column with independent and identically distributed (i.i.d.) samples from
* the standard normal distribution.
*
+ * @note The function is non-deterministic in general case.
+ *
* @group normal_funcs
* @since 1.4.0
*/
@@ -1211,7 +1252,7 @@ object functions {
/**
* Partition ID.
*
- * @note This is indeterministic because it depends on data partitioning and task scheduling.
+ * @note This is non-deterministic because it depends on data partitioning and task scheduling.
*
* @group normal_funcs
* @since 1.6.0
@@ -2691,7 +2732,12 @@ object functions {
/**
* Returns number of months between dates `date1` and `date2`.
- * The result is rounded off to 8 digits.
+ * If `date1` is later than `date2`, then the result is positive.
+ * If `date1` and `date2` are on the same day of month, or both are the last day of month,
+ * time of day will be ignored.
+ *
+ * Otherwise, the difference is calculated based on 31 days per month, and rounded to
+ * 8 digits.
* @group datetime_funcs
* @since 1.5.0
*/
@@ -3185,9 +3231,9 @@ object functions {
from_json(e, schema.asInstanceOf[DataType], options)
/**
- * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType`
- * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable
- * string.
+ * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType`
+ * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema.
+ * Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string
@@ -3217,9 +3263,9 @@ object functions {
from_json(e, schema, options.asScala.toMap)
/**
- * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType`
- * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable
- * string.
+ * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType`
+ * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema.
+ * Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string
@@ -3246,8 +3292,9 @@ object functions {
from_json(e, schema, Map.empty[String, String])
/**
- * Parses a column containing a JSON string into a `StructType` or `ArrayType` of `StructType`s
- * with the specified schema. Returns `null`, in the case of an unparseable string.
+ * Parses a column containing a JSON string into a `MapType` with `StringType` as keys type,
+ * `StructType` or `ArrayType` of `StructType`s with the specified schema.
+ * Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string
@@ -3259,9 +3306,9 @@ object functions {
from_json(e, schema, Map.empty[String, String])
/**
- * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType`
- * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable
- * string.
+ * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType`
+ * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema.
+ * Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string as a json string. In Spark 2.1,
@@ -3276,9 +3323,9 @@ object functions {
}
/**
- * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType`
- * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable
- * string.
+ * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType`
+ * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema.
+ * Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string as a json string, it could be a
@@ -3400,6 +3447,26 @@ object functions {
*/
def flatten(e: Column): Column = withExpr { Flatten(e.expr) }
+ /**
+ * Creates an array containing the left argument repeated the number of times given by the
+ * right argument.
+ *
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def array_repeat(left: Column, right: Column): Column = withExpr {
+ ArrayRepeat(left.expr, right.expr)
+ }
+
+ /**
+ * Creates an array containing the left argument repeated the number of times given by the
+ * right argument.
+ *
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def array_repeat(e: Column, count: Int): Column = array_repeat(e, lit(count))
+
/**
* Returns an unordered array containing the keys of the map.
* @group collection_funcs
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
index 7cefd03e43bc3..97da2b1325f58 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
@@ -242,7 +242,9 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
(sink, trigger) match {
case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) =>
- UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode)
+ if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) {
+ UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode)
+ }
new StreamingQueryWrapper(new ContinuousExecution(
sparkSession,
userSpecifiedName.orNull,
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
index 172e5d5eebcbe..445cb29f5ee3a 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
@@ -79,8 +79,8 @@ public Filter[] pushedFilters() {
}
@Override
- public List> createDataReaderFactories() {
- List> res = new ArrayList<>();
+ public List> planInputPartitions() {
+ List> res = new ArrayList<>();
Integer lowerBound = null;
for (Filter filter : filters) {
@@ -94,33 +94,34 @@ public List> createDataReaderFactories() {
}
if (lowerBound == null) {
- res.add(new JavaAdvancedDataReaderFactory(0, 5, requiredSchema));
- res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema));
+ res.add(new JavaAdvancedInputPartition(0, 5, requiredSchema));
+ res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema));
} else if (lowerBound < 4) {
- res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 5, requiredSchema));
- res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema));
+ res.add(new JavaAdvancedInputPartition(lowerBound + 1, 5, requiredSchema));
+ res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema));
} else if (lowerBound < 9) {
- res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 10, requiredSchema));
+ res.add(new JavaAdvancedInputPartition(lowerBound + 1, 10, requiredSchema));
}
return res;
}
}
- static class JavaAdvancedDataReaderFactory implements DataReaderFactory, DataReader {
+ static class JavaAdvancedInputPartition implements InputPartition,
+ InputPartitionReader {
private int start;
private int end;
private StructType requiredSchema;
- JavaAdvancedDataReaderFactory(int start, int end, StructType requiredSchema) {
+ JavaAdvancedInputPartition(int start, int end, StructType requiredSchema) {
this.start = start;
this.end = end;
this.requiredSchema = requiredSchema;
}
@Override
- public DataReader createDataReader() {
- return new JavaAdvancedDataReaderFactory(start - 1, end, requiredSchema);
+ public InputPartitionReader createPartitionReader() {
+ return new JavaAdvancedInputPartition(start - 1, end, requiredSchema);
}
@Override
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java
index c55093768105b..97d6176d02559 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java
@@ -42,14 +42,14 @@ public StructType readSchema() {
}
@Override
- public List> createBatchDataReaderFactories() {
+ public List> planBatchInputPartitions() {
return java.util.Arrays.asList(
- new JavaBatchDataReaderFactory(0, 50), new JavaBatchDataReaderFactory(50, 90));
+ new JavaBatchInputPartition(0, 50), new JavaBatchInputPartition(50, 90));
}
}
- static class JavaBatchDataReaderFactory
- implements DataReaderFactory, DataReader {
+ static class JavaBatchInputPartition
+ implements InputPartition, InputPartitionReader {
private int start;
private int end;
@@ -59,13 +59,13 @@ static class JavaBatchDataReaderFactory
private OnHeapColumnVector j;
private ColumnarBatch batch;
- JavaBatchDataReaderFactory(int start, int end) {
+ JavaBatchInputPartition(int start, int end) {
this.start = start;
this.end = end;
}
@Override
- public DataReader createDataReader() {
+ public InputPartitionReader createPartitionReader() {
this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType);
this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType);
ColumnVector[] vectors = new ColumnVector[2];
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
index 32fad59b97ff6..e49c8cf8b9e16 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
@@ -43,10 +43,10 @@ public StructType readSchema() {
}
@Override
- public List> createDataReaderFactories() {
+ public List> planInputPartitions() {
return java.util.Arrays.asList(
- new SpecificDataReaderFactory(new int[]{1, 1, 3}, new int[]{4, 4, 6}),
- new SpecificDataReaderFactory(new int[]{2, 4, 4}, new int[]{6, 2, 2}));
+ new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}),
+ new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2}));
}
@Override
@@ -73,12 +73,12 @@ public boolean satisfy(Distribution distribution) {
}
}
- static class SpecificDataReaderFactory implements DataReaderFactory, DataReader {
+ static class SpecificInputPartition implements InputPartition, InputPartitionReader {
private int[] i;
private int[] j;
private int current = -1;
- SpecificDataReaderFactory(int[] i, int[] j) {
+ SpecificInputPartition(int[] i, int[] j) {
assert i.length == j.length;
this.i = i;
this.j = j;
@@ -101,7 +101,7 @@ public void close() throws IOException {
}
@Override
- public DataReader createDataReader() {
+ public InputPartitionReader createPartitionReader() {
return this;
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java
index 048d078dfaac4..80eeffd95f83b 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java
@@ -24,7 +24,7 @@
import org.apache.spark.sql.sources.v2.DataSourceV2;
import org.apache.spark.sql.sources.v2.ReadSupportWithSchema;
import org.apache.spark.sql.sources.v2.reader.DataSourceReader;
-import org.apache.spark.sql.sources.v2.reader.DataReaderFactory;
+import org.apache.spark.sql.sources.v2.reader.InputPartition;
import org.apache.spark.sql.types.StructType;
public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema {
@@ -42,7 +42,7 @@ public StructType readSchema() {
}
@Override
- public List> createDataReaderFactories() {
+ public List> planInputPartitions() {
return java.util.Collections.emptyList();
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java
index 96f55b8a76811..8522a63898a3b 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java
@@ -25,8 +25,8 @@
import org.apache.spark.sql.sources.v2.DataSourceV2;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.ReadSupport;
-import org.apache.spark.sql.sources.v2.reader.DataReader;
-import org.apache.spark.sql.sources.v2.reader.DataReaderFactory;
+import org.apache.spark.sql.sources.v2.reader.InputPartitionReader;
+import org.apache.spark.sql.sources.v2.reader.InputPartition;
import org.apache.spark.sql.sources.v2.reader.DataSourceReader;
import org.apache.spark.sql.types.StructType;
@@ -41,25 +41,25 @@ public StructType readSchema() {
}
@Override
- public List> createDataReaderFactories() {
+ public List> planInputPartitions() {
return java.util.Arrays.asList(
- new JavaSimpleDataReaderFactory(0, 5),
- new JavaSimpleDataReaderFactory(5, 10));
+ new JavaSimpleInputPartition(0, 5),
+ new JavaSimpleInputPartition(5, 10));
}
}
- static class JavaSimpleDataReaderFactory implements DataReaderFactory, DataReader {
+ static class JavaSimpleInputPartition implements InputPartition, InputPartitionReader {
private int start;
private int end;
- JavaSimpleDataReaderFactory(int start, int end) {
+ JavaSimpleInputPartition(int start, int end) {
this.start = start;
this.end = end;
}
@Override
- public DataReader createDataReader() {
- return new JavaSimpleDataReaderFactory(start - 1, end);
+ public InputPartitionReader createPartitionReader() {
+ return new JavaSimpleInputPartition(start - 1, end);
}
@Override
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java
index c3916e0b370b5..3ad8e7a0104ce 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java
@@ -38,20 +38,20 @@ public StructType readSchema() {
}
@Override
- public List> createUnsafeRowReaderFactories() {
+ public List> planUnsafeInputPartitions() {
return java.util.Arrays.asList(
- new JavaUnsafeRowDataReaderFactory(0, 5),
- new JavaUnsafeRowDataReaderFactory(5, 10));
+ new JavaUnsafeRowInputPartition(0, 5),
+ new JavaUnsafeRowInputPartition(5, 10));
}
}
- static class JavaUnsafeRowDataReaderFactory
- implements DataReaderFactory, DataReader {
+ static class JavaUnsafeRowInputPartition
+ implements InputPartition, InputPartitionReader {
private int start;
private int end;
private UnsafeRow row;
- JavaUnsafeRowDataReaderFactory(int start, int end) {
+ JavaUnsafeRowInputPartition(int start, int end) {
this.start = start;
this.end = end;
this.row = new UnsafeRow(2);
@@ -59,8 +59,8 @@ static class JavaUnsafeRowDataReaderFactory
}
@Override
- public DataReader createDataReader() {
- return new JavaUnsafeRowDataReaderFactory(start - 1, end);
+ public InputPartitionReader createPartitionReader() {
+ return new JavaUnsafeRowInputPartition(start - 1, end);
}
@Override
diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql
new file mode 100644
index 0000000000000..92c7e26e3add2
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql
@@ -0,0 +1,56 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES
+ (101, 1, 1, 1),
+ (201, 2, 1, 1),
+ (301, 3, 1, 1),
+ (401, 4, 1, 11),
+ (501, 5, 1, null),
+ (601, 6, null, 1),
+ (701, 6, null, null),
+ (102, 1, 2, 2),
+ (202, 2, 1, 2),
+ (302, 3, 2, 1),
+ (402, 4, 2, 12),
+ (502, 5, 2, null),
+ (602, 6, null, 2),
+ (702, 6, null, null),
+ (103, 1, 3, 3),
+ (203, 2, 1, 3),
+ (303, 3, 3, 1),
+ (403, 4, 3, 13),
+ (503, 5, 3, null),
+ (603, 6, null, 3),
+ (703, 6, null, null),
+ (104, 1, 4, 4),
+ (204, 2, 1, 4),
+ (304, 3, 4, 1),
+ (404, 4, 4, 14),
+ (504, 5, 4, null),
+ (604, 6, null, 4),
+ (704, 6, null, null),
+ (800, 7, 1, 1)
+as t1(id, px, y, x);
+
+select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x),
+ regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x),
+ regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x)
+from t1 group by px order by px;
+
+
+select id, regr_count(y,x) over (partition by px) from t1 order by id;
diff --git a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out
index 51dac111029e8..58ed201e2a60f 100644
--- a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out
@@ -89,7 +89,7 @@ Database default
Table t
Partition Values [ds=2017-08-01, hr=10]
Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10
-Partition Statistics 1067 bytes, 3 rows
+Partition Statistics 1121 bytes, 3 rows
# Storage Information
Location [not included in comparison]sql/core/spark-warehouse/t
@@ -122,7 +122,7 @@ Database default
Table t
Partition Values [ds=2017-08-01, hr=10]
Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10
-Partition Statistics 1067 bytes, 3 rows
+Partition Statistics 1121 bytes, 3 rows
# Storage Information
Location [not included in comparison]sql/core/spark-warehouse/t
@@ -147,7 +147,7 @@ Database default
Table t
Partition Values [ds=2017-08-01, hr=11]
Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11
-Partition Statistics 1080 bytes, 4 rows
+Partition Statistics 1098 bytes, 4 rows
# Storage Information
Location [not included in comparison]sql/core/spark-warehouse/t
@@ -180,7 +180,7 @@ Database default
Table t
Partition Values [ds=2017-08-01, hr=10]
Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10
-Partition Statistics 1067 bytes, 3 rows
+Partition Statistics 1121 bytes, 3 rows
# Storage Information
Location [not included in comparison]sql/core/spark-warehouse/t
@@ -205,7 +205,7 @@ Database default
Table t
Partition Values [ds=2017-08-01, hr=11]
Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11
-Partition Statistics 1080 bytes, 4 rows
+Partition Statistics 1098 bytes, 4 rows
# Storage Information
Location [not included in comparison]sql/core/spark-warehouse/t
@@ -230,7 +230,7 @@ Database default
Table t
Partition Values [ds=2017-09-01, hr=5]
Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-09-01/hr=5
-Partition Statistics 1054 bytes, 2 rows
+Partition Statistics 1144 bytes, 2 rows
# Storage Information
Location [not included in comparison]sql/core/spark-warehouse/t
diff --git a/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out
new file mode 100644
index 0000000000000..d7d009a64bf84
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out
@@ -0,0 +1,93 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 3
+
+
+-- !query 0
+CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES
+ (101, 1, 1, 1),
+ (201, 2, 1, 1),
+ (301, 3, 1, 1),
+ (401, 4, 1, 11),
+ (501, 5, 1, null),
+ (601, 6, null, 1),
+ (701, 6, null, null),
+ (102, 1, 2, 2),
+ (202, 2, 1, 2),
+ (302, 3, 2, 1),
+ (402, 4, 2, 12),
+ (502, 5, 2, null),
+ (602, 6, null, 2),
+ (702, 6, null, null),
+ (103, 1, 3, 3),
+ (203, 2, 1, 3),
+ (303, 3, 3, 1),
+ (403, 4, 3, 13),
+ (503, 5, 3, null),
+ (603, 6, null, 3),
+ (703, 6, null, null),
+ (104, 1, 4, 4),
+ (204, 2, 1, 4),
+ (304, 3, 4, 1),
+ (404, 4, 4, 14),
+ (504, 5, 4, null),
+ (604, 6, null, 4),
+ (704, 6, null, null),
+ (800, 7, 1, 1)
+as t1(id, px, y, x)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x),
+ regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x),
+ regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x)
+from t1 group by px order by px
+-- !query 1 schema
+struct
+-- !query 1 output
+1 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 0.0 1.0 5.0 5.0 5.0 2.5 2.5 4
+2 1.25 0.0 NULL 0.0 0.0 4 0.0 1.0 1.0 5.0 0.0 0.0 2.5 1.0 4
+3 0.0 1.25 NULL 0.0 0.0 4 NULL NULL NULL 0.0 5.0 0.0 1.0 2.5 4
+4 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 -10.0 1.0 5.0 5.0 5.0 12.5 2.5 4
+5 NULL 1.25 NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0
+6 1.25 NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0
+7 0.0 0.0 NaN NaN 0.0 1 NULL NULL NULL 0.0 0.0 0.0 1.0 1.0 1
+
+
+-- !query 2
+select id, regr_count(y,x) over (partition by px) from t1 order by id
+-- !query 2 schema
+struct
+-- !query 2 output
+101 4
+102 4
+103 4
+104 4
+201 4
+202 4
+203 4
+204 4
+301 4
+302 4
+303 4
+304 4
+401 4
+402 4
+403 4
+404 4
+501 0
+502 0
+503 0
+504 0
+601 0
+602 0
+603 0
+604 0
+701 0
+702 0
+703 0
+704 0
+800 1
diff --git a/sql/core/src/test/resources/test-data/parquet-1217.parquet b/sql/core/src/test/resources/test-data/parquet-1217.parquet
new file mode 100644
index 0000000000000..eb2dc4f799070
Binary files /dev/null and b/sql/core/src/test/resources/test-data/parquet-1217.parquet differ
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index e7776e36702ad..96c28961e5aaf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -36,6 +36,8 @@ case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Doub
class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
import testImplicits._
+ val absTol = 1e-8
+
test("groupBy") {
checkAnswer(
testData2.groupBy("a").agg(sum($"b")),
@@ -416,7 +418,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
}
test("moments") {
- val absTol = 1e-8
val sparkVariance = testData2.agg(variance('a))
checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 0033acc7cd82c..72fb9a546465c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -887,6 +887,82 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
}
}
+ test("array_repeat function") {
+ val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on
+ val strDF = Seq(
+ ("hi", 2),
+ (null, 2)
+ ).toDF("a", "b")
+
+ val strDFTwiceResult = Seq(
+ Row(Seq("hi", "hi")),
+ Row(Seq(null, null))
+ )
+
+ checkAnswer(strDF.select(array_repeat($"a", 2)), strDFTwiceResult)
+ checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), strDFTwiceResult)
+ checkAnswer(strDF.select(array_repeat($"a", $"b")), strDFTwiceResult)
+ checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), strDFTwiceResult)
+ checkAnswer(strDF.selectExpr("array_repeat(a, 2)"), strDFTwiceResult)
+ checkAnswer(strDF.selectExpr("array_repeat(a, b)"), strDFTwiceResult)
+
+ val intDF = {
+ val schema = StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", IntegerType)))
+ val data = Seq(
+ Row(3, 2),
+ Row(null, 2)
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
+ }
+
+ val intDFTwiceResult = Seq(
+ Row(Seq(3, 3)),
+ Row(Seq(null, null))
+ )
+
+ checkAnswer(intDF.select(array_repeat($"a", 2)), intDFTwiceResult)
+ checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), intDFTwiceResult)
+ checkAnswer(intDF.select(array_repeat($"a", $"b")), intDFTwiceResult)
+ checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), intDFTwiceResult)
+ checkAnswer(intDF.selectExpr("array_repeat(a, 2)"), intDFTwiceResult)
+ checkAnswer(intDF.selectExpr("array_repeat(a, b)"), intDFTwiceResult)
+
+ val nullCountDF = {
+ val schema = StructType(Seq(
+ StructField("a", StringType),
+ StructField("b", IntegerType)))
+ val data = Seq(
+ Row("hi", null),
+ Row(null, null)
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
+ }
+
+ checkAnswer(
+ nullCountDF.select(array_repeat($"a", $"b")),
+ Seq(
+ Row(null),
+ Row(null)
+ )
+ )
+
+ // Error test cases
+ val invalidTypeDF = Seq(("hi", "1")).toDF("a", "b")
+
+ intercept[AnalysisException] {
+ invalidTypeDF.select(array_repeat($"a", $"b"))
+ }
+ intercept[AnalysisException] {
+ invalidTypeDF.select(array_repeat($"a", lit("1")))
+ }
+ intercept[AnalysisException] {
+ invalidTypeDF.selectExpr("array_repeat(a, 1.0)")
+ }
+
+ }
+
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index e0f4d2ba685e1..d477d78dc14e3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1425,6 +1425,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
}
}
+ test("SPARK-23627: provide isEmpty in DataSet") {
+ val ds1 = spark.emptyDataset[Int]
+ val ds2 = Seq(1, 2, 3).toDS()
+
+ assert(ds1.isEmpty == true)
+ assert(ds2.isEmpty == false)
+ }
+
test("SPARK-22472: add null check for top-level primitive values") {
// If the primitive values are from Option, we need to do runtime null check.
val ds = Seq(Some(1), None).toDS().as[Int]
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
index 00d2acc4a1d8a..055e1fc5640f3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
@@ -326,4 +326,70 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
assert(errMsg4.getMessage.startsWith(
"A type of keys and values in map() must be string, but got"))
}
+
+ test("SPARK-24027: from_json - map") {
+ val in = Seq("""{"a": 1, "b": 2, "c": 3}""").toDS()
+ val schema =
+ """
+ |{
+ | "type" : "map",
+ | "keyType" : "string",
+ | "valueType" : "integer",
+ | "valueContainsNull" : true
+ |}
+ """.stripMargin
+ val out = in.select(from_json($"value", schema, Map[String, String]()))
+
+ assert(out.columns.head == "entries")
+ checkAnswer(out, Row(Map("a" -> 1, "b" -> 2, "c" -> 3)))
+ }
+
+ test("SPARK-24027: from_json - map") {
+ val in = Seq("""{"a": {"b": 1}}""").toDS()
+ val schema = MapType(StringType, new StructType().add("b", IntegerType), true)
+ val out = in.select(from_json($"value", schema))
+
+ checkAnswer(out, Row(Map("a" -> Row(1))))
+ }
+
+ test("SPARK-24027: from_json - map>") {
+ val in = Seq("""{"a": {"b": 1}}""").toDS()
+ val schema = MapType(StringType, MapType(StringType, IntegerType))
+ val out = in.select(from_json($"value", schema))
+
+ checkAnswer(out, Row(Map("a" -> Map("b" -> 1))))
+ }
+
+ test("SPARK-24027: roundtrip - from_json -> to_json - map") {
+ val json = """{"a":1,"b":2,"c":3}"""
+ val schema = MapType(StringType, IntegerType, true)
+ val out = Seq(json).toDS().select(to_json(from_json($"value", schema)))
+
+ checkAnswer(out, Row(json))
+ }
+
+ test("SPARK-24027: roundtrip - to_json -> from_json - map") {
+ val in = Seq(Map("a" -> 1)).toDF()
+ val schema = MapType(StringType, IntegerType, true)
+ val out = in.select(from_json(to_json($"value"), schema))
+
+ checkAnswer(out, in)
+ }
+
+ test("SPARK-24027: from_json - wrong map") {
+ val in = Seq("""{"a" 1}""").toDS()
+ val schema = MapType(StringType, IntegerType)
+ val out = in.select(from_json($"value", schema, Map[String, String]()))
+
+ checkAnswer(out, Row(null))
+ }
+
+ test("SPARK-24027: from_json of a map with unsupported key type") {
+ val schema = MapType(StructType(StructField("f", IntegerType) :: Nil), StringType)
+
+ checkAnswer(Seq("""{{"f": 1}: "a"}""").toDS().select(from_json($"value", schema)),
+ Row(null))
+ checkAnswer(Seq("""{"{"f": 1}": "a"}""").toDS().select(from_json($"value", schema)),
+ Row(null))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
index b91712f4cc25d..60fa951e23178 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
@@ -50,7 +50,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
}
assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}")
- assert(sizes.head === BigInt(96),
+ assert(sizes.head === BigInt(128),
s"expected exact size 96 for table 'test', got: ${sizes.head}")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
index 863703b15f4f1..efc2f20a907f1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
@@ -503,7 +503,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
case plan: InMemoryRelation => plan
}.head
// InMemoryRelation's stats is file size before the underlying RDD is materialized
- assert(inMemoryRelation.computeStats().sizeInBytes === 740)
+ assert(inMemoryRelation.computeStats().sizeInBytes === 800)
// InMemoryRelation's stats is updated after materializing RDD
dfFromFile.collect()
@@ -516,7 +516,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
// Even CBO enabled, InMemoryRelation's stats keeps as the file size before table's stats
// is calculated
- assert(inMemoryRelation2.computeStats().sizeInBytes === 740)
+ assert(inMemoryRelation2.computeStats().sizeInBytes === 800)
// InMemoryRelation's stats should be updated after calculating stats of the table
// clear cache to simulate a fresh environment
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala
new file mode 100644
index 0000000000000..d442ba7e59c61
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.csv
+
+import java.io.File
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.{Column, Row, SparkSession}
+import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.types._
+import org.apache.spark.util.{Benchmark, Utils}
+
+/**
+ * Benchmark to measure CSV read/write performance.
+ * To run this:
+ * spark-submit --class --jars
+ */
+object CSVBenchmarks {
+ val conf = new SparkConf()
+
+ val spark = SparkSession.builder
+ .master("local[1]")
+ .appName("benchmark-csv-datasource")
+ .config(conf)
+ .getOrCreate()
+ import spark.implicits._
+
+ def withTempPath(f: File => Unit): Unit = {
+ val path = Utils.createTempDir()
+ path.delete()
+ try f(path) finally Utils.deleteRecursively(path)
+ }
+
+ def quotedValuesBenchmark(rowsNum: Int, numIters: Int): Unit = {
+ val benchmark = new Benchmark(s"Parsing quoted values", rowsNum)
+
+ withTempPath { path =>
+ val str = (0 until 10000).map(i => s""""$i"""").mkString(",")
+
+ spark.range(rowsNum)
+ .map(_ => str)
+ .write.option("header", true)
+ .csv(path.getAbsolutePath)
+
+ val schema = new StructType().add("value", StringType)
+ val ds = spark.read.option("header", true).schema(schema).csv(path.getAbsolutePath)
+
+ benchmark.addCase(s"One quoted string", numIters) { _ =>
+ ds.filter((_: Row) => true).count()
+ }
+
+ /*
+ Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz
+
+ Parsing quoted values: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ --------------------------------------------------------------------------------------------
+ One quoted string 30273 / 30549 0.0 605451.2 1.0X
+ */
+ benchmark.run()
+ }
+ }
+
+ def main(args: Array[String]): Unit = {
+ quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 461abdd96d3f3..07e6c74b14d0d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -1322,4 +1322,50 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te
val sampled = spark.read.option("inferSchema", true).option("samplingRatio", 1.0).csv(ds)
assert(sampled.count() == ds.count())
}
+
+ test("SPARK-17916: An empty string should not be coerced to null when nullValue is passed.") {
+ val litNull: String = null
+ val df = Seq(
+ (1, "John Doe"),
+ (2, ""),
+ (3, "-"),
+ (4, litNull)
+ ).toDF("id", "name")
+
+ // Checks for new behavior where an empty string is not coerced to null when `nullValue` is
+ // set to anything but an empty string literal.
+ withTempPath { path =>
+ df.write
+ .option("nullValue", "-")
+ .csv(path.getAbsolutePath)
+ val computed = spark.read
+ .option("nullValue", "-")
+ .schema(df.schema)
+ .csv(path.getAbsolutePath)
+ val expected = Seq(
+ (1, "John Doe"),
+ (2, ""),
+ (3, litNull),
+ (4, litNull)
+ ).toDF("id", "name")
+
+ checkAnswer(computed, expected)
+ }
+ // Keeps the old behavior where empty string us coerced to nullValue is not passed.
+ withTempPath { path =>
+ df.write
+ .csv(path.getAbsolutePath)
+ val computed = spark.read
+ .schema(df.schema)
+ .csv(path.getAbsolutePath)
+ val expected = Seq(
+ (1, "John Doe"),
+ (2, litNull),
+ (3, "-"),
+ (4, litNull)
+ ).toDF("id", "name")
+
+ checkAnswer(computed, expected)
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 0db688fec9a67..4b3921c61a000 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -2313,6 +2313,25 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
}
+ test("SPARK-23723: write json in UTF-16/32 with multiline off") {
+ Seq("UTF-16", "UTF-32").foreach { encoding =>
+ withTempPath { path =>
+ val ds = spark.createDataset(Seq(
+ ("a", 1), ("b", 2), ("c", 3))
+ ).repartition(2)
+ val e = intercept[IllegalArgumentException] {
+ ds.write
+ .option("encoding", encoding)
+ .option("multiline", "false")
+ .format("json").mode("overwrite")
+ .save(path.getCanonicalPath)
+ }.getMessage
+ assert(e.contains(
+ s"$encoding encoding in the blacklist is not allowed when multiLine is disabled"))
+ }
+ }
+ }
+
def checkReadJson(lineSep: String, encoding: String, inferSchema: Boolean, id: Int): Unit = {
test(s"SPARK-23724: checks reading json in ${encoding} #${id}") {
val schema = new StructType().add("f1", StringType).add("f2", IntegerType)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index 667e0b1760e3d..90da7eb8c4fb5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -648,6 +648,18 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
}
}
}
+
+ test("SPARK-23852: Broken Parquet push-down for partially-written stats") {
+ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") {
+ // parquet-1217.parquet contains a single column with values -1, 0, 1, 2 and null.
+ // The row-group statistics include null counts, but not min and max values, which
+ // triggers PARQUET-1217.
+ val df = readResourceParquetFile("test-data/parquet-1217.parquet")
+
+ // Will return 0 rows if PARQUET-1217 is not fixed.
+ assert(df.where("col > 0").count() === 2)
+ }
+ }
}
class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
index e8420eee7fe9d..3bc36ce55d902 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
@@ -220,11 +220,11 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
sink.addBatch(0, 1 to 3)
plan.invalidateStatsCache()
- assert(plan.stats.sizeInBytes === 12)
+ assert(plan.stats.sizeInBytes === 36)
sink.addBatch(1, 4 to 6)
plan.invalidateStatsCache()
- assert(plan.stats.sizeInBytes === 24)
+ assert(plan.stats.sizeInBytes === 72)
}
ignore("stress test") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
index ff14ec38e66a8..39a010f970ce5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
@@ -142,9 +142,9 @@ class RateSourceSuite extends StreamTest {
val startOffset = LongOffset(0L)
val endOffset = LongOffset(1L)
reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
- val tasks = reader.createDataReaderFactories()
+ val tasks = reader.planInputPartitions()
assert(tasks.size == 1)
- val dataReader = tasks.get(0).createDataReader()
+ val dataReader = tasks.get(0).createPartitionReader()
val data = ArrayBuffer[Row]()
while (dataReader.next()) {
data.append(dataReader.get())
@@ -159,11 +159,11 @@ class RateSourceSuite extends StreamTest {
val startOffset = LongOffset(0L)
val endOffset = LongOffset(1L)
reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
- val tasks = reader.createDataReaderFactories()
+ val tasks = reader.planInputPartitions()
assert(tasks.size == 11)
val readData = tasks.asScala
- .map(_.createDataReader())
+ .map(_.createPartitionReader())
.flatMap { reader =>
val buf = scala.collection.mutable.ListBuffer[Row]()
while (reader.next()) buf.append(reader.get())
@@ -304,7 +304,7 @@ class RateSourceSuite extends StreamTest {
val reader = new RateStreamContinuousReader(
new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava))
reader.setStartOffset(Optional.empty())
- val tasks = reader.createDataReaderFactories()
+ val tasks = reader.planInputPartitions()
assert(tasks.size == 2)
val data = scala.collection.mutable.ListBuffer[Row]()
@@ -314,7 +314,7 @@ class RateSourceSuite extends StreamTest {
.asInstanceOf[RateStreamOffset]
.partitionToValueAndRunTimeMs(t.partitionIndex)
.runTimeMs
- val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader]
+ val r = t.createPartitionReader().asInstanceOf[RateStreamContinuousInputPartitionReader]
for (rowIndex <- 0 to 9) {
r.next()
data.append(r.get())
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
index 93f3efe2ccc4a..5ff1ea84d9a7b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
@@ -60,7 +60,10 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils {
test("specify sorting columns without bucketing columns") {
val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
- intercept[IllegalArgumentException](df.write.sortBy("j").saveAsTable("tt"))
+ val e = intercept[AnalysisException] {
+ df.write.sortBy("j").saveAsTable("tt")
+ }
+ assert(e.getMessage == "sortBy must be used together with bucketBy;")
}
test("sorting by non-orderable column") {
@@ -74,7 +77,16 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils {
val e = intercept[AnalysisException] {
df.write.bucketBy(2, "i").parquet("/tmp/path")
}
- assert(e.getMessage == "'save' does not support bucketing right now;")
+ assert(e.getMessage == "'save' does not support bucketBy right now;")
+ }
+
+ test("write bucketed and sorted data using save()") {
+ val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
+
+ val e = intercept[AnalysisException] {
+ df.write.bucketBy(2, "i").sortBy("i").parquet("/tmp/path")
+ }
+ assert(e.getMessage == "'save' does not support bucketBy and sortBy right now;")
}
test("write bucketed data using insertInto()") {
@@ -83,7 +95,16 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils {
val e = intercept[AnalysisException] {
df.write.bucketBy(2, "i").insertInto("tt")
}
- assert(e.getMessage == "'insertInto' does not support bucketing right now;")
+ assert(e.getMessage == "'insertInto' does not support bucketBy right now;")
+ }
+
+ test("write bucketed and sorted data using insertInto()") {
+ val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
+
+ val e = intercept[AnalysisException] {
+ df.write.bucketBy(2, "i").sortBy("i").insertInto("tt")
+ }
+ assert(e.getMessage == "'insertInto' does not support bucketBy and sortBy right now;")
}
private lazy val df = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
index e0a53272cd222..505a3f3465c02 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
@@ -346,8 +346,8 @@ class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport {
class Reader extends DataSourceReader {
override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
- override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
- java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5))
+ override def planInputPartitions(): JList[InputPartition[Row]] = {
+ java.util.Arrays.asList(new SimpleInputPartition(0, 5))
}
}
@@ -359,20 +359,21 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {
class Reader extends DataSourceReader {
override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
- override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
- java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5), new SimpleDataReaderFactory(5, 10))
+ override def planInputPartitions(): JList[InputPartition[Row]] = {
+ java.util.Arrays.asList(new SimpleInputPartition(0, 5), new SimpleInputPartition(5, 10))
}
}
override def createReader(options: DataSourceOptions): DataSourceReader = new Reader
}
-class SimpleDataReaderFactory(start: Int, end: Int)
- extends DataReaderFactory[Row]
- with DataReader[Row] {
+class SimpleInputPartition(start: Int, end: Int)
+ extends InputPartition[Row]
+ with InputPartitionReader[Row] {
private var current = start - 1
- override def createDataReader(): DataReader[Row] = new SimpleDataReaderFactory(start, end)
+ override def createPartitionReader(): InputPartitionReader[Row] =
+ new SimpleInputPartition(start, end)
override def next(): Boolean = {
current += 1
@@ -413,21 +414,21 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport {
requiredSchema
}
- override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
+ override def planInputPartitions(): JList[InputPartition[Row]] = {
val lowerBound = filters.collect {
case GreaterThan("i", v: Int) => v
}.headOption
- val res = new ArrayList[DataReaderFactory[Row]]
+ val res = new ArrayList[InputPartition[Row]]
if (lowerBound.isEmpty) {
- res.add(new AdvancedDataReaderFactory(0, 5, requiredSchema))
- res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema))
+ res.add(new AdvancedInputPartition(0, 5, requiredSchema))
+ res.add(new AdvancedInputPartition(5, 10, requiredSchema))
} else if (lowerBound.get < 4) {
- res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 5, requiredSchema))
- res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema))
+ res.add(new AdvancedInputPartition(lowerBound.get + 1, 5, requiredSchema))
+ res.add(new AdvancedInputPartition(5, 10, requiredSchema))
} else if (lowerBound.get < 9) {
- res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 10, requiredSchema))
+ res.add(new AdvancedInputPartition(lowerBound.get + 1, 10, requiredSchema))
}
res
@@ -437,13 +438,13 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport {
override def createReader(options: DataSourceOptions): DataSourceReader = new Reader
}
-class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType)
- extends DataReaderFactory[Row] with DataReader[Row] {
+class AdvancedInputPartition(start: Int, end: Int, requiredSchema: StructType)
+ extends InputPartition[Row] with InputPartitionReader[Row] {
private var current = start - 1
- override def createDataReader(): DataReader[Row] = {
- new AdvancedDataReaderFactory(start, end, requiredSchema)
+ override def createPartitionReader(): InputPartitionReader[Row] = {
+ new AdvancedInputPartition(start, end, requiredSchema)
}
override def close(): Unit = {}
@@ -468,24 +469,24 @@ class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport {
class Reader extends DataSourceReader with SupportsScanUnsafeRow {
override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
- override def createUnsafeRowReaderFactories(): JList[DataReaderFactory[UnsafeRow]] = {
- java.util.Arrays.asList(new UnsafeRowDataReaderFactory(0, 5),
- new UnsafeRowDataReaderFactory(5, 10))
+ override def planUnsafeInputPartitions(): JList[InputPartition[UnsafeRow]] = {
+ java.util.Arrays.asList(new UnsafeRowInputPartitionReader(0, 5),
+ new UnsafeRowInputPartitionReader(5, 10))
}
}
override def createReader(options: DataSourceOptions): DataSourceReader = new Reader
}
-class UnsafeRowDataReaderFactory(start: Int, end: Int)
- extends DataReaderFactory[UnsafeRow] with DataReader[UnsafeRow] {
+class UnsafeRowInputPartitionReader(start: Int, end: Int)
+ extends InputPartition[UnsafeRow] with InputPartitionReader[UnsafeRow] {
private val row = new UnsafeRow(2)
row.pointTo(new Array[Byte](8 * 3), 8 * 3)
private var current = start - 1
- override def createDataReader(): DataReader[UnsafeRow] = this
+ override def createPartitionReader(): InputPartitionReader[UnsafeRow] = this
override def next(): Boolean = {
current += 1
@@ -503,7 +504,7 @@ class UnsafeRowDataReaderFactory(start: Int, end: Int)
class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema {
class Reader(val readSchema: StructType) extends DataSourceReader {
- override def createDataReaderFactories(): JList[DataReaderFactory[Row]] =
+ override def planInputPartitions(): JList[InputPartition[Row]] =
java.util.Collections.emptyList()
}
@@ -516,16 +517,17 @@ class BatchDataSourceV2 extends DataSourceV2 with ReadSupport {
class Reader extends DataSourceReader with SupportsScanColumnarBatch {
override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
- override def createBatchDataReaderFactories(): JList[DataReaderFactory[ColumnarBatch]] = {
- java.util.Arrays.asList(new BatchDataReaderFactory(0, 50), new BatchDataReaderFactory(50, 90))
+ override def planBatchInputPartitions(): JList[InputPartition[ColumnarBatch]] = {
+ java.util.Arrays.asList(
+ new BatchInputPartitionReader(0, 50), new BatchInputPartitionReader(50, 90))
}
}
override def createReader(options: DataSourceOptions): DataSourceReader = new Reader
}
-class BatchDataReaderFactory(start: Int, end: Int)
- extends DataReaderFactory[ColumnarBatch] with DataReader[ColumnarBatch] {
+class BatchInputPartitionReader(start: Int, end: Int)
+ extends InputPartition[ColumnarBatch] with InputPartitionReader[ColumnarBatch] {
private final val BATCH_SIZE = 20
private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType)
@@ -534,7 +536,7 @@ class BatchDataReaderFactory(start: Int, end: Int)
private var current = start
- override def createDataReader(): DataReader[ColumnarBatch] = this
+ override def createPartitionReader(): InputPartitionReader[ColumnarBatch] = this
override def next(): Boolean = {
i.reset()
@@ -568,11 +570,11 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport {
class Reader extends DataSourceReader with SupportsReportPartitioning {
override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int")
- override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
+ override def planInputPartitions(): JList[InputPartition[Row]] = {
// Note that we don't have same value of column `a` across partitions.
java.util.Arrays.asList(
- new SpecificDataReaderFactory(Array(1, 1, 3), Array(4, 4, 6)),
- new SpecificDataReaderFactory(Array(2, 4, 4), Array(6, 2, 2)))
+ new SpecificInputPartitionReader(Array(1, 1, 3), Array(4, 4, 6)),
+ new SpecificInputPartitionReader(Array(2, 4, 4), Array(6, 2, 2)))
}
override def outputPartitioning(): Partitioning = new MyPartitioning
@@ -590,14 +592,14 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport {
override def createReader(options: DataSourceOptions): DataSourceReader = new Reader
}
-class SpecificDataReaderFactory(i: Array[Int], j: Array[Int])
- extends DataReaderFactory[Row]
- with DataReader[Row] {
+class SpecificInputPartitionReader(i: Array[Int], j: Array[Int])
+ extends InputPartition[Row]
+ with InputPartitionReader[Row] {
assert(i.length == j.length)
private var current = -1
- override def createDataReader(): DataReader[Row] = this
+ override def createPartitionReader(): InputPartitionReader[Row] = this
override def next(): Boolean = {
current += 1
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
index a5007fa321359..694bb3b95b0f0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
@@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path}
import org.apache.spark.SparkContext
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceReader}
+import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader}
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.SerializableConfiguration
@@ -45,7 +45,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS
class Reader(path: String, conf: Configuration) extends DataSourceReader {
override def readSchema(): StructType = schema
- override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
+ override def planInputPartitions(): JList[InputPartition[Row]] = {
val dataPath = new Path(path)
val fs = dataPath.getFileSystem(conf)
if (fs.exists(dataPath)) {
@@ -54,9 +54,9 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS
name.startsWith("_") || name.startsWith(".")
}.map { f =>
val serializableConf = new SerializableConfiguration(conf)
- new SimpleCSVDataReaderFactory(
+ new SimpleCSVInputPartitionReader(
f.getPath.toUri.toString,
- serializableConf): DataReaderFactory[Row]
+ serializableConf): InputPartition[Row]
}.toList.asJava
} else {
Collections.emptyList()
@@ -156,14 +156,14 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS
}
}
-class SimpleCSVDataReaderFactory(path: String, conf: SerializableConfiguration)
- extends DataReaderFactory[Row] with DataReader[Row] {
+class SimpleCSVInputPartitionReader(path: String, conf: SerializableConfiguration)
+ extends InputPartition[Row] with InputPartitionReader[Row] {
@transient private var lines: Iterator[String] = _
@transient private var currentLine: String = _
@transient private var inputStream: FSDataInputStream = _
- override def createDataReader(): DataReader[Row] = {
+ override def createPartitionReader(): InputPartitionReader[Row] = {
val filePath = new Path(path)
val fs = filePath.getFileSystem(conf.value)
inputStream = fs.open(filePath)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index 9d139a927bea5..f348dac1319cb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -199,15 +199,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
case class CheckAnswerRowsByFunc(
globalCheckFunction: Seq[Row] => Unit,
lastOnly: Boolean) extends StreamAction with StreamMustBeRunning {
- override def toString: String = s"$operatorName"
- private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc"
+ override def toString: String = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc"
}
case class CheckNewAnswerRows(expectedAnswer: Seq[Row])
extends StreamAction with StreamMustBeRunning {
- override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}"
-
- private def operatorName = "CheckNewAnswer"
+ override def toString: String = s"CheckNewAnswer: ${expectedAnswer.mkString(",")}"
}
object CheckNewAnswer {
@@ -218,6 +215,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
val toExternalRow = RowEncoder(encoder.schema).resolveAndBind()
CheckNewAnswerRows((data +: moreData).map(d => toExternalRow.fromRow(encoder.toRow(d))))
}
+
+ def apply(rows: Row*): CheckNewAnswerRows = CheckNewAnswerRows(rows)
}
/** Stops the stream. It must currently be running. */
@@ -747,7 +746,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
error => failTest(error)
}
}
- pos += 1
}
try {
@@ -761,8 +759,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
currentStream.asInstanceOf[MicroBatchExecution].withProgressLocked {
actns.foreach(executeAction)
}
+ pos += 1
- case action: StreamAction => executeAction(action)
+ case action: StreamAction =>
+ executeAction(action)
+ pos += 1
}
if (streamThreadDeathCause != null) {
failTest("Stream Thread Died", streamThreadDeathCause)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
index da8f9608c1e9c..1f62357e6d09e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -62,20 +62,20 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with
AddData(input1, 1),
CheckAnswer(),
AddData(input2, 1, 10), // 1 arrived on input1 first, then input2, should join
- CheckLastBatch((1, 2, 3)),
+ CheckNewAnswer((1, 2, 3)),
AddData(input1, 10), // 10 arrived on input2 first, then input1, should join
- CheckLastBatch((10, 20, 30)),
+ CheckNewAnswer((10, 20, 30)),
AddData(input2, 1), // another 1 in input2 should join with 1 input1
- CheckLastBatch((1, 2, 3)),
+ CheckNewAnswer((1, 2, 3)),
StopStream,
StartStream(),
AddData(input1, 1), // multiple 1s should be kept in state causing multiple (1, 2, 3)
- CheckLastBatch((1, 2, 3), (1, 2, 3)),
+ CheckNewAnswer((1, 2, 3), (1, 2, 3)),
StopStream,
StartStream(),
AddData(input1, 100),
AddData(input2, 100),
- CheckLastBatch((100, 200, 300))
+ CheckNewAnswer((100, 200, 300))
)
}
@@ -97,25 +97,25 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with
testStream(joined)(
AddData(input1, 1),
- CheckLastBatch(),
+ CheckNewAnswer(),
AddData(input2, 1),
- CheckLastBatch((1, 10, 2, 3)),
+ CheckNewAnswer((1, 10, 2, 3)),
StopStream,
StartStream(),
AddData(input1, 25),
- CheckLastBatch(),
+ CheckNewAnswer(),
StopStream,
StartStream(),
AddData(input2, 25),
- CheckLastBatch((25, 30, 50, 75)),
+ CheckNewAnswer((25, 30, 50, 75)),
AddData(input1, 1),
- CheckLastBatch((1, 10, 2, 3)), // State for 1 still around as there is no watermark
+ CheckNewAnswer((1, 10, 2, 3)), // State for 1 still around as there is no watermark
StopStream,
StartStream(),
AddData(input1, 5),
- CheckLastBatch(),
+ CheckNewAnswer(),
AddData(input2, 5),
- CheckLastBatch((5, 10, 10, 15)) // No filter by any watermark
+ CheckNewAnswer((5, 10, 10, 15)) // No filter by any watermark
)
}
@@ -142,27 +142,27 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with
assertNumStateRows(total = 1, updated = 1),
AddData(input2, 1),
- CheckLastBatch((1, 10, 2, 3)),
+ CheckAnswer((1, 10, 2, 3)),
assertNumStateRows(total = 2, updated = 1),
StopStream,
StartStream(),
AddData(input1, 25),
- CheckLastBatch(), // since there is only 1 watermark operator, the watermark should be 15
- assertNumStateRows(total = 3, updated = 1),
+ CheckNewAnswer(), // watermark = 15, no-data-batch should remove 2 rows having window=[0,10]
+ assertNumStateRows(total = 1, updated = 1),
AddData(input2, 25),
- CheckLastBatch((25, 30, 50, 75)), // watermark = 15 should remove 2 rows having window=[0,10]
+ CheckNewAnswer((25, 30, 50, 75)),
assertNumStateRows(total = 2, updated = 1),
StopStream,
StartStream(),
AddData(input2, 1),
- CheckLastBatch(), // Should not join as < 15 removed
- assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15
+ CheckNewAnswer(), // Should not join as < 15 removed
+ assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15
AddData(input1, 5),
- CheckLastBatch(), // Should not join or add to state as < 15 got filtered by watermark
+ CheckNewAnswer(), // Same reason as above
assertNumStateRows(total = 2, updated = 0)
)
}
@@ -189,42 +189,39 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with
AddData(leftInput, (1, 5)),
CheckAnswer(),
AddData(rightInput, (1, 11)),
- CheckLastBatch((1, 5, 11)),
+ CheckNewAnswer((1, 5, 11)),
AddData(rightInput, (1, 10)),
- CheckLastBatch(), // no match as neither 5, nor 10 from leftTime is less than rightTime 10 - 5
+ CheckNewAnswer(), // no match as leftTime 5 is not < rightTime 10 - 5
assertNumStateRows(total = 3, updated = 3),
// Increase event time watermark to 20s by adding data with time = 30s on both inputs
AddData(leftInput, (1, 3), (1, 30)),
- CheckLastBatch((1, 3, 10), (1, 3, 11)),
+ CheckNewAnswer((1, 3, 10), (1, 3, 11)),
assertNumStateRows(total = 5, updated = 2),
AddData(rightInput, (0, 30)),
- CheckLastBatch(),
- assertNumStateRows(total = 6, updated = 1),
+ CheckNewAnswer(),
// event time watermark: max event time - 10 ==> 30 - 10 = 20
+ // so left side going to only receive data where leftTime > 20
// right side state constraint: 20 < leftTime < rightTime - 5 ==> rightTime > 25
-
- // Run another batch with event time = 25 to clear right state where rightTime <= 25
- AddData(rightInput, (0, 30)),
- CheckLastBatch(),
- assertNumStateRows(total = 5, updated = 1), // removed (1, 11) and (1, 10), added (0, 30)
+ // right state where rightTime <= 25 will be cleared, (1, 11) and (1, 10) removed
+ assertNumStateRows(total = 4, updated = 1),
// New data to right input should match with left side (1, 3) and (1, 5), as left state should
// not be cleared. But rows rightTime <= 20 should be filtered due to event time watermark and
// state rows with rightTime <= 25 should be removed from state.
// (1, 20) ==> filtered by event time watermark = 20
// (1, 21) ==> passed filter, matched with left (1, 3) and (1, 5), not added to state
- // as state watermark = 25
+ // as 21 < state watermark = 25
// (1, 28) ==> passed filter, matched with left (1, 3) and (1, 5), added to state
AddData(rightInput, (1, 20), (1, 21), (1, 28)),
- CheckLastBatch((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)),
- assertNumStateRows(total = 6, updated = 1),
+ CheckNewAnswer((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)),
+ assertNumStateRows(total = 5, updated = 1),
// New data to left input with leftTime <= 20 should be filtered due to event time watermark
AddData(leftInput, (1, 20), (1, 21)),
- CheckLastBatch((1, 21, 28)),
- assertNumStateRows(total = 7, updated = 1)
+ CheckNewAnswer((1, 21, 28)),
+ assertNumStateRows(total = 6, updated = 1)
)
}
@@ -275,38 +272,39 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with
AddData(leftInput, (1, 20)),
CheckAnswer(),
AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 31)),
- CheckLastBatch((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)),
+ CheckNewAnswer((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)),
assertNumStateRows(total = 7, updated = 7),
// If rightTime = 60, then it matches only leftTime = [50, 65]
AddData(rightInput, (1, 60)),
- CheckLastBatch(), // matches with nothing on the left
+ CheckNewAnswer(), // matches with nothing on the left
AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)),
- CheckLastBatch((1, 50, 60), (1, 65, 60)),
- assertNumStateRows(total = 12, updated = 5),
+ CheckNewAnswer((1, 50, 60), (1, 65, 60)),
// Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - delay 30 = 30) = 30
// Left state value watermark = 30 - 10 = slightly less than 20 (since condition has <=)
// Should drop < 20 from left, i.e., none
// Right state value watermark = 30 - 5 = slightly less than 25 (since condition has <=)
// Should drop < 25 from the right, i.e., 14 and 15
- AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to stat
- CheckLastBatch((1, 31, 26), (1, 31, 30), (1, 31, 31)),
- assertNumStateRows(total = 11, updated = 1), // 12 - 2 removed + 1 added
+ assertNumStateRows(total = 10, updated = 5), // 12 - 2 removed
+
+ AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to state
+ CheckNewAnswer((1, 31, 26), (1, 31, 30), (1, 31, 31)),
+ assertNumStateRows(total = 11, updated = 1), // only 31 added
// Advance the watermark
AddData(rightInput, (1, 80)),
- CheckLastBatch(),
- assertNumStateRows(total = 12, updated = 1),
-
+ CheckNewAnswer(),
// Event time watermark = min(left: 66 - delay 20 = 46, right: 80 - delay 30 = 50) = 46
// Left state value watermark = 46 - 10 = slightly less than 36 (since condition has <=)
// Should drop < 36 from left, i.e., 20, 31 (30 was not added)
// Right state value watermark = 46 - 5 = slightly less than 41 (since condition has <=)
// Should drop < 41 from the right, i.e., 25, 26, 30, 31
- AddData(rightInput, (1, 50)),
- CheckLastBatch((1, 49, 50), (1, 50, 50)),
- assertNumStateRows(total = 7, updated = 1) // 12 - 6 removed + 1 added
+ assertNumStateRows(total = 6, updated = 1), // 12 - 6 removed
+
+ AddData(rightInput, (1, 46), (1, 50)), // 46 should not be processed or added to state
+ CheckNewAnswer((1, 49, 50), (1, 50, 50)),
+ assertNumStateRows(total = 7, updated = 1) // 50 added
)
}
@@ -322,7 +320,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with
input1.addData(1)
q.awaitTermination(10000)
}
- assert(e.toString.contains("Stream stream joins without equality predicate is not supported"))
+ assert(e.toString.contains("Stream-stream join without equality predicate is not supported"))
}
test("stream stream self join") {
@@ -404,10 +402,11 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with
AddData(input1, 1, 5),
AddData(input2, 1, 5, 10),
AddData(input3, 5, 10),
- CheckLastBatch((5, 10, 5, 15, 5, 25)))
+ CheckNewAnswer((5, 10, 5, 15, 5, 25)))
}
}
+
class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter {
import testImplicits._
@@ -465,13 +464,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5),
// The left rows with leftValue <= 4 should generate their outer join row now and
// not get added to the state.
- CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)),
+ CheckNewAnswer(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)),
assertNumStateRows(total = 4, updated = 4),
// We shouldn't get more outer join rows when the watermark advances.
MultiAddData(leftInput, 20)(rightInput, 21),
- CheckLastBatch(),
+ CheckNewAnswer(),
AddData(rightInput, 20),
- CheckLastBatch((20, 30, 40, "60"))
+ CheckNewAnswer((20, 30, 40, "60"))
)
}
@@ -492,15 +491,15 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
testStream(joined)(
MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3),
- // The right rows with value <= 7 should never be added to the state.
- CheckLastBatch(Row(3, 10, 6, "9")),
+ // The right rows with rightValue <= 7 should never be added to the state.
+ CheckNewAnswer(Row(3, 10, 6, "9")), // rightValue = 9 > 7 hence joined and added to state
assertNumStateRows(total = 4, updated = 4),
// When the watermark advances, we get the outer join rows just as we would if they
// were added but didn't match the full join condition.
- MultiAddData(leftInput, 20)(rightInput, 21),
- CheckLastBatch(),
+ MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch computes nulls
+ CheckNewAnswer(Row(4, 10, 8, null), Row(5, 10, 10, null)),
AddData(rightInput, 20),
- CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 10, null))
+ CheckNewAnswer(Row(20, 30, 40, "60"))
)
}
@@ -521,15 +520,15 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
testStream(joined)(
MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5),
- // The left rows with value <= 4 should never be added to the state.
- CheckLastBatch(Row(3, 10, 6, "9")),
+ // The left rows with leftValue <= 4 should never be added to the state.
+ CheckNewAnswer(Row(3, 10, 6, "9")), // leftValue = 7 > 4 hence joined and added to state
assertNumStateRows(total = 4, updated = 4),
// When the watermark advances, we get the outer join rows just as we would if they
// were added but didn't match the full join condition.
- MultiAddData(leftInput, 20)(rightInput, 21),
- CheckLastBatch(),
+ MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch computes nulls
+ CheckNewAnswer(Row(4, 10, null, "12"), Row(5, 10, null, "15")),
AddData(rightInput, 20),
- CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, null, "15"))
+ CheckNewAnswer(Row(20, 30, 40, "60"))
)
}
@@ -552,13 +551,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3),
// The right rows with rightValue <= 7 should generate their outer join row now and
// not get added to the state.
- CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")),
+ CheckNewAnswer(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")),
assertNumStateRows(total = 4, updated = 4),
// We shouldn't get more outer join rows when the watermark advances.
MultiAddData(leftInput, 20)(rightInput, 21),
- CheckLastBatch(),
+ CheckNewAnswer(),
AddData(rightInput, 20),
- CheckLastBatch((20, 30, 40, "60"))
+ CheckNewAnswer((20, 30, 40, "60"))
)
}
@@ -568,14 +567,14 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
testStream(joined)(
// Test inner part of the join.
MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7),
- CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)),
- // Old state doesn't get dropped until the batch *after* it gets introduced, so the
- // nulls won't show up until the next batch after the watermark advances.
- MultiAddData(leftInput, 21)(rightInput, 22),
- CheckLastBatch(),
- assertNumStateRows(total = 12, updated = 12),
+ CheckNewAnswer((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)),
+
+ MultiAddData(leftInput, 21)(rightInput, 22), // watermark = 11, no-data-batch computes nulls
+ CheckNewAnswer(Row(1, 10, 2, null), Row(2, 10, 4, null)),
+ assertNumStateRows(total = 2, updated = 12),
+
AddData(leftInput, 22),
- CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, null)),
+ CheckNewAnswer(Row(22, 30, 44, 66)),
assertNumStateRows(total = 3, updated = 1)
)
}
@@ -586,14 +585,14 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
testStream(joined)(
// Test inner part of the join.
MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7),
- CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)),
- // Old state doesn't get dropped until the batch *after* it gets introduced, so the
- // nulls won't show up until the next batch after the watermark advances.
- MultiAddData(leftInput, 21)(rightInput, 22),
- CheckLastBatch(),
- assertNumStateRows(total = 12, updated = 12),
+ CheckNewAnswer((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)),
+
+ MultiAddData(leftInput, 21)(rightInput, 22), // watermark = 11, no-data-batch computes nulls
+ CheckNewAnswer(Row(6, 10, null, 18), Row(7, 10, null, 21)),
+ assertNumStateRows(total = 2, updated = 12),
+
AddData(leftInput, 22),
- CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, null, 21)),
+ CheckNewAnswer(Row(22, 30, 44, 66)),
assertNumStateRows(total = 3, updated = 1)
)
}
@@ -627,21 +626,18 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
AddData(leftInput, (1, 5), (3, 5)),
CheckAnswer(),
AddData(rightInput, (1, 10), (2, 5)),
- CheckLastBatch((1, 1, 5, 10)),
+ CheckNewAnswer((1, 1, 5, 10)),
AddData(rightInput, (1, 11)),
- CheckLastBatch(), // no match as left time is too low
+ CheckNewAnswer(), // no match as left time is too low
assertNumStateRows(total = 5, updated = 5),
// Increase event time watermark to 20s by adding data with time = 30s on both inputs
AddData(leftInput, (1, 7), (1, 30)),
- CheckLastBatch((1, 1, 7, 10), (1, 1, 7, 11)),
+ CheckNewAnswer((1, 1, 7, 10), (1, 1, 7, 11)),
assertNumStateRows(total = 7, updated = 2),
- AddData(rightInput, (0, 30)),
- CheckLastBatch(),
- assertNumStateRows(total = 8, updated = 1),
- AddData(rightInput, (0, 30)),
- CheckLastBatch(outerResult),
- assertNumStateRows(total = 3, updated = 1)
+ AddData(rightInput, (0, 30)), // watermark = 30 - 10 = 20, no-data-batch computes nulls
+ CheckNewAnswer(outerResult),
+ assertNumStateRows(total = 2, updated = 1)
)
}
}
@@ -665,36 +661,41 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
testStream(joined)(
// leftValue <= 10 should generate outer join rows even though it matches right keys
MultiAddData(leftInput, 1, 2, 3)(rightInput, 1, 2, 3),
- CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)),
- MultiAddData(leftInput, 20)(rightInput, 21),
- CheckLastBatch(),
- assertNumStateRows(total = 5, updated = 5), // 1...3 added, but 20 and 21 not added
+ CheckNewAnswer(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)),
+ assertNumStateRows(total = 3, updated = 3), // only right 1, 2, 3 added
+
+ MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch cleared < 10
+ CheckNewAnswer(),
+ assertNumStateRows(total = 2, updated = 2), // only 20 and 21 left in state
+
AddData(rightInput, 20),
- CheckLastBatch(
- Row(20, 30, 40, 60)),
+ CheckNewAnswer(Row(20, 30, 40, 60)),
assertNumStateRows(total = 3, updated = 1),
+
// leftValue and rightValue both satisfying condition should not generate outer join rows
- MultiAddData(leftInput, 40, 41)(rightInput, 40, 41),
- CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)),
- MultiAddData(leftInput, 70)(rightInput, 71),
- CheckLastBatch(),
- assertNumStateRows(total = 6, updated = 6), // all inputs added since last check
+ MultiAddData(leftInput, 40, 41)(rightInput, 40, 41), // watermark = 31
+ CheckNewAnswer((40, 50, 80, 120), (41, 50, 82, 123)),
+ assertNumStateRows(total = 4, updated = 4), // only left 40, 41 + right 40,41 left in state
+
+ MultiAddData(leftInput, 70)(rightInput, 71), // watermark = 60
+ CheckNewAnswer(),
+ assertNumStateRows(total = 2, updated = 2), // only 70, 71 left in state
+
AddData(rightInput, 70),
- CheckLastBatch((70, 80, 140, 210)),
+ CheckNewAnswer((70, 80, 140, 210)),
assertNumStateRows(total = 3, updated = 1),
+
// rightValue between 300 and 1000 should generate outer join rows even though it matches left
- MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103),
- CheckLastBatch(),
+ MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103), // watermark = 91
+ CheckNewAnswer(),
+ assertNumStateRows(total = 6, updated = 3), // only 101 - 103 left in state
+
MultiAddData(leftInput, 1000)(rightInput, 1001),
- CheckLastBatch(),
- assertNumStateRows(total = 8, updated = 5), // 101...103 added, but 1000 and 1001 not added
- AddData(rightInput, 1000),
- CheckLastBatch(
- Row(1000, 1010, 2000, 3000),
+ CheckNewAnswer(
Row(101, 110, 202, null),
Row(102, 110, 204, null),
Row(103, 110, 206, null)),
- assertNumStateRows(total = 3, updated = 1)
+ assertNumStateRows(total = 2, updated = 2)
)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 0cb2375e0a49a..dcf6cb5d609ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
+import org.apache.spark.sql.sources.v2.reader.InputPartition
import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}
import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock}
import org.apache.spark.sql.types.StructType
@@ -227,10 +227,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
}
// getBatch should take 100 ms the first time it is called
- override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
+ override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = {
synchronized {
clock.waitTillTime(1350)
- super.createUnsafeRowReaderFactories()
+ super.planUnsafeInputPartitions()
}
}
}
@@ -290,13 +290,14 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
AdvanceManualClock(100), // time = 1150 to unblock getEndOffset
AssertClockTime(1150),
- AssertStreamExecThreadIsWaitingForTime(1350), // will block on createReadTasks that needs 1350
+ // will block on planInputPartitions that needs 1350
+ AssertStreamExecThreadIsWaitingForTime(1350),
AssertOnQuery(_.status.isDataAvailable === true),
AssertOnQuery(_.status.isTriggerActive === true),
AssertOnQuery(_.status.message === "Processing new data"),
AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0),
- AdvanceManualClock(200), // time = 1350 to unblock createReadTasks
+ AdvanceManualClock(200), // time = 1350 to unblock planInputPartitions
AssertClockTime(1350),
AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500
AssertOnQuery(_.status.isDataAvailable === true),
@@ -831,6 +832,21 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
CheckLastBatch(("A", 1)))
}
+ test("StreamingRelationV2/StreamingExecutionRelation/ContinuousExecutionRelation.toJSON " +
+ "should not fail") {
+ val df = spark.readStream.format("rate").load()
+ assert(df.logicalPlan.toJSON.contains("StreamingRelationV2"))
+
+ testStream(df)(
+ AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingExecutionRelation"))
+ )
+
+ testStream(df, useV2Sink = true)(
+ StartStream(trigger = Trigger.Continuous(100)),
+ AssertOnQuery(_.logicalPlan.toJSON.contains("ContinuousExecutionRelation"))
+ )
+ }
+
/** Create a streaming DF that only execute one batch in which it returns the given static DF */
private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = {
require(!triggerDF.isStreaming)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala
new file mode 100644
index 0000000000000..b7ef637f5270e
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming.continuous
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.streaming.OutputMode
+
+class ContinuousAggregationSuite extends ContinuousSuiteBase {
+ import testImplicits._
+
+ test("not enabled") {
+ val ex = intercept[AnalysisException] {
+ val input = ContinuousMemoryStream.singlePartition[Int]
+ testStream(input.toDF().agg(max('value)), OutputMode.Complete)()
+ }
+
+ assert(ex.getMessage.contains("Continuous processing does not support Aggregate operations"))
+ }
+
+ test("basic") {
+ withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) {
+ val input = ContinuousMemoryStream.singlePartition[Int]
+
+ testStream(input.toDF().agg(max('value)), OutputMode.Complete)(
+ AddData(input, 0, 1, 2),
+ CheckAnswer(2),
+ StopStream,
+ AddData(input, 3, 4, 5),
+ StartStream(),
+ CheckAnswer(5),
+ AddData(input, -1, -2, -3),
+ CheckAnswer(5))
+ }
+ }
+
+ test("repeated restart") {
+ withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) {
+ val input = ContinuousMemoryStream.singlePartition[Int]
+
+ testStream(input.toDF().agg(max('value)), OutputMode.Complete)(
+ AddData(input, 0, 1, 2),
+ CheckAnswer(2),
+ StopStream,
+ StartStream(),
+ StopStream,
+ StartStream(),
+ StopStream,
+ StartStream(),
+ AddData(input, 0),
+ CheckAnswer(2),
+ AddData(input, 5),
+ CheckAnswer(5))
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
index e755625d09e0f..e663fa8312da4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
@@ -27,8 +27,8 @@ import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext}
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.execution.streaming.continuous._
-import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, PartitionOffset}
+import org.apache.spark.sql.sources.v2.reader.InputPartition
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, PartitionOffset}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.types.{DataType, IntegerType}
@@ -51,6 +51,7 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar {
startEpoch,
spark,
SparkEnv.get)
+ EpochTracker.initializeCurrentEpoch(0)
}
override def afterEach(): Unit = {
@@ -72,8 +73,8 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar {
*/
private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = {
val queue = new ArrayBlockingQueue[UnsafeRow](1024)
- val factory = new DataReaderFactory[UnsafeRow] {
- override def createDataReader() = new ContinuousDataReader[UnsafeRow] {
+ val factory = new InputPartition[UnsafeRow] {
+ override def createPartitionReader() = new ContinuousInputPartitionReader[UnsafeRow] {
var index = -1
var curr: UnsafeRow = _
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala
index af4618bed5456..c1a28b9bc75ef 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport}
-import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
+import org.apache.spark.sql.sources.v2.reader.InputPartition
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger}
@@ -44,7 +44,7 @@ case class FakeReader() extends MicroBatchReader with ContinuousReader {
def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map())
def setStartOffset(start: Optional[Offset]): Unit = {}
- def createDataReaderFactories(): java.util.ArrayList[DataReaderFactory[Row]] = {
+ def planInputPartitions(): java.util.ArrayList[InputPartition[Row]] = {
throw new IllegalStateException("fake source - cannot actually read")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
index 14b1feb2adc20..b65058fffd339 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
@@ -276,7 +276,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
assert(LastOptions.parameters("doubleOpt") == "6.7")
}
- test("check jdbc() does not support partitioning or bucketing") {
+ test("check jdbc() does not support partitioning, bucketBy or sortBy") {
val df = spark.read.text(Utils.createTempDir(namePrefix = "text").getCanonicalPath)
var w = df.write.partitionBy("value")
@@ -287,7 +287,19 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
w = df.write.bucketBy(2, "value")
e = intercept[AnalysisException](w.jdbc(null, null, null))
- Seq("jdbc", "bucketing").foreach { s =>
+ Seq("jdbc", "does not support bucketBy right now").foreach { s =>
+ assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
+ }
+
+ w = df.write.sortBy("value")
+ e = intercept[AnalysisException](w.jdbc(null, null, null))
+ Seq("sortBy must be used together with bucketBy").foreach { s =>
+ assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
+ }
+
+ w = df.write.bucketBy(2, "value").sortBy("value")
+ e = intercept[AnalysisException](w.jdbc(null, null, null))
+ Seq("jdbc", "does not support bucketBy and sortBy right now").foreach { s =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala
index 10c9603745379..bb134bbe68bd9 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala
@@ -105,11 +105,10 @@ private[spark] object HiveUtils extends Logging {
.createWithDefault(false)
val CONVERT_METASTORE_ORC = buildConf("spark.sql.hive.convertMetastoreOrc")
- .internal()
.doc("When set to true, the built-in ORC reader and writer are used to process " +
"ORC tables created by using the HiveQL syntax, instead of Hive serde.")
.booleanConf
- .createWithDefault(false)
+ .createWithDefault(true)
val HIVE_METASTORE_SHARED_PREFIXES = buildConf("spark.sql.hive.metastore.sharedPrefixes")
.doc("A comma separated list of class prefixes that should be loaded using the classloader " +