Skip to content

Commit

Permalink
Add decode_payload method
Browse files Browse the repository at this point in the history
  • Loading branch information
provinzkraut committed Aug 26, 2024
1 parent 4df8fd5 commit b1d176d
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 11 deletions.
28 changes: 28 additions & 0 deletions docs/examples/security/jwt/custom_decode_payload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import dataclasses
from typing import Any, List, Optional, Sequence, Union

from litestar.security.jwt.token import JWTDecodeOptions, Token


@dataclasses.dataclass
class CustomToken(Token):
@classmethod
def decode_payload(
cls,
encoded_token: str,
secret: str,
algorithms: List[str],
issuer: Optional[List[str]] = None,
audience: Union[str, Sequence[str], None] = None,
options: Optional[JWTDecodeOptions] = None,
) -> Any:
payload = super().decode_payload(
encoded_token=encoded_token,
secret=secret,
algorithms=algorithms,
issuer=issuer,
audience=audience,
options=options,
)
payload["sub"] = payload["sub"].split("@", maxsplit=1)[1]
return payload
14 changes: 14 additions & 0 deletions docs/usage/security/jwt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,17 @@ a :exc:`NotAuthorizedException` will be raised, returning a response with a

.. literalinclude:: /examples/security/jwt/verify_issuer_audience.py
:caption: Verifying issuer and audience


Customizing token validation
----------------------------

Token decoding / validation can be further customized by overriding the
:meth:`~.security.jwt.Token.decode_payload` method. It will be called by
:meth:`~.security.jwt.Token.decode` with the encoded token string, and must return a
dictionary representing the decoded payload, which will then used by
:meth:`~.security.jwt.Token.decode` to construct an instance of the token class.


.. literalinclude:: /examples/security/jwt/custom_decode_payload.py
:caption: Customizing payload decoding
51 changes: 41 additions & 10 deletions litestar/security/jwt/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dataclasses
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, TypedDict

import jwt
import msgspec
Expand All @@ -13,7 +13,10 @@
if TYPE_CHECKING:
from typing_extensions import Self

__all__ = ("Token",)
__all__ = (
"Token",
"JWTDecodeOptions",
)


def _normalize_datetime(value: datetime) -> datetime:
Expand All @@ -31,6 +34,15 @@ def _normalize_datetime(value: datetime) -> datetime:
return value.replace(microsecond=0)


class JWTDecodeOptions(TypedDict, total=False):
verify_aud: bool
verify_iss: bool
verify_exp: bool
verify_nbf: bool
strict_aud: bool
require: list[str]


@dataclass
class Token:
"""JWT Token DTO."""
Expand Down Expand Up @@ -70,6 +82,26 @@ def __post_init__(self) -> None:
else:
raise ImproperlyConfiguredException("iat must be a current or past time")

@classmethod
def decode_payload(
cls,
encoded_token: str,
secret: str,
algorithms: list[str],
issuer: list[str] | None = None,
audience: str | Sequence[str] | None = None,
options: JWTDecodeOptions | None = None,
) -> dict[str, Any]:
"""Decode and verify the JWT and return its payload"""
return jwt.decode(
jwt=encoded_token,
key=secret,
algorithms=algorithms,
issuer=issuer,
audience=audience,
options=options, # type: ignore[arg-type]
)

@classmethod
def decode(
cls,
Expand All @@ -83,7 +115,7 @@ def decode(
verify_nbf: bool = True,
strict_audience: bool = False,
) -> Self:
"""Decode a passed in token string and returns a Token instance.
"""Decode a passed in token string and return a Token instance.
Args:
encoded_token: A base64 string containing an encoded JWT.
Expand Down Expand Up @@ -112,7 +144,7 @@ def decode(
NotAuthorizedException: If the token is invalid.
"""

options: dict[str, Any] = {
options: JWTDecodeOptions = {
"verify_aud": bool(audience),
"verify_iss": bool(issuer),
}
Expand All @@ -132,12 +164,12 @@ def decode(
audience = audience[0]

try:
payload: dict[str, Any] = jwt.decode(
jwt=encoded_token,
key=secret,
payload = cls.decode_payload(
encoded_token=encoded_token,
secret=secret,
algorithms=[algorithm],
issuer=list(issuer) if issuer else None,
audience=audience,
issuer=list(issuer) if issuer else None,
options=options,
)
# msgspec can do these conversions as well, but to keep backwards
Expand All @@ -152,8 +184,7 @@ def decode(
return msgspec.convert(payload, cls, strict=False)
except (
KeyError,
jwt.DecodeError,
jwt.exceptions.InvalidAlgorithmError,
jwt.exceptions.InvalidTokenError,
ImproperlyConfiguredException,
msgspec.ValidationError,
) as e:
Expand Down
30 changes: 29 additions & 1 deletion tests/unit/test_security/test_jwt/test_token.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import dataclasses
import secrets
import sys
from dataclasses import asdict
from datetime import datetime, timedelta, timezone
from typing import Any
from typing import Any, Sequence
from uuid import uuid4

import jwt
Expand All @@ -14,6 +15,7 @@

from litestar.exceptions import ImproperlyConfiguredException, NotAuthorizedException
from litestar.security.jwt import Token
from litestar.security.jwt.token import JWTDecodeOptions


@pytest.mark.parametrize("algorithm", ["HS256", "HS384", "HS512"])
Expand Down Expand Up @@ -194,3 +196,29 @@ def test_strict_aud_with_one_element_sequence(audience: str | list[str]) -> None
audience=["foo"],
strict_audience=True,
)


def test_custom_decode_payload() -> None:
@dataclasses.dataclass
class CustomToken(Token):
@classmethod
def decode_payload(
cls,
encoded_token: str,
secret: str,
algorithms: list[str],
issuer: list[str] | None = None,
audience: str | Sequence[str] | None = None,
options: JWTDecodeOptions | None = None,
) -> Any:
payload = super().decode_payload(
encoded_token=encoded_token,
secret=secret,
algorithms=algorithms,
)
payload["sub"] = "some-random-value"
return payload

_secret = secrets.token_hex()
encoded = CustomToken(exp=datetime.now() + timedelta(days=1), sub="foo").encode(_secret, "HS256")
assert CustomToken.decode(encoded, secret=_secret, algorithm="HS256").sub == "some-random-value"

0 comments on commit b1d176d

Please sign in to comment.