Skip to content

Commit

Permalink
Feature/truststore (apache#27)
Browse files Browse the repository at this point in the history
* Added RoleID environment variable

* dot bug fixing && appTOken in executors

* added changes
  • Loading branch information
mpenate authored and jlopezmalla committed Jul 6, 2017
1 parent b6aac20 commit 4e92f82
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 93 deletions.
118 changes: 81 additions & 37 deletions core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ object SparkSubmit {
if (!Utils.classIsLoadable("org.apache.spark.deploy.yarn.Client") && !Utils.isTesting) {
printErrorAndExit(
"Could not load YARN classes. " +
"This copy of Spark may not have been compiled with YARN support.")
"This copy of Spark may not have been compiled with YARN support.")
}
}

Expand All @@ -296,11 +296,11 @@ object SparkSubmit {
// Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files
// too for packages that include Python code
val exclusions: Seq[String] =
if (!StringUtils.isBlank(args.packagesExclusions)) {
args.packagesExclusions.split(",")
} else {
Nil
}
if (!StringUtils.isBlank(args.packagesExclusions)) {
args.packagesExclusions.split(",")
} else {
Nil
}
val resolvedMavenCoordinates = SparkSubmitUtils.resolveMavenCoordinates(args.packages,
Option(args.repositories), Option(args.ivyRepoPath), exclusions = exclusions)
if (!StringUtils.isBlank(resolvedMavenCoordinates)) {
Expand Down Expand Up @@ -498,17 +498,25 @@ object SparkSubmit {
if (isUserJar(args.primaryResource)) {
childClasspath += args.primaryResource
}
if (args.jars != null) { childClasspath ++= args.jars.split(",") }
if (args.childArgs != null) { childArgs ++= args.childArgs }
if (args.jars != null) {
childClasspath ++= args.jars.split(",")
}
if (args.childArgs != null) {
childArgs ++= args.childArgs
}
}

// Map all arguments to command-line options or system properties for our chosen mode
for (opt <- options) {
if (opt.value != null &&
(deployMode & opt.deployMode) != 0 &&
(clusterManager & opt.clusterManager) != 0) {
if (opt.clOption != null) { childArgs += (opt.clOption, opt.value) }
if (opt.sysProp != null) { sysProps.put(opt.sysProp, opt.value) }
(deployMode & opt.deployMode) != 0 &&
(clusterManager & opt.clusterManager) != 0) {
if (opt.clOption != null) {
childArgs += (opt.clOption, opt.value)
}
if (opt.sysProp != null) {
sysProps.put(opt.sysProp, opt.value)
}
}
}

Expand All @@ -532,7 +540,9 @@ object SparkSubmit {
} else {
// In legacy standalone cluster mode, use Client as a wrapper around the user class
childMainClass = "org.apache.spark.deploy.Client"
if (args.supervise) { childArgs += "--supervise" }
if (args.supervise) {
childArgs += "--supervise"
}
Option(args.driverMemory).foreach { m => childArgs += ("--memory", m) }
Option(args.driverCores).foreach { c => childArgs += ("--cores", c) }
childArgs += "launch"
Expand Down Expand Up @@ -654,32 +664,66 @@ object SparkSubmit {
sysProps("spark.submit.pyFiles") = formattedPyFiles
}

val (pincipal, keytab) = if (
(args.sparkProperties.get("spark.secret.roleID").isDefined &&
args.sparkProperties.get("spark.secret.secretID").isDefined)
|| args.sparkProperties.get("spark.secret.vault.tempToken").isDefined
|| sys.env.get("VAULT_TEMP_TOKEN").isDefined) {

val vaultUrl = s"${args.sparkProperties("spark.secret.vault.protocol")}://" +
s"${args.sparkProperties("spark.secret.vault.hosts").split(",")
.map(host => s"$host:${args.sparkProperties("spark.secret.vault.port")}").mkString(",")}"
val vaultToken = if (args.sparkProperties.get("spark.secret.vault.tempToken").isDefined
|| sys.env.get("VAULT_TEMP_TOKEN").isDefined) {
VaultHelper.getRealToken(vaultUrl, args.sparkProperties.getOrElse(
"spark.secret.vault.tempToken", sys.env("VAULT_TEMP_TOKEN")))
} else {
val roleID = args.sparkProperties("spark.secret.roleID")
val secretID = args.sparkProperties("spark.secret.secretID")
VaultHelper.getTokenFromAppRole(vaultUrl, roleID, secretID)
}

val environment = ConfigSecurity.prepareEnvironment(Option(vaultToken), Option(vaultUrl))
val principal = environment.get("principal").getOrElse(args.principal)
val keytab = environment.get("keytabPath").getOrElse(args.keytab)
val mesosRoleEnv = (sys.env.get("VAULT_ROLE_ID"),
sys.env.get("VAULT_SECRET_ID"))

environment.foreach{case (key, value) => sysProps.put(key, value)}
(principal, keytab)
} else (args.principal, args.keytab)
val sparkRoleOpts = (args.sparkProperties.get("spark.secret.roleID"),
args.sparkProperties.get("spark.secret.secretID"))

val tempToken = args.sparkProperties.get("spark.secret.vault.tempToken")


val sysEnvToken = sys.env.get("VAULT_TEMP_TOKEN")

val vaultUrl = s"${args.sparkProperties ("spark.secret.vault.protocol")}://" +
s"${
args.sparkProperties ("spark.secret.vault.hosts").split (",")
.map (host => s"$host:${
args.sparkProperties ("spark.secret.vault.port")
}").mkString (",")
}"


val (pincipal, keytab) =
(mesosRoleEnv, sparkRoleOpts, tempToken, sysEnvToken) match {

case ((roleIdEnv, secretIdEnv), (roleIdProp, secretIdProp), _, _)
if ((roleIdEnv.isDefined || roleIdProp.isDefined) &&
(secretIdEnv.isDefined || secretIdProp.isDefined)) =>
//scalastyle:off
println("Role ID and SecretId found")
val roleId = roleIdEnv.getOrElse(roleIdProp.get)
val secretId = secretIdEnv.getOrElse(secretIdProp.get)
val vaultToken = VaultHelper.getTokenFromAppRole (vaultUrl, roleId, secretId)
val environment = ConfigSecurity.prepareEnvironment(
Option(vaultToken), Option (vaultUrl) )
val principal = environment.get ("principal").getOrElse (args.principal)
val keytab = environment.get ("keytabPath").getOrElse (args.keytab)

environment.foreach {
case (key, value) => sysProps.put (key, value)
}
(principal, keytab)

case (_, _, tempTokenProp, tempTokenEnv)
if (tempTokenProp.isDefined || tempTokenEnv.isDefined) =>
//scalastyle:off
println("TempToken found")
val tempToken = tempTokenProp.getOrElse(tempTokenEnv.get)
val vaultToken = VaultHelper.getRealToken (vaultUrl, tempToken)
val environment = ConfigSecurity.prepareEnvironment(
Option (vaultToken), Option (vaultUrl))
val principal = environment.get ("principal").getOrElse (args.principal)
val keytab = environment.get ("keytabPath").getOrElse (args.keytab)

environment.foreach {
case (key, value) => sysProps.put (key, value)
}
(principal, keytab)

case _ => (args.principal, args.keytab)
}

(childArgs, childClasspath, sysProps, childMainClass, pincipal, keytab)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable
import scala.util.{Failure, Success}
import scala.util.control.NonFatal

import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.SparkHadoopUtil
Expand All @@ -33,7 +32,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rpc._
import org.apache.spark.scheduler.{ExecutorLossReason, TaskDescription}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.security.ConfigSecurity
import org.apache.spark.security.{ConfigSecurity, VaultHelper}
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.{ThreadUtils, Utils}

Expand Down Expand Up @@ -288,7 +287,10 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
appId == null) {
printUsageAndExit()
}
ConfigSecurity.prepareEnvironment(Option(System.getenv("VAULT_TEMP_TOKEN")))
ConfigSecurity.prepareEnvironment(scala.util.Try{
VaultHelper.getRealToken(ConfigSecurity.vaultUri.get,
sys.env("VAULT_TEMP_TOKEN"))}.toOption)

run(driverUrl, executorId, hostname, cores, appId, workerUrl, userClassPath)
System.exit(0)
}
Expand Down
14 changes: 10 additions & 4 deletions core/src/main/scala/org/apache/spark/security/ConfigSecurity.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@ import org.apache.spark.internal.Logging
object ConfigSecurity extends Logging{

var vaultToken: Option[String] = None
val vaultHost: Option[String] = sys.env.get("VAULT_HOSTS")
val vaultHost: Option[String] = sys.env.get("VAULT_HOST")
val vaultUri: Option[String] = {
(sys.env.get("VAULT_PROTOCOL"), vaultHost, sys.env.get("VAULT_PORT")) match {
case (Some(vaultProtocol), Some(vaultHost), Some(vaultPort)) =>
Option(s"$vaultProtocol://$vaultHost:$vaultPort")
case _ => None
val vaultUri = s"$vaultProtocol://$vaultHost:$vaultPort"
logDebug(s"vault uri: $vaultUri found, any Vault Connection will use it")
Option(vaultUri)
case _ =>
logDebug("No Vault information found, any Vault Connection will fail")
None
}
}

Expand All @@ -37,7 +41,9 @@ object ConfigSecurity extends Logging{
val secretOptionsMap = ConfigSecurity.extractSecretFromEnv(sys.env)
logDebug(s"secretOptionsMap: ${secretOptionsMap.mkString("\n")}")
loadingConf(secretOptionsMap)
vaultToken = if (vaultToken.isDefined) vaultAppToken else Option(System.getenv("VAULT_TOKEN"))
vaultToken = if (vaultAppToken.isDefined) {
vaultAppToken
} else sys.env.get("VAULT_TOKEN")
if(vaultToken.isDefined) {
require(vaultUri.isDefined, "A proper vault host is required")
logDebug(s"env VAR: ${sys.env.mkString("\n")}")
Expand Down
71 changes: 25 additions & 46 deletions core/src/main/scala/org/apache/spark/security/SSLConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import sun.security.util.DerInputStream

import org.apache.spark.internal.Logging

object SSLConfig extends Logging{
object SSLConfig extends Logging {

val sslTypeDataStore = "DATASTORE"
val sslTypeKafkaStore = "KAFKA"
Expand All @@ -37,19 +37,23 @@ object SSLConfig extends Logging{
vaultToken: String,
sslType: String,
options: Map[String, String]): Map[String, String] = {
val rootCA = VaultHelper.getRootCA(vaultHost, vaultToken)
val rootCAPath = writeRootCA(rootCA)
val certPass = VaultHelper.getCertPassFromVault(vaultHost, vaultToken)
val trustStorePath = generateTrustStore(sslType, rootCA, certPass)

val sparkSSLPrefix = "spark.ssl."

val vaultTrustStorePath = options.get(s"${sslType}_VAULT_TRUSTSTORE_PATH")
val vaultTrustStorePassPath = options.get(s"${sslType}_VAULT_TRUSTSTORE_PASS_PATH")
val trustStore = VaultHelper.getTrustStore(vaultHost, vaultToken, vaultTrustStorePath.get)
val trustPass = VaultHelper.getCertPassForAppFromVault(
vaultHost, vaultTrustStorePassPath.get, vaultToken)
val trustStorePath = generateTrustStore(sslType, trustStore, trustPass)

logInfo(s"Setting SSL values for $sslType")

val trustStoreOptions =
Map(s"spark.ssl.${sslType.toLowerCase}.enabled" -> "true",
s"spark.ssl.${sslType.toLowerCase}.trustStore" -> trustStorePath,
s"spark.ssl.${sslType.toLowerCase}.trustStorePassword" -> certPass,
s"spark.ssl.${sslType.toLowerCase}.rootCaPath" -> rootCAPath,
s"spark.ssl.${sslType.toLowerCase}.security.protocol" -> "SSL")
Map(s"$sparkSSLPrefix${sslType.toLowerCase}.enabled" -> "true",
s"$sparkSSLPrefix${sslType.toLowerCase}.trustStore" -> trustStorePath,
s"$sparkSSLPrefix${sslType.toLowerCase}.trustStorePassword" -> trustPass,
s"$sparkSSLPrefix${sslType.toLowerCase}.security.protocol" -> "SSL")

val vaultKeystorePath = options.get(s"${sslType}_VAULT_CERT_PATH")

Expand All @@ -65,10 +69,10 @@ object SSLConfig extends Logging{

val keyStorePath = generateKeyStore(sslType, certs, key, pass)

Map(s"spark.ssl${sslType.toLowerCase}.keyStore" -> keyStorePath,
s"spark.ssl${sslType.toLowerCase}.keyStorePassword" -> pass,
s"spark.ssl${sslType.toLowerCase}.protocol" -> "TLSv1.2",
s"spark.ssl${sslType.toLowerCase}.needClientAuth" -> "true"
Map(s"$sparkSSLPrefix${sslType.toLowerCase}.keyStore" -> keyStorePath,
s"$sparkSSLPrefix${sslType.toLowerCase}.keyStorePassword" -> pass,
s"$sparkSSLPrefix${sslType.toLowerCase}.protocol" -> "TLSv1.2",
s"$sparkSSLPrefix${sslType.toLowerCase}.needClientAuth" -> "true"
)

} else {
Expand All @@ -79,13 +83,13 @@ object SSLConfig extends Logging{

val vaultKeyPassPath = options.get(s"${sslType}_VAULT_KEY_PASS_PATH")

val keyPass = Map(s"spark.ssl.${sslType.toLowerCase}.keyPassword"
val keyPass = Map(s"$sparkSSLPrefix${sslType.toLowerCase}.keyPassword"
-> VaultHelper.getCertPassForAppFromVault(vaultHost, vaultKeyPassPath.get, vaultToken))

val certFilesPath =
Map("spark.ssl.cert.path" -> s"${sys.env.get("SPARK_SSL_CERT_PATH")}/cert.crt",
"spark.ssl.key.pkcs8" -> s"${sys.env.get("SPARK_SSL_CERT_PATH")}/key.pkcs8",
"spark.ssl.root.cert" -> s"${sys.env.get("SPARK_SSL_CERT_PATH")}/caroot.crt")
Map(sparkSSLPrefix + "cert.path" -> s"${sys.env.get("SPARK_SSL_CERT_PATH")}/cert.crt",
sparkSSLPrefix + "key.pkcs8" -> s"${sys.env.get("SPARK_SSL_CERT_PATH")}/key.pkcs8",
sparkSSLPrefix + "root.cert" -> s"${sys.env.get("SPARK_SSL_CERT_PATH")}/caroot.crt")

trustStoreOptions ++ keyStoreOptions ++ keyPass ++ certFilesPath
}
Expand All @@ -96,7 +100,7 @@ object SSLConfig extends Logging{
keystore.load(null)
val certs = getBase64FromCAs(cas)

certs.zipWithIndex.foreach{case (cert, index) =>
certs.zipWithIndex.foreach { case (cert, index) =>
val key = s"cert-${index}"
keystore.setCertificateEntry(key, generateCertificateFromDER(cert))
}
Expand Down Expand Up @@ -161,7 +165,7 @@ object SSLConfig extends Logging{
val certs = getBase64FromCAs(cas)
val arrayCert = certs.map(cert => generateCertificateFromDER(cert))
val alias = "key-alias"
keystore.setKeyEntry(alias, key, password.toCharArray, arrayCert )
keystore.setKeyEntry(alias, key, password.toCharArray, arrayCert)

val fileName = "keystore.jks"
val dir = new File(s"/tmp/$sslType")
Expand All @@ -180,7 +184,7 @@ object SSLConfig extends Logging{
CertificateFactory.getInstance("X.509").generateCertificate(new ByteArrayInputStream(certBytes))

private def getArrayFromCA(ca: String): Array[String] = {
val splittedBy = ca.takeWhile(_=='-')
val splittedBy = ca.takeWhile(_ == '-')
val begin = s"$splittedBy${ca.split(splittedBy).tail.head}$splittedBy"
val end = begin.replace("BEGIN", "END")
ca.split(begin).tail.map(_.split(end).head)
Expand All @@ -192,29 +196,4 @@ object SSLConfig extends Logging{
DatatypeConverter.parseBase64Binary(value)
})
}

def writeRootCA(rootCA: String): String = {
def getCertFromOnLine(certBadFormat: String): String = {
var text1 = certBadFormat
var arg = Seq[String]()
while (text1.size != 0) {
val (toStore, toUpdate) = text1.splitAt(64)
text1 = toUpdate
arg = arg ++ Seq(toStore)
}
arg.mkString("\n")
}

val path = "/tmp/root.crt"
val splitter = rootCA.split("BEGIN").head
val Array(_, head, certBadFormat, tail) = rootCA.split(splitter)
val cert = getCertFromOnLine(certBadFormat)
val writableCert = Seq(s"$splitter$head$splitter", cert, s"$splitter$tail$splitter")
.mkString("\n")
val downloadFile = Files.createFile(Paths.get(path),
PosixFilePermissions.asFileAttribute(PosixFilePermissions.fromString("rw-------")))
downloadFile.toFile.deleteOnExit()
Files.write(downloadFile, writableCert.getBytes)
path
}
}
16 changes: 14 additions & 2 deletions core/src/main/scala/org/apache/spark/security/VaultHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ object VaultHelper extends Logging {
logDebug(s"Requesting Secret ID from Vault: $requestUrl")
HTTPHelper.executePost(requestUrl, "data",
Some(Seq(("X-Vault-Token", token.get))))("secret_id").asInstanceOf[String]
}
}

def getTemporalToken(vaultHost: String, token: String): String = {
val requestUrl = s"$vaultHost/v1/sys/wrapping/wrap"
Expand All @@ -90,21 +90,33 @@ object VaultHelper extends Logging {
(keytab64, principal)
}

@deprecated
def getRootCA(vaultUrl: String, token: String): String = {
val certVaultPath = "/v1/ca-trust/certificates/"
val requestUrl = s"$vaultUrl/$certVaultPath"
val listCertKeysVaultPath = s"$requestUrl?list=true"

logDebug(s"Requesting Cert List: $listCertKeysVaultPath")
val keys = HTTPHelper.executeGet(listCertKeysVaultPath,
"data", Some(Seq(("X-Vault-Token", token))))("pass").asInstanceOf[List[String]]
"data", Some(Seq(("X-Vault-Token", token))))("keys").asInstanceOf[List[String]]

keys.flatMap(key => {
HTTPHelper.executeGet(s"$requestUrl$key",
"data", Some(Seq(("X-Vault-Token", token)))).find(_._1.endsWith("_crt"))
}).map(_._2).mkString
}

def getTrustStore(vaultUrl: String, token: String, certVaultPath: String): String = {
val requestUrl = s"$vaultUrl/$certVaultPath"
val truststoreVaultPath = s"$requestUrl"

logDebug(s"Requesting truststore: $truststoreVaultPath")
val data = HTTPHelper.executeGet(requestUrl,
"data", Some(Seq(("X-Vault-Token", token))))
val trustStore = data.find(_._1.endsWith("_crt")).get._2.asInstanceOf[String]
trustStore
}

def getCertPassFromVault(vaultUrl: String, token: String): String = {
val certPassVaultPath = "/v1/ca-trust/passwords/default/keystore"
logDebug(s"Requesting Cert Pass: $certPassVaultPath")
Expand Down
Loading

0 comments on commit 4e92f82

Please sign in to comment.