From 49fb62cb96cd9afc854f5102313f16e27c0eb2b4 Mon Sep 17 00:00:00 2001 From: Filip Skokan Date: Sat, 25 Sep 2021 17:27:03 +0200 Subject: [PATCH] feat: return resolved key when verify and decrypt resolve functions are used --- src/jwe/compact/decrypt.ts | 28 +++++++++++++++++++++++++--- src/jwe/flattened/decrypt.ts | 26 ++++++++++++++++++++++++-- src/jwe/general/decrypt.ts | 20 ++++++++++++++++++-- src/jws/compact/verify.ts | 28 +++++++++++++++++++++++++--- src/jws/flattened/verify.ts | 26 ++++++++++++++++++++++++-- src/jws/general/verify.ts | 20 ++++++++++++++++++-- src/jwt/decrypt.ts | 30 ++++++++++++++++++++++++++---- src/jwt/verify.ts | 28 ++++++++++++++++++++++++---- src/types.d.ts | 7 +++++++ test/jwk/embedded.test.mjs | 6 +++++- test/jwks/remote.test.mjs | 6 +++++- test/jwt/encrypt.test.mjs | 7 ++++++- test/jwt/sign.test.mjs | 3 ++- 13 files changed, 209 insertions(+), 26 deletions(-) diff --git a/src/jwe/compact/decrypt.ts b/src/jwe/compact/decrypt.ts index add51353c9..0968321541 100644 --- a/src/jwe/compact/decrypt.ts +++ b/src/jwe/compact/decrypt.ts @@ -8,6 +8,7 @@ import type { GetKeyFunction, FlattenedJWE, CompactDecryptResult, + ResolvedKey, } from '../../types.d' /** @@ -20,7 +21,7 @@ export interface CompactDecryptGetKey extends GetKeyFunction +/** + * @param jwe Compact JWE. + * @param getKey Function resolving Private Key or Secret to decrypt the JWE with. + * @param options JWE Decryption options. + */ +async function compactDecrypt( + jwe: string | Uint8Array, + getKey: CompactDecryptGetKey, + options?: DecryptOptions, +): Promise async function compactDecrypt( jwe: string | Uint8Array, key: KeyLike | CompactDecryptGetKey, options?: DecryptOptions, -): Promise { +) { if (jwe instanceof Uint8Array) { jwe = decoder.decode(jwe) } @@ -86,7 +102,13 @@ async function compactDecrypt( options, ) - return { plaintext: decrypted.plaintext, protectedHeader: decrypted.protectedHeader! } + const result = { plaintext: decrypted.plaintext, protectedHeader: decrypted.protectedHeader! } + + if (typeof key === 'function') { + return { ...result, key: decrypted.key } + } + + return result } export { compactDecrypt } diff --git a/src/jwe/flattened/decrypt.ts b/src/jwe/flattened/decrypt.ts index 1f6afc8bd4..f52a0c3778 100644 --- a/src/jwe/flattened/decrypt.ts +++ b/src/jwe/flattened/decrypt.ts @@ -14,6 +14,7 @@ import type { JWEHeaderParameters, DecryptOptions, GetKeyFunction, + ResolvedKey, } from '../../types.d' import { encoder, decoder, concat } from '../../lib/buffer_utils.js' import cekFactory from '../../lib/cek.js' @@ -36,7 +37,7 @@ export interface FlattenedDecryptGetKey * Decrypts a Flattened JWE. * * @param jwe Flattened JWE. - * @param key Public Key or Secret, or a function resolving one, to decrypt the JWE with. + * @param key Private Key or Secret to decrypt the JWE with. * @param options JWE Decryption options. * * @example ESM import @@ -77,11 +78,26 @@ export interface FlattenedDecryptGetKey * console.log(decoder.decode(additionalAuthenticatedData)) * ``` */ +function flattenedDecrypt( + jwe: FlattenedJWE, + key: KeyLike, + options?: DecryptOptions, +): Promise +/** + * @param jwe Flattened JWE. + * @param getKey Function resolving Private Key or Secret to decrypt the JWE with. + * @param options JWE Decryption options. + */ +function flattenedDecrypt( + jwe: FlattenedJWE, + getKey: FlattenedDecryptGetKey, + options?: DecryptOptions, +): Promise async function flattenedDecrypt( jwe: FlattenedJWE, key: KeyLike | FlattenedDecryptGetKey, options?: DecryptOptions, -): Promise { +) { if (!isObject(jwe)) { throw new JWEInvalid('Flattened JWE must be an object') } @@ -183,8 +199,10 @@ async function flattenedDecrypt( encryptedKey = base64url(jwe.encrypted_key!) } + let resolvedKey = false if (typeof key === 'function') { key = await key(parsedProt, jwe) + resolvedKey = true } let cek: KeyLike @@ -240,6 +258,10 @@ async function flattenedDecrypt( result.unprotectedHeader = jwe.header } + if (resolvedKey) { + return { ...result, key } + } + return result } diff --git a/src/jwe/general/decrypt.ts b/src/jwe/general/decrypt.ts index 95cc31f81d..d7f07e4a3f 100644 --- a/src/jwe/general/decrypt.ts +++ b/src/jwe/general/decrypt.ts @@ -8,6 +8,7 @@ import type { FlattenedJWE, GeneralJWE, GeneralDecryptResult, + ResolvedKey, } from '../../types.d' import isObject from '../../lib/is_object.js' @@ -21,7 +22,7 @@ export interface GeneralDecryptGetKey extends GetKeyFunction +/** + * @param jwe General JWE. + * @param getKey Function resolving Private Key or Secret to decrypt the JWE with. + * @param options JWE Decryption options. + */ +function generalDecrypt( + jwe: GeneralJWE, + getKey: GeneralDecryptGetKey, + options?: DecryptOptions, +): Promise async function generalDecrypt( jwe: GeneralJWE, key: KeyLike | GeneralDecryptGetKey, options?: DecryptOptions, -): Promise { +) { if (!isObject(jwe)) { throw new JWEInvalid('General JWE must be an object') } diff --git a/src/jws/compact/verify.ts b/src/jws/compact/verify.ts index 16d7dd1a81..5af7c969de 100644 --- a/src/jws/compact/verify.ts +++ b/src/jws/compact/verify.ts @@ -8,6 +8,7 @@ import type { JWSHeaderParameters, KeyLike, VerifyOptions, + ResolvedKey, } from '../../types.d' /** @@ -24,7 +25,7 @@ export interface CompactVerifyGetKey * Verifies the signature and format of and afterwards decodes the Compact JWS. * * @param jws Compact JWS. - * @param key Key, or a function resolving a key, to verify the JWS with. + * @param key Key to verify the JWS with. * @param options JWS Verify options. * * @example ESM import @@ -53,11 +54,26 @@ export interface CompactVerifyGetKey * console.log(decoder.decode(payload)) * ``` */ +function compactVerify( + jws: string | Uint8Array, + key: KeyLike, + options?: VerifyOptions, +): Promise +/** + * @param jws Compact JWS. + * @param getKey Function resolving a key to verify the JWS with. + * @param options JWS Verify options. + */ +function compactVerify( + jws: string | Uint8Array, + getKey: CompactVerifyGetKey, + options?: VerifyOptions, +): Promise async function compactVerify( jws: string | Uint8Array, key: KeyLike | CompactVerifyGetKey, options?: VerifyOptions, -): Promise { +) { if (jws instanceof Uint8Array) { jws = decoder.decode(jws) } @@ -81,7 +97,13 @@ async function compactVerify( options, ) - return { payload: verified.payload, protectedHeader: verified.protectedHeader! } + const result = { payload: verified.payload, protectedHeader: verified.protectedHeader! } + + if (typeof key === 'function') { + return { ...result, key: verified.key } + } + + return result } export { compactVerify } diff --git a/src/jws/flattened/verify.ts b/src/jws/flattened/verify.ts index b64c68dc05..dab39fb170 100644 --- a/src/jws/flattened/verify.ts +++ b/src/jws/flattened/verify.ts @@ -16,6 +16,7 @@ import type { JWSHeaderParameters, VerifyOptions, GetKeyFunction, + ResolvedKey, } from '../../types.d' const checkExtensions = validateCrit.bind(undefined, JWSInvalid, new Map([['b64', true]])) @@ -35,7 +36,7 @@ export interface FlattenedVerifyGetKey * Verifies the signature and format of and afterwards decodes the Flattened JWS. * * @param jws Flattened JWS. - * @param key Key, or a function resolving a key, to verify the JWS with. + * @param key Key to verify the JWS with. * @param options JWS Verify options. * * @example ESM import @@ -68,11 +69,26 @@ export interface FlattenedVerifyGetKey * console.log(decoder.decode(payload)) * ``` */ +function flattenedVerify( + jws: FlattenedJWSInput, + key: KeyLike, + options?: VerifyOptions, +): Promise +/** + * @param jws Flattened JWS. + * @param getKey Function resolving a key to verify the JWS with. + * @param options JWS Verify options. + */ +function flattenedVerify( + jws: FlattenedJWSInput, + getKey: FlattenedVerifyGetKey, + options?: VerifyOptions, +): Promise async function flattenedVerify( jws: FlattenedJWSInput, key: KeyLike | FlattenedVerifyGetKey, options?: VerifyOptions, -): Promise { +) { if (!isObject(jws)) { throw new JWSInvalid('Flattened JWS must be an object') } @@ -149,8 +165,10 @@ async function flattenedVerify( throw new JWSInvalid('JWS Payload must be a string or an Uint8Array instance') } + let resolvedKey = false if (typeof key === 'function') { key = await key(parsedProt, jws) + resolvedKey = true } checkKeyType(alg, key, 'verify') @@ -186,6 +204,10 @@ async function flattenedVerify( result.unprotectedHeader = jws.header } + if (resolvedKey) { + return { ...result, key } + } + return result } diff --git a/src/jws/general/verify.ts b/src/jws/general/verify.ts index 900f53157a..57fcda5459 100644 --- a/src/jws/general/verify.ts +++ b/src/jws/general/verify.ts @@ -7,6 +7,7 @@ import type { JWSHeaderParameters, KeyLike, VerifyOptions, + ResolvedKey, } from '../../types.d' import { JWSInvalid, JWSSignatureVerificationFailed } from '../../util/errors.js' import isObject from '../../lib/is_object.js' @@ -25,7 +26,7 @@ export interface GeneralVerifyGetKey * Verifies the signature and format of and afterwards decodes the General JWS. * * @param jws General JWS. - * @param key Key, or a function resolving a key, to verify the JWS with. + * @param key Key to verify the JWS with. * @param options JWS Verify options. * * @example ESM import @@ -62,11 +63,26 @@ export interface GeneralVerifyGetKey * console.log(decoder.decode(payload)) * ``` */ +function generalVerify( + jws: GeneralJWSInput, + key: KeyLike, + options?: VerifyOptions, +): Promise +/** + * @param jws General JWS. + * @param getKey Function resolving a key to verify the JWS with. + * @param options JWS Verify options. + */ +function generalVerify( + jws: GeneralJWSInput, + getKey: GeneralVerifyGetKey, + options?: VerifyOptions, +): Promise async function generalVerify( jws: GeneralJWSInput, key: KeyLike | GeneralVerifyGetKey, options?: VerifyOptions, -): Promise { +) { if (!isObject(jws)) { throw new JWSInvalid('General JWS must be an object') } diff --git a/src/jwt/decrypt.ts b/src/jwt/decrypt.ts index d2d4ed0a98..e7a9aa5491 100644 --- a/src/jwt/decrypt.ts +++ b/src/jwt/decrypt.ts @@ -8,6 +8,7 @@ import type { JWEHeaderParameters, FlattenedJWE, JWTDecryptResult, + ResolvedKey, } from '../types.d' import jwtPayload from '../lib/jwt_claims_set.js' import { JWTClaimValidationFailed } from '../util/errors.js' @@ -27,7 +28,7 @@ export interface JWTDecryptGetKey extends GetKeyFunction +/** + * @param jwt JSON Web Token value (encoded as JWE). + * @param getKey Function resolving Private Key or Secret to decrypt and verify the JWT with. + * @param options JWT Decryption and JWT Claims Set validation options. + */ +async function jwtDecrypt( + jwt: string | Uint8Array, + getKey: JWTDecryptGetKey, + options?: JWTDecryptOptions, +): Promise async function jwtDecrypt( jwt: string | Uint8Array, key: KeyLike | JWTDecryptGetKey, options?: JWTDecryptOptions, -): Promise { - const decrypted = await decrypt(jwt, key, options) +) { + const decrypted = await decrypt(jwt, [1]>key, options) const payload = jwtPayload(decrypted.protectedHeader, decrypted.plaintext, options) const { protectedHeader } = decrypted @@ -95,7 +111,13 @@ async function jwtDecrypt( ) } - return { payload, protectedHeader } + const result = { payload, protectedHeader } + + if (typeof key === 'function') { + return { ...result, key: decrypted.key } + } + + return result } export { jwtDecrypt } diff --git a/src/jwt/verify.ts b/src/jwt/verify.ts index f8f2b8a274..c5f88fd1ef 100644 --- a/src/jwt/verify.ts +++ b/src/jwt/verify.ts @@ -8,6 +8,7 @@ import type { GetKeyFunction, FlattenedJWSInput, JWTVerifyResult, + ResolvedKey, } from '../types.d' import jwtPayload from '../lib/jwt_claims_set.js' import { JWTInvalid } from '../util/errors.js' @@ -30,7 +31,7 @@ export interface JWTVerifyGetKey extends GetKeyFunction +/** + * @param jwt JSON Web Token value (encoded as JWS). + * @param getKey Function resolving a key to verify the JWT with. + * @param options JWT Decryption and JWT Claims Set validation options. + */ +async function jwtVerify( + jwt: string | Uint8Array, + getKey: JWTVerifyGetKey, + options?: JWTVerifyOptions, +): Promise async function jwtVerify( jwt: string | Uint8Array, key: KeyLike | JWTVerifyGetKey, options?: JWTVerifyOptions, -): Promise { - const verified = await verify(jwt, key, options) +) { + const verified = await verify(jwt, [1]>key, options) if (verified.protectedHeader.crit?.includes('b64') && verified.protectedHeader.b64 === false) { throw new JWTInvalid('JWTs MUST NOT use unencoded payload') } const payload = jwtPayload(verified.protectedHeader, verified.payload, options) - return { payload, protectedHeader: verified.protectedHeader } + const result = { payload, protectedHeader: verified.protectedHeader } + if (typeof key === 'function') { + return { ...result, key: verified.key } + } + return result } export { jwtVerify } diff --git a/src/types.d.ts b/src/types.d.ts index 5b077484f6..5dbb4291e4 100644 --- a/src/types.d.ts +++ b/src/types.d.ts @@ -721,3 +721,10 @@ export interface JWTDecryptResult { */ protectedHeader: JWEHeaderParameters } + +export interface ResolvedKey { + /** + * Key resolved from the key resolver function. + */ + key: KeyLike +} diff --git a/test/jwk/embedded.test.mjs b/test/jwk/embedded.test.mjs index 66c25c08db..75874a6f3e 100644 --- a/test/jwk/embedded.test.mjs +++ b/test/jwk/embedded.test.mjs @@ -65,7 +65,11 @@ Promise.all([ }); test('EmbeddedJWK', async (t) => { - await t.notThrowsAsync(flattenedVerify(t.context.token, EmbeddedJWK)); + await t.notThrowsAsync(async () => { + const { key: resolvedKey } = await flattenedVerify(t.context.token, EmbeddedJWK); + t.truthy(resolvedKey); + t.is(resolvedKey.type, 'public'); + }); }); test('EmbeddedJWK requires "jwk" to be an object', async (t) => { diff --git a/test/jwks/remote.test.mjs b/test/jwks/remote.test.mjs index d5b6981af7..d357cdcc61 100644 --- a/test/jwks/remote.test.mjs +++ b/test/jwks/remote.test.mjs @@ -135,7 +135,11 @@ Promise.all([ const jwt = await new SignJWT({}) .setProtectedHeader({ alg: 'PS256', kid: jwk.kid }) .sign(key); - await t.notThrowsAsync(jwtVerify(jwt, JWKS)); + await t.notThrowsAsync(async () => { + const { key: resolvedKey } = await jwtVerify(jwt, JWKS); + t.truthy(resolvedKey); + t.is(resolvedKey.type, 'public'); + }); } { const [jwk] = keys; diff --git a/test/jwt/encrypt.test.mjs b/test/jwt/encrypt.test.mjs index 26eb5cdd47..5819e23714 100644 --- a/test/jwt/encrypt.test.mjs +++ b/test/jwt/encrypt.test.mjs @@ -97,7 +97,11 @@ Promise.all([ const jwt = await enc.encrypt(t.context.secret); - const { plaintext, protectedHeader } = await compactDecrypt(jwt, async (header, token) => { + const { + plaintext, + protectedHeader, + key: resolvedKey, + } = await compactDecrypt(jwt, async (header, token) => { t.true('alg' in header); t.true('enc' in header); t.is(header.alg, 'dir'); @@ -108,6 +112,7 @@ Promise.all([ t.true('tag' in token); return t.context.secret; }); + t.is(resolvedKey, t.context.secret); const payload = JSON.parse(new TextDecoder().decode(plaintext)); t.is(payload[claim], expected); if (duplicate) { diff --git a/test/jwt/sign.test.mjs b/test/jwt/sign.test.mjs index e28e5b4b29..992b26d355 100644 --- a/test/jwt/sign.test.mjs +++ b/test/jwt/sign.test.mjs @@ -92,7 +92,7 @@ Promise.all([ .setProtectedHeader({ alg: 'HS256' }) [method](value) .sign(t.context.secret); - const { payload } = await compactVerify(jwt, async (header, token) => { + const { payload, key: resolvedKey } = await compactVerify(jwt, async (header, token) => { t.true('alg' in header); t.is(header.alg, 'HS256'); t.true('payload' in token); @@ -100,6 +100,7 @@ Promise.all([ t.true('signature' in token); return t.context.secret; }); + t.is(resolvedKey, t.context.secret); const claims = JSON.parse(new TextDecoder().decode(payload)); t.true(claim in claims); t.is(claims[claim], expected);