Skip to content

Commit

Permalink
Replace usages of JWKSet.load(URL) with a Ktor get request + `JWKSe…
Browse files Browse the repository at this point in the history
…t.parse(String)`
  • Loading branch information
dzarras committed Nov 27, 2023
1 parent 8a8f6bc commit 68895a7
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ internal class ClientMetaDataResolver(
private val httpClientFactory: KtorHttpClientFactory = DefaultHttpClientFactory,
walletOpenId4VPConfig: WalletOpenId4VPConfig,
) {
private val clientMetadataValidator = ClientMetadataValidator(walletOpenId4VPConfig)
private val clientMetadataValidator = ClientMetadataValidator(walletOpenId4VPConfig, httpClientFactory)
suspend fun resolve(clientMetaDataSource: ClientMetaDataSource): Result<ClientMetaData> {
val unvalidatedClientMetaData = when (clientMetaDataSource) {
is ClientMetaDataSource.ByValue -> clientMetaDataSource.metaData
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ import com.nimbusds.jose.JWSAlgorithm
import com.nimbusds.jose.jwk.JWKSet
import eu.europa.ec.eudi.openid4vp.*
import eu.europa.ec.eudi.openid4vp.internal.success
import io.ktor.client.call.*
import io.ktor.client.request.*
import java.io.IOException
import java.net.URL
import java.text.ParseException

internal class ClientMetadataValidator(
private val walletOpenId4VPConfig: WalletOpenId4VPConfig,
private val httpClientFactory: KtorHttpClientFactory = DefaultHttpClientFactory,
) {

suspend fun validate(unvalidatedClientMetadata: UnvalidatedClientMetaData): Result<ClientMetaData> = runCatching {
Expand Down Expand Up @@ -130,7 +133,7 @@ internal class ClientMetadataValidator(
if (encryptionMethod.isNullOrEmpty()) RequestValidationError.IdTokenEncryptionMethodMissing.asFailure()
else Result.success(EncryptionMethod.parse(encryptionMethod))

private fun parseRequiredJwks(clientMetadata: UnvalidatedClientMetaData): Result<JWKSet> {
private suspend fun parseRequiredJwks(clientMetadata: UnvalidatedClientMetaData): Result<JWKSet> {
val atLeastOneJwkSourceDefined = !clientMetadata.jwks.isNullOrEmpty() || !clientMetadata.jwksUri.isNullOrEmpty()
if (!atLeastOneJwkSourceDefined) {
return RequestValidationError.MissingClientMetadataJwksSource.asFailure()
Expand All @@ -146,8 +149,12 @@ internal class ClientMetadataValidator(
ResolutionError.ClientMetadataJwkUriUnparsable(ex).asFailure()
}

fun requiredJwksUri() = try {
Result.success(JWKSet.load(URL(clientMetadata.jwksUri)))
suspend fun requiredJwksUri() = try {
val unparsed = httpClientFactory().use { client ->
client.get(URL(clientMetadata.jwksUri)).body<String>()
}
val jwkSet = JWKSet.parse(unparsed)
Result.success(jwkSet)
} catch (ex: IOException) {
ResolutionError.ClientMetadataJwkResolutionFailed(ex).asFailure()
} catch (ex: ParseException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ internal class DefaultAuthorizationRequestResolver(
* @param clientId The client that placed request
*/
private suspend fun requestObjectFromJwt(clientId: String, jwt: Jwt): Result<RequestObject> {
val validator = JarJwtSignatureValidator(walletOpenId4VPConfig)
val validator = JarJwtSignatureValidator(walletOpenId4VPConfig, httpClientFactory)
return validator.validate(clientId, jwt)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import eu.europa.ec.eudi.openid4vp.*
import eu.europa.ec.eudi.openid4vp.SupportedClientIdScheme.IsoX509
import eu.europa.ec.eudi.openid4vp.SupportedClientIdScheme.Preregistered
import eu.europa.ec.eudi.openid4vp.internal.success
import io.ktor.client.call.*
import io.ktor.client.request.*
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.jsonObject
Expand All @@ -43,12 +45,14 @@ import java.text.ParseException
* Validates a JWT that represents an Authorization Request according to RFC9101
*
* @param walletOpenId4VPConfig wallet's configuration
* @param httpClientFactory a factory to obtain a Ktor http client
*/
internal class JarJwtSignatureValidator(
private val walletOpenId4VPConfig: WalletOpenId4VPConfig,
private val httpClientFactory: KtorHttpClientFactory = DefaultHttpClientFactory,
) {

fun validate(clientId: String, jwt: Jwt): Result<RequestObject> = runCatching {
suspend fun validate(clientId: String, jwt: Jwt): Result<RequestObject> = runCatching {
val signedJwt = parse(jwt).getOrThrow()
val error = doValidate(clientId, signedJwt)
if (null == error) signedJwt.jwtClaimsSet.toType { requestObject(it) }
Expand All @@ -62,7 +66,7 @@ internal class JarJwtSignatureValidator(
RequestValidationError.InvalidJarJwt("JAR JWT parse error").asFailure()
}

private fun doValidate(clientId: String, signedJwt: SignedJWT): AuthorizationRequestError? {
private suspend fun doValidate(clientId: String, signedJwt: SignedJWT): AuthorizationRequestError? {
val untrustedClaimSet = signedJwt.jwtClaimsSet
val jwtClientId = untrustedClaimSet.getStringClaim("client_id")

Expand All @@ -87,57 +91,60 @@ internal class JarJwtSignatureValidator(
}
}
}
}

private fun invalidJarJwt(cause: String): AuthorizationRequestError = RequestValidationError.InvalidJarJwt(cause)

private fun validatePreregistered(
supportedClientIdScheme: Preregistered,
clientId: String,
signedJwt: SignedJWT,
): AuthorizationRequestError? {
fun PreregisteredClient.verifySignature() =
try {
val jwtProcessor = jwtProcessor(this)
jwtProcessor.process(signedJwt, null)
null
} catch (e: JOSEException) {
throw RuntimeException(e)
} catch (e: BadJOSEException) {
invalidJarJwt("Invalid signature ${e.message}")
}

val trustedClient = supportedClientIdScheme.clients[clientId]
return if (null == trustedClient) invalidJarJwt("Client with client_id $clientId is not pre-registered")
else trustedClient.verifySignature()
}
private suspend fun validatePreregistered(
supportedClientIdScheme: Preregistered,
clientId: String,
signedJwt: SignedJWT,
): AuthorizationRequestError? {
suspend fun PreregisteredClient.verifySignature() =
try {
val jwtProcessor = jwtProcessor(this)
jwtProcessor.process(signedJwt, null)
null
} catch (e: JOSEException) {
throw RuntimeException(e)
} catch (e: BadJOSEException) {
invalidJarJwt("Invalid signature ${e.message}")
}

private fun jwtProcessor(client: PreregisteredClient): ConfigurableJWTProcessor<SecurityContext> =
DefaultJWTProcessor<SecurityContext>().also {
it.jwsTypeVerifier = DefaultJOSEObjectTypeVerifier(
JOSEObjectType("oauth-authz-req+jwt"),
)
it.jwsKeySelector = JWSVerificationKeySelector(
client.jarSigningAlg.toNimbusJWSAlgorithm(),
client.jwkSetSource.toNimbus(),
)
val trustedClient = supportedClientIdScheme.clients[clientId]
return if (null == trustedClient) invalidJarJwt("Client with client_id $clientId is not pre-registered")
else trustedClient.verifySignature()
}

private fun String.toNimbusJWSAlgorithm() = JWSAlgorithm.parse(this)

internal fun JwkSetSource.toNimbus(): JWKSource<SecurityContext> {
val jwkSet = when (this) {
is JwkSetSource.ByValue -> {
JWKSet.parse(jwks.toString())
private suspend fun jwtProcessor(client: PreregisteredClient): ConfigurableJWTProcessor<SecurityContext> =
DefaultJWTProcessor<SecurityContext>().also {
it.jwsTypeVerifier = DefaultJOSEObjectTypeVerifier(
JOSEObjectType("oauth-authz-req+jwt"),
)
it.jwsKeySelector = JWSVerificationKeySelector(
client.jarSigningAlg.toNimbusJWSAlgorithm(),
client.jwkSetSource.toNimbus(),
)
}

is JwkSetSource.ByReference -> {
JWKSet.load(jwksUri.toURL())
private suspend fun JwkSetSource.toNimbus(): JWKSource<SecurityContext> {
val jwkSet = when (this) {
is JwkSetSource.ByValue -> {
JWKSet.parse(jwks.toString())
}

is JwkSetSource.ByReference -> {
val unparsed = httpClientFactory().use { client ->
client.get(jwksUri.toURL()).body<String>()
}
JWKSet.parse(unparsed)
}
}
return ImmutableJWKSet(jwkSet)
}
return ImmutableJWKSet(jwkSet)
}

private fun invalidJarJwt(cause: String): AuthorizationRequestError = RequestValidationError.InvalidJarJwt(cause)

private fun String.toNimbusJWSAlgorithm() = JWSAlgorithm.parse(this)

private fun requestObject(cs: JWTClaimsSet): RequestObject {
fun Map<String, Any?>.asJsonObject(): JsonObject {
val jsonStr = Gson().toJson(this)
Expand Down

0 comments on commit 68895a7

Please sign in to comment.