diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index a97fab970ac6e..a05722f3d3db9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -280,7 +280,7 @@ object SparkSubmit {
if (!Utils.classIsLoadable("org.apache.spark.deploy.yarn.Client") && !Utils.isTesting) {
printErrorAndExit(
"Could not load YARN classes. " +
- "This copy of Spark may not have been compiled with YARN support.")
+ "This copy of Spark may not have been compiled with YARN support.")
}
}
@@ -296,11 +296,11 @@ object SparkSubmit {
// Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files
// too for packages that include Python code
val exclusions: Seq[String] =
- if (!StringUtils.isBlank(args.packagesExclusions)) {
- args.packagesExclusions.split(",")
- } else {
- Nil
- }
+ if (!StringUtils.isBlank(args.packagesExclusions)) {
+ args.packagesExclusions.split(",")
+ } else {
+ Nil
+ }
val resolvedMavenCoordinates = SparkSubmitUtils.resolveMavenCoordinates(args.packages,
Option(args.repositories), Option(args.ivyRepoPath), exclusions = exclusions)
if (!StringUtils.isBlank(resolvedMavenCoordinates)) {
@@ -498,17 +498,25 @@ object SparkSubmit {
if (isUserJar(args.primaryResource)) {
childClasspath += args.primaryResource
}
- if (args.jars != null) { childClasspath ++= args.jars.split(",") }
- if (args.childArgs != null) { childArgs ++= args.childArgs }
+ if (args.jars != null) {
+ childClasspath ++= args.jars.split(",")
+ }
+ if (args.childArgs != null) {
+ childArgs ++= args.childArgs
+ }
}
// Map all arguments to command-line options or system properties for our chosen mode
for (opt <- options) {
if (opt.value != null &&
- (deployMode & opt.deployMode) != 0 &&
- (clusterManager & opt.clusterManager) != 0) {
- if (opt.clOption != null) { childArgs += (opt.clOption, opt.value) }
- if (opt.sysProp != null) { sysProps.put(opt.sysProp, opt.value) }
+ (deployMode & opt.deployMode) != 0 &&
+ (clusterManager & opt.clusterManager) != 0) {
+ if (opt.clOption != null) {
+ childArgs += (opt.clOption, opt.value)
+ }
+ if (opt.sysProp != null) {
+ sysProps.put(opt.sysProp, opt.value)
+ }
}
}
@@ -532,7 +540,9 @@ object SparkSubmit {
} else {
// In legacy standalone cluster mode, use Client as a wrapper around the user class
childMainClass = "org.apache.spark.deploy.Client"
- if (args.supervise) { childArgs += "--supervise" }
+ if (args.supervise) {
+ childArgs += "--supervise"
+ }
Option(args.driverMemory).foreach { m => childArgs += ("--memory", m) }
Option(args.driverCores).foreach { c => childArgs += ("--cores", c) }
childArgs += "launch"
@@ -654,32 +664,66 @@ object SparkSubmit {
sysProps("spark.submit.pyFiles") = formattedPyFiles
}
- val (pincipal, keytab) = if (
- (args.sparkProperties.get("spark.secret.roleID").isDefined &&
- args.sparkProperties.get("spark.secret.secretID").isDefined)
- || args.sparkProperties.get("spark.secret.vault.tempToken").isDefined
- || sys.env.get("VAULT_TEMP_TOKEN").isDefined) {
-
- val vaultUrl = s"${args.sparkProperties("spark.secret.vault.protocol")}://" +
- s"${args.sparkProperties("spark.secret.vault.hosts").split(",")
- .map(host => s"$host:${args.sparkProperties("spark.secret.vault.port")}").mkString(",")}"
- val vaultToken = if (args.sparkProperties.get("spark.secret.vault.tempToken").isDefined
- || sys.env.get("VAULT_TEMP_TOKEN").isDefined) {
- VaultHelper.getRealToken(vaultUrl, args.sparkProperties.getOrElse(
- "spark.secret.vault.tempToken", sys.env("VAULT_TEMP_TOKEN")))
- } else {
- val roleID = args.sparkProperties("spark.secret.roleID")
- val secretID = args.sparkProperties("spark.secret.secretID")
- VaultHelper.getTokenFromAppRole(vaultUrl, roleID, secretID)
- }
- val environment = ConfigSecurity.prepareEnvironment(Option(vaultToken), Option(vaultUrl))
- val principal = environment.get("principal").getOrElse(args.principal)
- val keytab = environment.get("keytabPath").getOrElse(args.keytab)
+ val mesosRoleEnv = (sys.env.get("VAULT_ROLE_ID"),
+ sys.env.get("VAULT_SECRET_ID"))
- environment.foreach{case (key, value) => sysProps.put(key, value)}
- (principal, keytab)
- } else (args.principal, args.keytab)
+ val sparkRoleOpts = (args.sparkProperties.get("spark.secret.roleID"),
+ args.sparkProperties.get("spark.secret.secretID"))
+
+ val tempToken = args.sparkProperties.get("spark.secret.vault.tempToken")
+
+
+ val sysEnvToken = sys.env.get("VAULT_TEMP_TOKEN")
+
+ val vaultUrl = s"${args.sparkProperties ("spark.secret.vault.protocol")}://" +
+ s"${
+ args.sparkProperties ("spark.secret.vault.hosts").split (",")
+ .map (host => s"$host:${
+ args.sparkProperties ("spark.secret.vault.port")
+ }").mkString (",")
+ }"
+
+
+ val (pincipal, keytab) =
+ (mesosRoleEnv, sparkRoleOpts, tempToken, sysEnvToken) match {
+
+ case ((roleIdEnv, secretIdEnv), (roleIdProp, secretIdProp), _, _)
+ if ((roleIdEnv.isDefined || roleIdProp.isDefined) &&
+ (secretIdEnv.isDefined || secretIdProp.isDefined)) =>
+ //scalastyle:off
+ println("Role ID and SecretId found")
+ val roleId = roleIdEnv.getOrElse(roleIdProp.get)
+ val secretId = secretIdEnv.getOrElse(secretIdProp.get)
+ val vaultToken = VaultHelper.getTokenFromAppRole (vaultUrl, roleId, secretId)
+ val environment = ConfigSecurity.prepareEnvironment(
+ Option(vaultToken), Option (vaultUrl) )
+ val principal = environment.get ("principal").getOrElse (args.principal)
+ val keytab = environment.get ("keytabPath").getOrElse (args.keytab)
+
+ environment.foreach {
+ case (key, value) => sysProps.put (key, value)
+ }
+ (principal, keytab)
+
+ case (_, _, tempTokenProp, tempTokenEnv)
+ if (tempTokenProp.isDefined || tempTokenEnv.isDefined) =>
+ //scalastyle:off
+ println("TempToken found")
+ val tempToken = tempTokenProp.getOrElse(tempTokenEnv.get)
+ val vaultToken = VaultHelper.getRealToken (vaultUrl, tempToken)
+ val environment = ConfigSecurity.prepareEnvironment(
+ Option (vaultToken), Option (vaultUrl))
+ val principal = environment.get ("principal").getOrElse (args.principal)
+ val keytab = environment.get ("keytabPath").getOrElse (args.keytab)
+
+ environment.foreach {
+ case (key, value) => sysProps.put (key, value)
+ }
+ (principal, keytab)
+
+ case _ => (args.principal, args.keytab)
+ }
(childArgs, childClasspath, sysProps, childMainClass, pincipal, keytab)
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 827a2cca841b0..8fe9203a2698c 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -24,7 +24,6 @@ import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable
import scala.util.{Failure, Success}
import scala.util.control.NonFatal
-
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.SparkHadoopUtil
@@ -33,7 +32,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rpc._
import org.apache.spark.scheduler.{ExecutorLossReason, TaskDescription}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
-import org.apache.spark.security.ConfigSecurity
+import org.apache.spark.security.{ConfigSecurity, VaultHelper}
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.{ThreadUtils, Utils}
@@ -288,7 +287,10 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
appId == null) {
printUsageAndExit()
}
- ConfigSecurity.prepareEnvironment(Option(System.getenv("VAULT_TEMP_TOKEN")))
+ ConfigSecurity.prepareEnvironment(scala.util.Try{
+ VaultHelper.getRealToken(ConfigSecurity.vaultUri.get,
+ sys.env("VAULT_TEMP_TOKEN"))}.toOption)
+
run(driverUrl, executorId, hostname, cores, appId, workerUrl, userClassPath)
System.exit(0)
}
diff --git a/core/src/main/scala/org/apache/spark/security/ConfigSecurity.scala b/core/src/main/scala/org/apache/spark/security/ConfigSecurity.scala
index 63f493a267952..e90bb40666d33 100644
--- a/core/src/main/scala/org/apache/spark/security/ConfigSecurity.scala
+++ b/core/src/main/scala/org/apache/spark/security/ConfigSecurity.scala
@@ -21,12 +21,16 @@ import org.apache.spark.internal.Logging
object ConfigSecurity extends Logging{
var vaultToken: Option[String] = None
- val vaultHost: Option[String] = sys.env.get("VAULT_HOSTS")
+ val vaultHost: Option[String] = sys.env.get("VAULT_HOST")
val vaultUri: Option[String] = {
(sys.env.get("VAULT_PROTOCOL"), vaultHost, sys.env.get("VAULT_PORT")) match {
case (Some(vaultProtocol), Some(vaultHost), Some(vaultPort)) =>
- Option(s"$vaultProtocol://$vaultHost:$vaultPort")
- case _ => None
+ val vaultUri = s"$vaultProtocol://$vaultHost:$vaultPort"
+ logDebug(s"vault uri: $vaultUri found, any Vault Connection will use it")
+ Option(vaultUri)
+ case _ =>
+ logDebug("No Vault information found, any Vault Connection will fail")
+ None
}
}
@@ -37,7 +41,9 @@ object ConfigSecurity extends Logging{
val secretOptionsMap = ConfigSecurity.extractSecretFromEnv(sys.env)
logDebug(s"secretOptionsMap: ${secretOptionsMap.mkString("\n")}")
loadingConf(secretOptionsMap)
- vaultToken = if (vaultToken.isDefined) vaultAppToken else Option(System.getenv("VAULT_TOKEN"))
+ vaultToken = if (vaultAppToken.isDefined) {
+ vaultAppToken
+ } else sys.env.get("VAULT_TOKEN")
if(vaultToken.isDefined) {
require(vaultUri.isDefined, "A proper vault host is required")
logDebug(s"env VAR: ${sys.env.mkString("\n")}")
diff --git a/core/src/main/scala/org/apache/spark/security/SSLConfig.scala b/core/src/main/scala/org/apache/spark/security/SSLConfig.scala
index ff0a6703cfd0d..5ec4b5da0a3ac 100644
--- a/core/src/main/scala/org/apache/spark/security/SSLConfig.scala
+++ b/core/src/main/scala/org/apache/spark/security/SSLConfig.scala
@@ -28,7 +28,7 @@ import sun.security.util.DerInputStream
import org.apache.spark.internal.Logging
-object SSLConfig extends Logging{
+object SSLConfig extends Logging {
val sslTypeDataStore = "DATASTORE"
val sslTypeKafkaStore = "KAFKA"
@@ -37,19 +37,23 @@ object SSLConfig extends Logging{
vaultToken: String,
sslType: String,
options: Map[String, String]): Map[String, String] = {
- val rootCA = VaultHelper.getRootCA(vaultHost, vaultToken)
- val rootCAPath = writeRootCA(rootCA)
- val certPass = VaultHelper.getCertPassFromVault(vaultHost, vaultToken)
- val trustStorePath = generateTrustStore(sslType, rootCA, certPass)
+
+ val sparkSSLPrefix = "spark.ssl."
+
+ val vaultTrustStorePath = options.get(s"${sslType}_VAULT_TRUSTSTORE_PATH")
+ val vaultTrustStorePassPath = options.get(s"${sslType}_VAULT_TRUSTSTORE_PASS_PATH")
+ val trustStore = VaultHelper.getTrustStore(vaultHost, vaultToken, vaultTrustStorePath.get)
+ val trustPass = VaultHelper.getCertPassForAppFromVault(
+ vaultHost, vaultTrustStorePassPath.get, vaultToken)
+ val trustStorePath = generateTrustStore(sslType, trustStore, trustPass)
logInfo(s"Setting SSL values for $sslType")
val trustStoreOptions =
- Map(s"spark.ssl.${sslType.toLowerCase}.enabled" -> "true",
- s"spark.ssl.${sslType.toLowerCase}.trustStore" -> trustStorePath,
- s"spark.ssl.${sslType.toLowerCase}.trustStorePassword" -> certPass,
- s"spark.ssl.${sslType.toLowerCase}.rootCaPath" -> rootCAPath,
- s"spark.ssl.${sslType.toLowerCase}.security.protocol" -> "SSL")
+ Map(s"$sparkSSLPrefix${sslType.toLowerCase}.enabled" -> "true",
+ s"$sparkSSLPrefix${sslType.toLowerCase}.trustStore" -> trustStorePath,
+ s"$sparkSSLPrefix${sslType.toLowerCase}.trustStorePassword" -> trustPass,
+ s"$sparkSSLPrefix${sslType.toLowerCase}.security.protocol" -> "SSL")
val vaultKeystorePath = options.get(s"${sslType}_VAULT_CERT_PATH")
@@ -65,10 +69,10 @@ object SSLConfig extends Logging{
val keyStorePath = generateKeyStore(sslType, certs, key, pass)
- Map(s"spark.ssl${sslType.toLowerCase}.keyStore" -> keyStorePath,
- s"spark.ssl${sslType.toLowerCase}.keyStorePassword" -> pass,
- s"spark.ssl${sslType.toLowerCase}.protocol" -> "TLSv1.2",
- s"spark.ssl${sslType.toLowerCase}.needClientAuth" -> "true"
+ Map(s"$sparkSSLPrefix${sslType.toLowerCase}.keyStore" -> keyStorePath,
+ s"$sparkSSLPrefix${sslType.toLowerCase}.keyStorePassword" -> pass,
+ s"$sparkSSLPrefix${sslType.toLowerCase}.protocol" -> "TLSv1.2",
+ s"$sparkSSLPrefix${sslType.toLowerCase}.needClientAuth" -> "true"
)
} else {
@@ -79,13 +83,13 @@ object SSLConfig extends Logging{
val vaultKeyPassPath = options.get(s"${sslType}_VAULT_KEY_PASS_PATH")
- val keyPass = Map(s"spark.ssl.${sslType.toLowerCase}.keyPassword"
+ val keyPass = Map(s"$sparkSSLPrefix${sslType.toLowerCase}.keyPassword"
-> VaultHelper.getCertPassForAppFromVault(vaultHost, vaultKeyPassPath.get, vaultToken))
val certFilesPath =
- Map("spark.ssl.cert.path" -> s"${sys.env.get("SPARK_SSL_CERT_PATH")}/cert.crt",
- "spark.ssl.key.pkcs8" -> s"${sys.env.get("SPARK_SSL_CERT_PATH")}/key.pkcs8",
- "spark.ssl.root.cert" -> s"${sys.env.get("SPARK_SSL_CERT_PATH")}/caroot.crt")
+ Map(sparkSSLPrefix + "cert.path" -> s"${sys.env.get("SPARK_SSL_CERT_PATH")}/cert.crt",
+ sparkSSLPrefix + "key.pkcs8" -> s"${sys.env.get("SPARK_SSL_CERT_PATH")}/key.pkcs8",
+ sparkSSLPrefix + "root.cert" -> s"${sys.env.get("SPARK_SSL_CERT_PATH")}/caroot.crt")
trustStoreOptions ++ keyStoreOptions ++ keyPass ++ certFilesPath
}
@@ -96,7 +100,7 @@ object SSLConfig extends Logging{
keystore.load(null)
val certs = getBase64FromCAs(cas)
- certs.zipWithIndex.foreach{case (cert, index) =>
+ certs.zipWithIndex.foreach { case (cert, index) =>
val key = s"cert-${index}"
keystore.setCertificateEntry(key, generateCertificateFromDER(cert))
}
@@ -161,7 +165,7 @@ object SSLConfig extends Logging{
val certs = getBase64FromCAs(cas)
val arrayCert = certs.map(cert => generateCertificateFromDER(cert))
val alias = "key-alias"
- keystore.setKeyEntry(alias, key, password.toCharArray, arrayCert )
+ keystore.setKeyEntry(alias, key, password.toCharArray, arrayCert)
val fileName = "keystore.jks"
val dir = new File(s"/tmp/$sslType")
@@ -180,7 +184,7 @@ object SSLConfig extends Logging{
CertificateFactory.getInstance("X.509").generateCertificate(new ByteArrayInputStream(certBytes))
private def getArrayFromCA(ca: String): Array[String] = {
- val splittedBy = ca.takeWhile(_=='-')
+ val splittedBy = ca.takeWhile(_ == '-')
val begin = s"$splittedBy${ca.split(splittedBy).tail.head}$splittedBy"
val end = begin.replace("BEGIN", "END")
ca.split(begin).tail.map(_.split(end).head)
@@ -192,29 +196,4 @@ object SSLConfig extends Logging{
DatatypeConverter.parseBase64Binary(value)
})
}
-
- def writeRootCA(rootCA: String): String = {
- def getCertFromOnLine(certBadFormat: String): String = {
- var text1 = certBadFormat
- var arg = Seq[String]()
- while (text1.size != 0) {
- val (toStore, toUpdate) = text1.splitAt(64)
- text1 = toUpdate
- arg = arg ++ Seq(toStore)
- }
- arg.mkString("\n")
- }
-
- val path = "/tmp/root.crt"
- val splitter = rootCA.split("BEGIN").head
- val Array(_, head, certBadFormat, tail) = rootCA.split(splitter)
- val cert = getCertFromOnLine(certBadFormat)
- val writableCert = Seq(s"$splitter$head$splitter", cert, s"$splitter$tail$splitter")
- .mkString("\n")
- val downloadFile = Files.createFile(Paths.get(path),
- PosixFilePermissions.asFileAttribute(PosixFilePermissions.fromString("rw-------")))
- downloadFile.toFile.deleteOnExit()
- Files.write(downloadFile, writableCert.getBytes)
- path
- }
}
diff --git a/core/src/main/scala/org/apache/spark/security/VaultHelper.scala b/core/src/main/scala/org/apache/spark/security/VaultHelper.scala
index 724a24b12f889..23b9c3f6484e7 100644
--- a/core/src/main/scala/org/apache/spark/security/VaultHelper.scala
+++ b/core/src/main/scala/org/apache/spark/security/VaultHelper.scala
@@ -66,7 +66,7 @@ object VaultHelper extends Logging {
logDebug(s"Requesting Secret ID from Vault: $requestUrl")
HTTPHelper.executePost(requestUrl, "data",
Some(Seq(("X-Vault-Token", token.get))))("secret_id").asInstanceOf[String]
- }
+ }
def getTemporalToken(vaultHost: String, token: String): String = {
val requestUrl = s"$vaultHost/v1/sys/wrapping/wrap"
@@ -90,6 +90,7 @@ object VaultHelper extends Logging {
(keytab64, principal)
}
+ @deprecated
def getRootCA(vaultUrl: String, token: String): String = {
val certVaultPath = "/v1/ca-trust/certificates/"
val requestUrl = s"$vaultUrl/$certVaultPath"
@@ -97,7 +98,7 @@ object VaultHelper extends Logging {
logDebug(s"Requesting Cert List: $listCertKeysVaultPath")
val keys = HTTPHelper.executeGet(listCertKeysVaultPath,
- "data", Some(Seq(("X-Vault-Token", token))))("pass").asInstanceOf[List[String]]
+ "data", Some(Seq(("X-Vault-Token", token))))("keys").asInstanceOf[List[String]]
keys.flatMap(key => {
HTTPHelper.executeGet(s"$requestUrl$key",
@@ -105,6 +106,17 @@ object VaultHelper extends Logging {
}).map(_._2).mkString
}
+ def getTrustStore(vaultUrl: String, token: String, certVaultPath: String): String = {
+ val requestUrl = s"$vaultUrl/$certVaultPath"
+ val truststoreVaultPath = s"$requestUrl"
+
+ logDebug(s"Requesting truststore: $truststoreVaultPath")
+ val data = HTTPHelper.executeGet(requestUrl,
+ "data", Some(Seq(("X-Vault-Token", token))))
+ val trustStore = data.find(_._1.endsWith("_crt")).get._2.asInstanceOf[String]
+ trustStore
+ }
+
def getCertPassFromVault(vaultUrl: String, token: String): String = {
val certPassVaultPath = "/v1/ca-trust/passwords/default/keystore"
logDebug(s"Requesting Cert Pass: $certPassVaultPath")
diff --git a/pom.xml b/pom.xml
index c46dc1bcd3505..4f155eaf529ee 100644
--- a/pom.xml
+++ b/pom.xml
@@ -2336,7 +2336,7 @@
0.8.0
false
- true
+ false
false
false
${basedir}/src/main/scala