Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Feature parity for a_decode_token and decode_token #616

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 38 additions & 21 deletions src/keycloak/keycloak_openid.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class to handle authentication and token manipulation.
"""

import json
from typing import Optional
from typing import Optional, Union

from jwcrypto import jwk, jwt

Expand Down Expand Up @@ -581,6 +581,33 @@ def introspect(self, token, rpt=None, token_type_hint=None):
)
return raise_error_from_response(data_raw, KeycloakPostError)

@staticmethod
def _verify_token(token, key: Union[jwk.JWK, jwk.JWKSet, None], **kwargs):
"""Decode and optionally validate a token.

:param token: The token to verify
:type token: str
:param key: Which key should be used for validation.
If not provided, the validation is not performed and the token is implicitly valid.
:type key: Union[jwk.JWK, jwk.JWKSet, None]
:param kwargs: Additional keyword arguments for jwcrypto's JWT object
:type kwargs: dict
:returns: Decoded token
"""
# keep the function free of IO
# this way it can be used by `decode_token` and `a_decode_token`

if key is not None:
leeway = kwargs.pop("leeway", 60)
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.leeway = leeway
full_jwt.validate(key)
return jwt.json_decode(full_jwt.claims)
else:
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.token.objects["valid"] = True
return json.loads(full_jwt.token.payload.decode("utf-8"))

def decode_token(self, token, validate: bool = True, **kwargs):
"""Decode user token.

Expand All @@ -603,26 +630,19 @@ def decode_token(self, token, validate: bool = True, **kwargs):
:returns: Decoded token
:rtype: dict
"""
key = kwargs.pop("key", None)
if validate:
if "key" not in kwargs:
if key is None:
key = (
"-----BEGIN PUBLIC KEY-----\n"
+ self.public_key()
+ "\n-----END PUBLIC KEY-----"
)
key = jwk.JWK.from_pem(key.encode("utf-8"))
kwargs["key"] = key

key = kwargs.pop("key")
leeway = kwargs.pop("leeway", 60)
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.leeway = leeway
full_jwt.validate(key)
return jwt.json_decode(full_jwt.claims)
else:
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.token.objects["valid"] = True
return json.loads(full_jwt.token.payload.decode("utf-8"))
key = None

return self._verify_token(token, key, **kwargs)

def load_authorization_config(self, path):
"""Load Keycloak settings (authorization).
Expand Down Expand Up @@ -1254,22 +1274,19 @@ async def a_decode_token(self, token, validate: bool = True, **kwargs):
:returns: Decoded token
:rtype: dict
"""
key = kwargs.pop("key", None)
if validate:
if "key" not in kwargs:
if key is None:
key = (
"-----BEGIN PUBLIC KEY-----\n"
+ await self.a_public_key()
+ "\n-----END PUBLIC KEY-----"
)
key = jwk.JWK.from_pem(key.encode("utf-8"))
kwargs["key"] = key

full_jwt = jwt.JWT(jwt=token, **kwargs)
return jwt.json_decode(full_jwt.claims)
else:
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.token.objects["valid"] = True
return json.loads(full_jwt.token.payload.decode("utf-8"))
key = None

return self._verify_token(token, key, **kwargs)

async def a_load_authorization_config(self, path):
"""Load Keycloak settings (authorization) asynchronously.
Expand Down
75 changes: 74 additions & 1 deletion tests/test_keycloak_openid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import Tuple
from unittest import mock

import jwcrypto.jwk
import jwcrypto.jws
import pytest

from keycloak import KeycloakAdmin, KeycloakOpenID
Expand Down Expand Up @@ -317,6 +319,39 @@ def test_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
assert decoded_refresh_token["typ"] == "Refresh", decoded_refresh_token


def test_decode_token_invalid_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
"""Test decode token with an invalid token.

:param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials
:type oid_with_credentials: Tuple[KeycloakOpenID, str, str]
"""
oid, username, password = oid_with_credentials
token = oid.token(username=username, password=password)
access_token = token["access_token"]
decoded_access_token = oid.decode_token(token=access_token)

key = oid.public_key()
key = "-----BEGIN PUBLIC KEY-----\n" + key + "\n-----END PUBLIC KEY-----"
key = jwcrypto.jwk.JWK.from_pem(key.encode("utf-8"))

invalid_access_token = access_token + "a"
with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
decoded_invalid_access_token = oid.decode_token(token=invalid_access_token, validate=True)

with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
decoded_invalid_access_token = oid.decode_token(
token=invalid_access_token, validate=True, key=key
)

decoded_invalid_access_token = oid.decode_token(token=invalid_access_token, validate=False)
assert decoded_access_token == decoded_invalid_access_token

decoded_invalid_access_token = oid.decode_token(
token=invalid_access_token, validate=False, key=key
)
assert decoded_access_token == decoded_invalid_access_token


def test_load_authorization_config(oid_with_credentials_authz: Tuple[KeycloakOpenID, str, str]):
"""Test load authorization config.

Expand Down Expand Up @@ -765,7 +800,7 @@ async def test_a_introspect(oid_with_credentials: Tuple[KeycloakOpenID, str, str

@pytest.mark.asyncio
async def test_a_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
"""Test decode token.
"""Test decode token asynchronously.

:param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials
:type oid_with_credentials: Tuple[KeycloakOpenID, str, str]
Expand All @@ -781,6 +816,44 @@ async def test_a_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, s
assert decoded_refresh_token["typ"] == "Refresh", decoded_refresh_token


@pytest.mark.asyncio
async def test_a_decode_token_invalid_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
"""Test decode token asynchronously an invalid token.

:param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials
:type oid_with_credentials: Tuple[KeycloakOpenID, str, str]
"""
oid, username, password = oid_with_credentials
token = await oid.a_token(username=username, password=password)
access_token = token["access_token"]
decoded_access_token = await oid.a_decode_token(token=access_token)

key = await oid.a_public_key()
key = "-----BEGIN PUBLIC KEY-----\n" + key + "\n-----END PUBLIC KEY-----"
key = jwcrypto.jwk.JWK.from_pem(key.encode("utf-8"))

invalid_access_token = access_token + "a"
with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
decoded_invalid_access_token = await oid.a_decode_token(
token=invalid_access_token, validate=True
)

with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
decoded_invalid_access_token = await oid.a_decode_token(
token=invalid_access_token, validate=True, key=key
)

decoded_invalid_access_token = await oid.a_decode_token(
token=invalid_access_token, validate=False
)
assert decoded_access_token == decoded_invalid_access_token

decoded_invalid_access_token = await oid.a_decode_token(
token=invalid_access_token, validate=False, key=key
)
assert decoded_access_token == decoded_invalid_access_token


@pytest.mark.asyncio
async def test_a_load_authorization_config(
oid_with_credentials_authz: Tuple[KeycloakOpenID, str, str]
Expand Down
Loading