diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index a97fab970ac6e..a05722f3d3db9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -280,7 +280,7 @@ object SparkSubmit { if (!Utils.classIsLoadable("org.apache.spark.deploy.yarn.Client") && !Utils.isTesting) { printErrorAndExit( "Could not load YARN classes. " + - "This copy of Spark may not have been compiled with YARN support.") + "This copy of Spark may not have been compiled with YARN support.") } } @@ -296,11 +296,11 @@ object SparkSubmit { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code val exclusions: Seq[String] = - if (!StringUtils.isBlank(args.packagesExclusions)) { - args.packagesExclusions.split(",") - } else { - Nil - } + if (!StringUtils.isBlank(args.packagesExclusions)) { + args.packagesExclusions.split(",") + } else { + Nil + } val resolvedMavenCoordinates = SparkSubmitUtils.resolveMavenCoordinates(args.packages, Option(args.repositories), Option(args.ivyRepoPath), exclusions = exclusions) if (!StringUtils.isBlank(resolvedMavenCoordinates)) { @@ -498,17 +498,25 @@ object SparkSubmit { if (isUserJar(args.primaryResource)) { childClasspath += args.primaryResource } - if (args.jars != null) { childClasspath ++= args.jars.split(",") } - if (args.childArgs != null) { childArgs ++= args.childArgs } + if (args.jars != null) { + childClasspath ++= args.jars.split(",") + } + if (args.childArgs != null) { + childArgs ++= args.childArgs + } } // Map all arguments to command-line options or system properties for our chosen mode for (opt <- options) { if (opt.value != null && - (deployMode & opt.deployMode) != 0 && - (clusterManager & opt.clusterManager) != 0) { - if (opt.clOption != null) { childArgs += (opt.clOption, opt.value) } - if (opt.sysProp != null) { sysProps.put(opt.sysProp, opt.value) } + (deployMode & opt.deployMode) != 0 && + (clusterManager & opt.clusterManager) != 0) { + if (opt.clOption != null) { + childArgs += (opt.clOption, opt.value) + } + if (opt.sysProp != null) { + sysProps.put(opt.sysProp, opt.value) + } } } @@ -532,7 +540,9 @@ object SparkSubmit { } else { // In legacy standalone cluster mode, use Client as a wrapper around the user class childMainClass = "org.apache.spark.deploy.Client" - if (args.supervise) { childArgs += "--supervise" } + if (args.supervise) { + childArgs += "--supervise" + } Option(args.driverMemory).foreach { m => childArgs += ("--memory", m) } Option(args.driverCores).foreach { c => childArgs += ("--cores", c) } childArgs += "launch" @@ -654,32 +664,66 @@ object SparkSubmit { sysProps("spark.submit.pyFiles") = formattedPyFiles } - val (pincipal, keytab) = if ( - (args.sparkProperties.get("spark.secret.roleID").isDefined && - args.sparkProperties.get("spark.secret.secretID").isDefined) - || args.sparkProperties.get("spark.secret.vault.tempToken").isDefined - || sys.env.get("VAULT_TEMP_TOKEN").isDefined) { - - val vaultUrl = s"${args.sparkProperties("spark.secret.vault.protocol")}://" + - s"${args.sparkProperties("spark.secret.vault.hosts").split(",") - .map(host => s"$host:${args.sparkProperties("spark.secret.vault.port")}").mkString(",")}" - val vaultToken = if (args.sparkProperties.get("spark.secret.vault.tempToken").isDefined - || sys.env.get("VAULT_TEMP_TOKEN").isDefined) { - VaultHelper.getRealToken(vaultUrl, args.sparkProperties.getOrElse( - "spark.secret.vault.tempToken", sys.env("VAULT_TEMP_TOKEN"))) - } else { - val roleID = args.sparkProperties("spark.secret.roleID") - val secretID = args.sparkProperties("spark.secret.secretID") - VaultHelper.getTokenFromAppRole(vaultUrl, roleID, secretID) - } - val environment = ConfigSecurity.prepareEnvironment(Option(vaultToken), Option(vaultUrl)) - val principal = environment.get("principal").getOrElse(args.principal) - val keytab = environment.get("keytabPath").getOrElse(args.keytab) + val mesosRoleEnv = (sys.env.get("VAULT_ROLE_ID"), + sys.env.get("VAULT_SECRET_ID")) - environment.foreach{case (key, value) => sysProps.put(key, value)} - (principal, keytab) - } else (args.principal, args.keytab) + val sparkRoleOpts = (args.sparkProperties.get("spark.secret.roleID"), + args.sparkProperties.get("spark.secret.secretID")) + + val tempToken = args.sparkProperties.get("spark.secret.vault.tempToken") + + + val sysEnvToken = sys.env.get("VAULT_TEMP_TOKEN") + + val vaultUrl = s"${args.sparkProperties ("spark.secret.vault.protocol")}://" + + s"${ + args.sparkProperties ("spark.secret.vault.hosts").split (",") + .map (host => s"$host:${ + args.sparkProperties ("spark.secret.vault.port") + }").mkString (",") + }" + + + val (pincipal, keytab) = + (mesosRoleEnv, sparkRoleOpts, tempToken, sysEnvToken) match { + + case ((roleIdEnv, secretIdEnv), (roleIdProp, secretIdProp), _, _) + if ((roleIdEnv.isDefined || roleIdProp.isDefined) && + (secretIdEnv.isDefined || secretIdProp.isDefined)) => + //scalastyle:off + println("Role ID and SecretId found") + val roleId = roleIdEnv.getOrElse(roleIdProp.get) + val secretId = secretIdEnv.getOrElse(secretIdProp.get) + val vaultToken = VaultHelper.getTokenFromAppRole (vaultUrl, roleId, secretId) + val environment = ConfigSecurity.prepareEnvironment( + Option(vaultToken), Option (vaultUrl) ) + val principal = environment.get ("principal").getOrElse (args.principal) + val keytab = environment.get ("keytabPath").getOrElse (args.keytab) + + environment.foreach { + case (key, value) => sysProps.put (key, value) + } + (principal, keytab) + + case (_, _, tempTokenProp, tempTokenEnv) + if (tempTokenProp.isDefined || tempTokenEnv.isDefined) => + //scalastyle:off + println("TempToken found") + val tempToken = tempTokenProp.getOrElse(tempTokenEnv.get) + val vaultToken = VaultHelper.getRealToken (vaultUrl, tempToken) + val environment = ConfigSecurity.prepareEnvironment( + Option (vaultToken), Option (vaultUrl)) + val principal = environment.get ("principal").getOrElse (args.principal) + val keytab = environment.get ("keytabPath").getOrElse (args.keytab) + + environment.foreach { + case (key, value) => sysProps.put (key, value) + } + (principal, keytab) + + case _ => (args.principal, args.keytab) + } (childArgs, childClasspath, sysProps, childMainClass, pincipal, keytab) } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 827a2cca841b0..8fe9203a2698c 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -24,7 +24,6 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable import scala.util.{Failure, Success} import scala.util.control.NonFatal - import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil @@ -33,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rpc._ import org.apache.spark.scheduler.{ExecutorLossReason, TaskDescription} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.security.ConfigSecurity +import org.apache.spark.security.{ConfigSecurity, VaultHelper} import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.{ThreadUtils, Utils} @@ -288,7 +287,10 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { appId == null) { printUsageAndExit() } - ConfigSecurity.prepareEnvironment(Option(System.getenv("VAULT_TEMP_TOKEN"))) + ConfigSecurity.prepareEnvironment(scala.util.Try{ + VaultHelper.getRealToken(ConfigSecurity.vaultUri.get, + sys.env("VAULT_TEMP_TOKEN"))}.toOption) + run(driverUrl, executorId, hostname, cores, appId, workerUrl, userClassPath) System.exit(0) } diff --git a/core/src/main/scala/org/apache/spark/security/ConfigSecurity.scala b/core/src/main/scala/org/apache/spark/security/ConfigSecurity.scala index 63f493a267952..e90bb40666d33 100644 --- a/core/src/main/scala/org/apache/spark/security/ConfigSecurity.scala +++ b/core/src/main/scala/org/apache/spark/security/ConfigSecurity.scala @@ -21,12 +21,16 @@ import org.apache.spark.internal.Logging object ConfigSecurity extends Logging{ var vaultToken: Option[String] = None - val vaultHost: Option[String] = sys.env.get("VAULT_HOSTS") + val vaultHost: Option[String] = sys.env.get("VAULT_HOST") val vaultUri: Option[String] = { (sys.env.get("VAULT_PROTOCOL"), vaultHost, sys.env.get("VAULT_PORT")) match { case (Some(vaultProtocol), Some(vaultHost), Some(vaultPort)) => - Option(s"$vaultProtocol://$vaultHost:$vaultPort") - case _ => None + val vaultUri = s"$vaultProtocol://$vaultHost:$vaultPort" + logDebug(s"vault uri: $vaultUri found, any Vault Connection will use it") + Option(vaultUri) + case _ => + logDebug("No Vault information found, any Vault Connection will fail") + None } } @@ -37,7 +41,9 @@ object ConfigSecurity extends Logging{ val secretOptionsMap = ConfigSecurity.extractSecretFromEnv(sys.env) logDebug(s"secretOptionsMap: ${secretOptionsMap.mkString("\n")}") loadingConf(secretOptionsMap) - vaultToken = if (vaultToken.isDefined) vaultAppToken else Option(System.getenv("VAULT_TOKEN")) + vaultToken = if (vaultAppToken.isDefined) { + vaultAppToken + } else sys.env.get("VAULT_TOKEN") if(vaultToken.isDefined) { require(vaultUri.isDefined, "A proper vault host is required") logDebug(s"env VAR: ${sys.env.mkString("\n")}") diff --git a/core/src/main/scala/org/apache/spark/security/SSLConfig.scala b/core/src/main/scala/org/apache/spark/security/SSLConfig.scala index ff0a6703cfd0d..5ec4b5da0a3ac 100644 --- a/core/src/main/scala/org/apache/spark/security/SSLConfig.scala +++ b/core/src/main/scala/org/apache/spark/security/SSLConfig.scala @@ -28,7 +28,7 @@ import sun.security.util.DerInputStream import org.apache.spark.internal.Logging -object SSLConfig extends Logging{ +object SSLConfig extends Logging { val sslTypeDataStore = "DATASTORE" val sslTypeKafkaStore = "KAFKA" @@ -37,19 +37,23 @@ object SSLConfig extends Logging{ vaultToken: String, sslType: String, options: Map[String, String]): Map[String, String] = { - val rootCA = VaultHelper.getRootCA(vaultHost, vaultToken) - val rootCAPath = writeRootCA(rootCA) - val certPass = VaultHelper.getCertPassFromVault(vaultHost, vaultToken) - val trustStorePath = generateTrustStore(sslType, rootCA, certPass) + + val sparkSSLPrefix = "spark.ssl." + + val vaultTrustStorePath = options.get(s"${sslType}_VAULT_TRUSTSTORE_PATH") + val vaultTrustStorePassPath = options.get(s"${sslType}_VAULT_TRUSTSTORE_PASS_PATH") + val trustStore = VaultHelper.getTrustStore(vaultHost, vaultToken, vaultTrustStorePath.get) + val trustPass = VaultHelper.getCertPassForAppFromVault( + vaultHost, vaultTrustStorePassPath.get, vaultToken) + val trustStorePath = generateTrustStore(sslType, trustStore, trustPass) logInfo(s"Setting SSL values for $sslType") val trustStoreOptions = - Map(s"spark.ssl.${sslType.toLowerCase}.enabled" -> "true", - s"spark.ssl.${sslType.toLowerCase}.trustStore" -> trustStorePath, - s"spark.ssl.${sslType.toLowerCase}.trustStorePassword" -> certPass, - s"spark.ssl.${sslType.toLowerCase}.rootCaPath" -> rootCAPath, - s"spark.ssl.${sslType.toLowerCase}.security.protocol" -> "SSL") + Map(s"$sparkSSLPrefix${sslType.toLowerCase}.enabled" -> "true", + s"$sparkSSLPrefix${sslType.toLowerCase}.trustStore" -> trustStorePath, + s"$sparkSSLPrefix${sslType.toLowerCase}.trustStorePassword" -> trustPass, + s"$sparkSSLPrefix${sslType.toLowerCase}.security.protocol" -> "SSL") val vaultKeystorePath = options.get(s"${sslType}_VAULT_CERT_PATH") @@ -65,10 +69,10 @@ object SSLConfig extends Logging{ val keyStorePath = generateKeyStore(sslType, certs, key, pass) - Map(s"spark.ssl${sslType.toLowerCase}.keyStore" -> keyStorePath, - s"spark.ssl${sslType.toLowerCase}.keyStorePassword" -> pass, - s"spark.ssl${sslType.toLowerCase}.protocol" -> "TLSv1.2", - s"spark.ssl${sslType.toLowerCase}.needClientAuth" -> "true" + Map(s"$sparkSSLPrefix${sslType.toLowerCase}.keyStore" -> keyStorePath, + s"$sparkSSLPrefix${sslType.toLowerCase}.keyStorePassword" -> pass, + s"$sparkSSLPrefix${sslType.toLowerCase}.protocol" -> "TLSv1.2", + s"$sparkSSLPrefix${sslType.toLowerCase}.needClientAuth" -> "true" ) } else { @@ -79,13 +83,13 @@ object SSLConfig extends Logging{ val vaultKeyPassPath = options.get(s"${sslType}_VAULT_KEY_PASS_PATH") - val keyPass = Map(s"spark.ssl.${sslType.toLowerCase}.keyPassword" + val keyPass = Map(s"$sparkSSLPrefix${sslType.toLowerCase}.keyPassword" -> VaultHelper.getCertPassForAppFromVault(vaultHost, vaultKeyPassPath.get, vaultToken)) val certFilesPath = - Map("spark.ssl.cert.path" -> s"${sys.env.get("SPARK_SSL_CERT_PATH")}/cert.crt", - "spark.ssl.key.pkcs8" -> s"${sys.env.get("SPARK_SSL_CERT_PATH")}/key.pkcs8", - "spark.ssl.root.cert" -> s"${sys.env.get("SPARK_SSL_CERT_PATH")}/caroot.crt") + Map(sparkSSLPrefix + "cert.path" -> s"${sys.env.get("SPARK_SSL_CERT_PATH")}/cert.crt", + sparkSSLPrefix + "key.pkcs8" -> s"${sys.env.get("SPARK_SSL_CERT_PATH")}/key.pkcs8", + sparkSSLPrefix + "root.cert" -> s"${sys.env.get("SPARK_SSL_CERT_PATH")}/caroot.crt") trustStoreOptions ++ keyStoreOptions ++ keyPass ++ certFilesPath } @@ -96,7 +100,7 @@ object SSLConfig extends Logging{ keystore.load(null) val certs = getBase64FromCAs(cas) - certs.zipWithIndex.foreach{case (cert, index) => + certs.zipWithIndex.foreach { case (cert, index) => val key = s"cert-${index}" keystore.setCertificateEntry(key, generateCertificateFromDER(cert)) } @@ -161,7 +165,7 @@ object SSLConfig extends Logging{ val certs = getBase64FromCAs(cas) val arrayCert = certs.map(cert => generateCertificateFromDER(cert)) val alias = "key-alias" - keystore.setKeyEntry(alias, key, password.toCharArray, arrayCert ) + keystore.setKeyEntry(alias, key, password.toCharArray, arrayCert) val fileName = "keystore.jks" val dir = new File(s"/tmp/$sslType") @@ -180,7 +184,7 @@ object SSLConfig extends Logging{ CertificateFactory.getInstance("X.509").generateCertificate(new ByteArrayInputStream(certBytes)) private def getArrayFromCA(ca: String): Array[String] = { - val splittedBy = ca.takeWhile(_=='-') + val splittedBy = ca.takeWhile(_ == '-') val begin = s"$splittedBy${ca.split(splittedBy).tail.head}$splittedBy" val end = begin.replace("BEGIN", "END") ca.split(begin).tail.map(_.split(end).head) @@ -192,29 +196,4 @@ object SSLConfig extends Logging{ DatatypeConverter.parseBase64Binary(value) }) } - - def writeRootCA(rootCA: String): String = { - def getCertFromOnLine(certBadFormat: String): String = { - var text1 = certBadFormat - var arg = Seq[String]() - while (text1.size != 0) { - val (toStore, toUpdate) = text1.splitAt(64) - text1 = toUpdate - arg = arg ++ Seq(toStore) - } - arg.mkString("\n") - } - - val path = "/tmp/root.crt" - val splitter = rootCA.split("BEGIN").head - val Array(_, head, certBadFormat, tail) = rootCA.split(splitter) - val cert = getCertFromOnLine(certBadFormat) - val writableCert = Seq(s"$splitter$head$splitter", cert, s"$splitter$tail$splitter") - .mkString("\n") - val downloadFile = Files.createFile(Paths.get(path), - PosixFilePermissions.asFileAttribute(PosixFilePermissions.fromString("rw-------"))) - downloadFile.toFile.deleteOnExit() - Files.write(downloadFile, writableCert.getBytes) - path - } } diff --git a/core/src/main/scala/org/apache/spark/security/VaultHelper.scala b/core/src/main/scala/org/apache/spark/security/VaultHelper.scala index 724a24b12f889..23b9c3f6484e7 100644 --- a/core/src/main/scala/org/apache/spark/security/VaultHelper.scala +++ b/core/src/main/scala/org/apache/spark/security/VaultHelper.scala @@ -66,7 +66,7 @@ object VaultHelper extends Logging { logDebug(s"Requesting Secret ID from Vault: $requestUrl") HTTPHelper.executePost(requestUrl, "data", Some(Seq(("X-Vault-Token", token.get))))("secret_id").asInstanceOf[String] - } + } def getTemporalToken(vaultHost: String, token: String): String = { val requestUrl = s"$vaultHost/v1/sys/wrapping/wrap" @@ -90,6 +90,7 @@ object VaultHelper extends Logging { (keytab64, principal) } + @deprecated def getRootCA(vaultUrl: String, token: String): String = { val certVaultPath = "/v1/ca-trust/certificates/" val requestUrl = s"$vaultUrl/$certVaultPath" @@ -97,7 +98,7 @@ object VaultHelper extends Logging { logDebug(s"Requesting Cert List: $listCertKeysVaultPath") val keys = HTTPHelper.executeGet(listCertKeysVaultPath, - "data", Some(Seq(("X-Vault-Token", token))))("pass").asInstanceOf[List[String]] + "data", Some(Seq(("X-Vault-Token", token))))("keys").asInstanceOf[List[String]] keys.flatMap(key => { HTTPHelper.executeGet(s"$requestUrl$key", @@ -105,6 +106,17 @@ object VaultHelper extends Logging { }).map(_._2).mkString } + def getTrustStore(vaultUrl: String, token: String, certVaultPath: String): String = { + val requestUrl = s"$vaultUrl/$certVaultPath" + val truststoreVaultPath = s"$requestUrl" + + logDebug(s"Requesting truststore: $truststoreVaultPath") + val data = HTTPHelper.executeGet(requestUrl, + "data", Some(Seq(("X-Vault-Token", token)))) + val trustStore = data.find(_._1.endsWith("_crt")).get._2.asInstanceOf[String] + trustStore + } + def getCertPassFromVault(vaultUrl: String, token: String): String = { val certPassVaultPath = "/v1/ca-trust/passwords/default/keystore" logDebug(s"Requesting Cert Pass: $certPassVaultPath") diff --git a/pom.xml b/pom.xml index c46dc1bcd3505..4f155eaf529ee 100644 --- a/pom.xml +++ b/pom.xml @@ -2336,7 +2336,7 @@ 0.8.0 false - true + false false false ${basedir}/src/main/scala