Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 2.2 merge #232

Merged
merged 14 commits into from
May 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
(New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf)
(The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net)
(The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net)
(The New BSD License) Py4J (net.sf.py4j:py4j:0.10.4 - http://py4j.sourceforge.net/)
(The New BSD License) Py4J (net.sf.py4j:py4j:0.10.7 - http://py4j.sourceforge.net/)
(Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/)
(BSD licence) sbt and sbt-launch-lib.bash
(BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE)
Expand Down
4 changes: 2 additions & 2 deletions R/pkg/R/client.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# Creates a SparkR client connection object
# if one doesn't already exist
connectBackend <- function(hostname, port, timeout) {
connectBackend <- function(hostname, port, timeout, authSecret) {
if (exists(".sparkRcon", envir = .sparkREnv)) {
if (isOpen(.sparkREnv[[".sparkRCon"]])) {
cat("SparkRBackend client connection already exists\n")
Expand All @@ -29,7 +29,7 @@ connectBackend <- function(hostname, port, timeout) {

con <- socketConnection(host = hostname, port = port, server = FALSE,
blocking = TRUE, open = "wb", timeout = timeout)

doServerAuth(con, authSecret)
assign(".sparkRCon", con, envir = .sparkREnv)
con
}
Expand Down
10 changes: 7 additions & 3 deletions R/pkg/R/deserialize.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,18 @@ readTypedObject <- function(con, type) {
stop(paste("Unsupported type for deserialization", type)))
}

readString <- function(con) {
stringLen <- readInt(con)
raw <- readBin(con, raw(), stringLen, endian = "big")
readStringData <- function(con, len) {
raw <- readBin(con, raw(), len, endian = "big")
string <- rawToChar(raw)
Encoding(string) <- "UTF-8"
string
}

readString <- function(con) {
stringLen <- readInt(con)
readStringData(con, stringLen)
}

readInt <- function(con) {
readBin(con, integer(), n = 1, endian = "big")
}
Expand Down
39 changes: 34 additions & 5 deletions R/pkg/R/sparkR.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ sparkR.sparkContext <- function(
" please use the --packages commandline instead", sep = ","))
}
backendPort <- existingPort
authSecret <- Sys.getenv("SPARKR_BACKEND_AUTH_SECRET")
if (nchar(authSecret) == 0) {
stop("Auth secret not provided in environment.")
}
} else {
path <- tempfile(pattern = "backend_port")
submitOps <- getClientModeSparkSubmitOpts(
Expand Down Expand Up @@ -189,16 +193,27 @@ sparkR.sparkContext <- function(
monitorPort <- readInt(f)
rLibPath <- readString(f)
connectionTimeout <- readInt(f)

# Don't use readString() so that we can provide a useful
# error message if the R and Java versions are mismatched.
authSecretLen = readInt(f)
if (length(authSecretLen) == 0 || authSecretLen == 0) {
stop("Unexpected EOF in JVM connection data. Mismatched versions?")
}
authSecret <- readStringData(f, authSecretLen)
close(f)
file.remove(path)
if (length(backendPort) == 0 || backendPort == 0 ||
length(monitorPort) == 0 || monitorPort == 0 ||
length(rLibPath) != 1) {
length(rLibPath) != 1 || length(authSecret) == 0) {
stop("JVM failed to launch")
}
assign(".monitorConn",
socketConnection(port = monitorPort, timeout = connectionTimeout),
envir = .sparkREnv)

monitorConn <- socketConnection(port = monitorPort, blocking = TRUE,
timeout = connectionTimeout, open = "wb")
doServerAuth(monitorConn, authSecret)

assign(".monitorConn", monitorConn, envir = .sparkREnv)
assign(".backendLaunched", 1, envir = .sparkREnv)
if (rLibPath != "") {
assign(".libPath", rLibPath, envir = .sparkREnv)
Expand All @@ -208,7 +223,7 @@ sparkR.sparkContext <- function(

.sparkREnv$backendPort <- backendPort
tryCatch({
connectBackend("localhost", backendPort, timeout = connectionTimeout)
connectBackend("localhost", backendPort, timeout = connectionTimeout, authSecret = authSecret)
},
error = function(err) {
stop("Failed to connect JVM\n")
Expand Down Expand Up @@ -632,3 +647,17 @@ sparkCheckInstall <- function(sparkHome, master, deployMode) {
NULL
}
}

# Utility function for sending auth data over a socket and checking the server's reply.
doServerAuth <- function(con, authSecret) {
if (nchar(authSecret) == 0) {
stop("Auth secret not provided.")
}
writeString(con, authSecret)
flush(con)
reply <- readString(con)
if (reply != "ok") {
close(con)
stop("Unexpected reply from server.")
}
}
4 changes: 3 additions & 1 deletion R/pkg/inst/worker/daemon.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ suppressPackageStartupMessages(library(SparkR))

port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
inputCon <- socketConnection(
port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout)
port = port, open = "wb", blocking = TRUE, timeout = connectionTimeout)

SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET"))

while (TRUE) {
ready <- socketSelect(list(inputCon))
Expand Down
5 changes: 4 additions & 1 deletion R/pkg/inst/worker/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,12 @@ suppressPackageStartupMessages(library(SparkR))

port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
inputCon <- socketConnection(
port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout)
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET"))

outputCon <- socketConnection(
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
SparkR:::doServerAuth(outputCon, Sys.getenv("SPARKR_WORKER_SECRET"))

# read the index of the current partition inside the RDD
partition <- SparkR:::readInt(inputCon)
Expand Down
6 changes: 3 additions & 3 deletions bin/pyspark
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ source "${SPARK_HOME}"/bin/load-spark-env.sh
export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]"

# In Spark 2.0, IPYTHON and IPYTHON_OPTS are removed and pyspark fails to launch if either option
# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython
# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython
# to use IPython and set PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver
# (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython
# and executor Python executables.

# Fail noisily if removed options are set
if [[ -n "$IPYTHON" || -n "$IPYTHON_OPTS" ]]; then
echo "Error in pyspark startup:"
echo "Error in pyspark startup:"
echo "IPYTHON and IPYTHON_OPTS are removed in Spark 2.0+. Remove these from the environment and set PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS instead."
exit 1
fi
Expand All @@ -57,7 +57,7 @@ export PYSPARK_PYTHON

# Add the PySpark classes to the Python path:
export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH"
export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.4-src.zip:$PYTHONPATH"
export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:$PYTHONPATH"

# Load the PySpark shell.py script when ./pyspark is used interactively:
export OLD_PYTHONSTARTUP="$PYTHONSTARTUP"
Expand Down
2 changes: 1 addition & 1 deletion bin/pyspark2.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" (
)

set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH%
set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.4-src.zip;%PYTHONPATH%
set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.7-src.zip;%PYTHONPATH%

set OLD_PYTHONSTARTUP=%PYTHONSTARTUP%
set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py
Expand Down
2 changes: 1 addition & 1 deletion core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@
<dependency>
<groupId>net.sf.py4j</groupId>
<artifactId>py4j</artifactId>
<version>0.10.4</version>
<version>0.10.7</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
Expand Down
11 changes: 2 additions & 9 deletions core/src/main/scala/org/apache/spark/SecurityManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@

package org.apache.spark

import java.lang.{Byte => JByte}
import java.net.{Authenticator, PasswordAuthentication}
import java.security.{KeyStore, SecureRandom}
import java.security.KeyStore
import java.security.cert.X509Certificate
import javax.net.ssl._

import com.google.common.hash.HashCodes
import com.google.common.io.Files
import org.apache.hadoop.io.Text

Expand Down Expand Up @@ -435,12 +433,7 @@ private[spark] class SecurityManager(
val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(SECRET_LOOKUP_KEY)
if (secretKey == null || secretKey.length == 0) {
logDebug("generateSecretKey: yarn mode, secret key from credentials is null")
val rnd = new SecureRandom()
val length = sparkConf.getInt("spark.authenticate.secretBitLength", 256) / JByte.SIZE
val secret = new Array[Byte](length)
rnd.nextBytes(secret)

val cookie = HashCodes.fromBytes(secret).toString()
val cookie = Utils.createSecret(sparkConf)
SparkHadoopUtil.get.addSecretKeyToUserCredentials(SECRET_LOOKUP_KEY, cookie)
cookie
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,39 @@

package org.apache.spark.api.python

import java.io.DataOutputStream
import java.net.Socket
import java.io.{DataOutputStream, File, FileOutputStream}
import java.net.InetAddress
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.Files

import py4j.GatewayServer

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils

/**
* Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port
* back to its caller via a callback port specified by the caller.
* Process that starts a Py4J GatewayServer on an ephemeral port.
*
* This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py).
*/
private[spark] object PythonGatewayServer extends Logging {
initializeLogIfNecessary(true)

def main(args: Array[String]): Unit = Utils.tryOrExit {
// Start a GatewayServer on an ephemeral port
val gatewayServer: GatewayServer = new GatewayServer(null, 0)
def main(args: Array[String]): Unit = {
val secret = Utils.createSecret(new SparkConf())

// Start a GatewayServer on an ephemeral port. Make sure the callback client is configured
// with the same secret, in case the app needs callbacks from the JVM to the underlying
// python processes.
val localhost = InetAddress.getLoopbackAddress()
val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder()
.authToken(secret)
.javaPort(0)
.javaAddress(localhost)
.callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
.build()

gatewayServer.start()
val boundPort: Int = gatewayServer.getListeningPort
if (boundPort == -1) {
Expand All @@ -46,15 +59,24 @@ private[spark] object PythonGatewayServer extends Logging {
logDebug(s"Started PythonGatewayServer on port $boundPort")
}

// Communicate the bound port back to the caller via the caller-specified callback port
val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
val callbackSocket = new Socket(callbackHost, callbackPort)
val dos = new DataOutputStream(callbackSocket.getOutputStream)
// Communicate the connection information back to the python process by writing the
// information in the requested file. This needs to match the read side in java_gateway.py.
val connectionInfoPath = new File(sys.env("_PYSPARK_DRIVER_CONN_INFO_PATH"))
val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(),
"connection", ".info").toFile()

val dos = new DataOutputStream(new FileOutputStream(tmpPath))
dos.writeInt(boundPort)

val secretBytes = secret.getBytes(UTF_8)
dos.writeInt(secretBytes.length)
dos.write(secretBytes, 0, secretBytes.length)
dos.close()
callbackSocket.close()

if (!tmpPath.renameTo(connectionInfoPath)) {
logError(s"Unable to write connection information to $connectionInfoPath.")
System.exit(1)
}

// Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
while (System.in.read() != -1) {
Expand Down
29 changes: 22 additions & 7 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util._


Expand Down Expand Up @@ -421,6 +422,12 @@ private[spark] object PythonRDD extends Logging {
// remember the broadcasts sent to each worker
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()

// Authentication helper used when serving iterator data.
private lazy val authHelper = {
val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
new SocketAuthHelper(conf)
}

def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = {
synchronized {
workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
Expand All @@ -443,12 +450,13 @@ private[spark] object PythonRDD extends Logging {
* (effectively a collect()), but allows you to run on a certain subset of partitions,
* or to enable local execution.
*
* @return the port number of a local socket which serves the data collected from this job.
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, and the secret for authentication.
*/
def runJob(
sc: SparkContext,
rdd: JavaRDD[Array[Byte]],
partitions: JArrayList[Int]): Int = {
partitions: JArrayList[Int]): Array[Any] = {
type ByteArray = Array[Byte]
type UnrolledPartition = Array[ByteArray]
val allPartitions: Array[UnrolledPartition] =
Expand All @@ -461,13 +469,14 @@ private[spark] object PythonRDD extends Logging {
/**
* A helper function to collect an RDD as an iterator, then serve it via socket.
*
* @return the port number of a local socket which serves the data collected from this job.
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, and the secret for authentication.
*/
def collectAndServe[T](rdd: RDD[T]): Int = {
def collectAndServe[T](rdd: RDD[T]): Array[Any] = {
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}

def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = {
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
}

Expand Down Expand Up @@ -698,8 +707,11 @@ private[spark] object PythonRDD extends Logging {
* and send them into this connection.
*
* The thread will terminate after all the data are sent or any exceptions happen.
*
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, and the secret for authentication.
*/
def serveIterator[T](items: Iterator[T], threadName: String): Int = {
def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
// Close the socket if no connection in 15 seconds
serverSocket.setSoTimeout(15000)
Expand All @@ -709,11 +721,14 @@ private[spark] object PythonRDD extends Logging {
override def run() {
try {
val sock = serverSocket.accept()
authHelper.authClient(sock)

val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
Utils.tryWithSafeFinally {
writeIteratorToStream(items, out)
} {
out.close()
sock.close()
}
} catch {
case NonFatal(e) =>
Expand All @@ -724,7 +739,7 @@ private[spark] object PythonRDD extends Logging {
}
}.start()

serverSocket.getLocalPort
Array(serverSocket.getLocalPort, authHelper.secret)
}

private def getMergedConf(confAsMap: java.util.HashMap[String, String],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ private[spark] object PythonUtils {
val pythonPath = new ArrayBuffer[String]
for (sparkHome <- sys.env.get("SPARK_HOME")) {
pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator)
pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.4-src.zip").mkString(File.separator)
pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.7-src.zip").mkString(File.separator)
}
pythonPath ++= SparkContext.jarOfObject(this)
pythonPath.mkString(File.pathSeparator)
Expand Down
Loading