Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(JWT): Customised token verification #3695

Merged
merged 8 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@
(PY_RE, r".*R"),
(PY_OBJ, r"litestar.security.jwt.auth.TokenT"),
(PY_CLASS, "ExceptionToProblemDetailMapType"),
(PY_CLASS, "litestar.security.jwt.token.JWTDecodeOptions"),
]

# Warnings about missing references to those targets in the specified location will be ignored.
Expand Down
28 changes: 28 additions & 0 deletions docs/examples/security/jwt/custom_decode_payload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import dataclasses
from typing import Any, List, Optional, Sequence, Union

from litestar.security.jwt.token import JWTDecodeOptions, Token


@dataclasses.dataclass
class CustomToken(Token):
@classmethod
def decode_payload(
cls,
encoded_token: str,
secret: str,
algorithms: List[str],
issuer: Optional[List[str]] = None,
audience: Union[str, Sequence[str], None] = None,
options: Optional[JWTDecodeOptions] = None,
) -> Any:
payload = super().decode_payload(
encoded_token=encoded_token,
secret=secret,
algorithms=algorithms,
issuer=issuer,
audience=audience,
options=options,
)
payload["sub"] = payload["sub"].split("@", maxsplit=1)[1]
return payload
32 changes: 32 additions & 0 deletions docs/examples/security/jwt/verify_issuer_audience.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import dataclasses
import secrets
from typing import Any, Dict

from litestar import Litestar, Request, get
from litestar.connection import ASGIConnection
from litestar.security.jwt import JWTAuth, Token


@dataclasses.dataclass
class User:
id: str


async def retrieve_user_handler(token: Token, connection: ASGIConnection) -> User:
return User(id=token.sub)


jwt_auth = JWTAuth[User](
token_secret=secrets.token_hex(),
retrieve_user_handler=retrieve_user_handler,
accepted_audiences=["https://api.testserver.local"],
accepted_issuers=["https://auth.testserver.local"],
)


@get("/")
def handler(request: Request[User, Token, Any]) -> Dict[str, Any]:
return {"id": request.user.id}


app = Litestar([handler], middleware=[jwt_auth.middleware])
29 changes: 29 additions & 0 deletions docs/usage/security/jwt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,32 @@ conversions.
converting the token. To support more complex conversions, the
:meth:`~.security.jwt.Token.encode` and :meth:`~.security.jwt.Token.decode` methods
must be overwritten in the subclass.


Verifying issuer and audience
-----------------------------

To verify the JWT ``iss`` (*issuer*) and ``aud`` (*audience*) claim, a list of accepted
provinzkraut marked this conversation as resolved.
Show resolved Hide resolved
issuers or audiences can bet set on the authentication backend. When a JWT is decoded,
the issuer or audience on the token is compared to the list of accepted issuers /
audiences. If the value in the token does not match any value in the respective list,
a :exc:`NotAuthorizedException` will be raised, returning a response with a
``401 Unauthorized`` status.


.. literalinclude:: /examples/security/jwt/verify_issuer_audience.py
:caption: Verifying issuer and audience


Customizing token validation
----------------------------

Token decoding / validation can be further customized by overriding the
:meth:`~.security.jwt.Token.decode_payload` method. It will be called by
:meth:`~.security.jwt.Token.decode` with the encoded token string, and must return a
dictionary representing the decoded payload, which will then used by
:meth:`~.security.jwt.Token.decode` to construct an instance of the token class.


.. literalinclude:: /examples/security/jwt/custom_decode_payload.py
:caption: Customizing payload decoding
102 changes: 102 additions & 0 deletions litestar/security/jwt/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,27 @@ class BaseJWTAuth(Generic[UserType, TokenT], AbstractSecurityConfig[UserType, To
"""
token_cls: type[Token] = Token
"""Target type the JWT payload will be converted into"""
accepted_audiences: Sequence[str] | None = None
"""Audiences to accept when verifying the token. If given, and the audience in the
token does not match, a 401 response is returned
"""
accepted_issuers: Sequence[str] | None = None
"""Issuers to accept when verifying the token. If given, and the issuer in the
token does not match, a 401 response is returned
"""
require_claims: Sequence[str] | None = None
"""Require these claims to be present in the JWT payload. If any of those claims
is missing, a 401 response is returned
"""
verify_expiry: bool = True
"""Verify that the value of the ``exp`` (*expiration*) claim is in the future"""
verify_not_before: bool = True
"""Verify that the value of the ``nbf`` (*not before*) claim is in the past"""
strict_audience: bool = False
"""Verify that the value of the ``aud`` (*audience*) claim is a single value, and
not a list of values, and matches ``audience`` exactly. Requires that
``accepted_audiences`` is a sequence of length 1
"""

@property
def openapi_components(self) -> Components:
Expand Down Expand Up @@ -120,6 +141,12 @@ def middleware(self) -> DefineMiddleware:
scopes=self.scopes,
token_secret=self.token_secret,
token_cls=self.token_cls,
token_issuer=self.accepted_issuers,
token_audience=self.accepted_audiences,
require_claims=self.require_claims,
verify_expiry=self.verify_expiry,
verify_not_before=self.verify_not_before,
strict_audience=self.strict_audience,
)

def login(
Expand Down Expand Up @@ -290,6 +317,27 @@ class JWTAuth(Generic[UserType, TokenT], BaseJWTAuth[UserType, TokenT]):
"""
token_cls: type[Token] = Token
"""Target type the JWT payload will be converted into"""
accepted_audiences: Sequence[str] | None = None
"""Audiences to accept when verifying the token. If given, and the audience in the
token does not match, a 401 response is returned
"""
accepted_issuers: Sequence[str] | None = None
"""Issuers to accept when verifying the token. If given, and the issuer in the
token does not match, a 401 response is returned
"""
require_claims: Sequence[str] | None = None
"""Require these claims to be present in the JWT payload. If any of those claims
is missing, a 401 response is returned
"""
verify_expiry: bool = True
"""Verify that the value of the ``exp`` (*expiration*) claim is in the future"""
verify_not_before: bool = True
"""Verify that the value of the ``nbf`` (*not before*) claim is in the past"""
strict_audience: bool = False
"""Verify that the value of the ``aud`` (*audience*) claim is a single value, and
not a list of values, and matches ``audience`` exactly. Requires that
``accepted_audiences`` is a sequence of length 1
"""


@dataclass
Expand Down Expand Up @@ -370,6 +418,27 @@ class and adds support for passing JWT tokens ``HttpOnly`` cookies.
"""
token_cls: type[Token] = Token
"""Target type the JWT payload will be converted into"""
accepted_audiences: Sequence[str] | None = None
"""Audiences to accept when verifying the token. If given, and the audience in the
token does not match, a 401 response is returned
"""
accepted_issuers: Sequence[str] | None = None
"""Issuers to accept when verifying the token. If given, and the issuer in the
token does not match, a 401 response is returned
"""
require_claims: Sequence[str] | None = None
"""Require these claims to be present in the JWT payload. If any of those claims
is missing, a 401 response is returned
"""
verify_expiry: bool = True
"""Verify that the value of the ``exp`` (*expiration*) claim is in the future"""
verify_not_before: bool = True
"""Verify that the value of the ``nbf`` (*not before*) claim is in the past"""
strict_audience: bool = False
"""Verify that the value of the ``aud`` (*audience*) claim is a single value, and
not a list of values, and matches ``audience`` exactly. Requires that
``accepted_audiences`` is a sequence of length 1
"""

@property
def openapi_components(self) -> Components:
Expand Down Expand Up @@ -411,6 +480,12 @@ def middleware(self) -> DefineMiddleware:
scopes=self.scopes,
token_secret=self.token_secret,
token_cls=self.token_cls,
token_issuer=self.accepted_issuers,
token_audience=self.accepted_audiences,
require_claims=self.require_claims,
verify_expiry=self.verify_expiry,
verify_not_before=self.verify_not_before,
strict_audience=self.strict_audience,
)

def login(
Expand Down Expand Up @@ -579,6 +654,27 @@ class OAuth2PasswordBearerAuth(Generic[UserType, TokenT], BaseJWTAuth[UserType,
"""
token_cls: type[Token] = Token
"""Target type the JWT payload will be converted into"""
accepted_audiences: Sequence[str] | None = None
"""Audiences to accept when verifying the token. If given, and the audience in the
token does not match, a 401 response is returned
"""
accepted_issuers: Sequence[str] | None = None
"""Issuers to accept when verifying the token. If given, and the issuer in the
token does not match, a 401 response is returned
"""
require_claims: Sequence[str] | None = None
"""Require these claims to be present in the JWT payload. If any of those claims
is missing, a 401 response is returned
"""
verify_expiry: bool = True
"""Verify that the value of the ``exp`` (*expiration*) claim is in the future"""
verify_not_before: bool = True
"""Verify that the value of the ``nbf`` (*not before*) claim is in the past"""
strict_audience: bool = False
"""Verify that the value of the ``aud`` (*audience*) claim is a single value, and
not a list of values, and matches ``audience`` exactly. Requires that
``accepted_audiences`` is a sequence of length 1
"""

@property
def middleware(self) -> DefineMiddleware:
Expand All @@ -600,6 +696,12 @@ def middleware(self) -> DefineMiddleware:
scopes=self.scopes,
token_secret=self.token_secret,
token_cls=self.token_cls,
token_issuer=self.accepted_issuers,
token_audience=self.accepted_audiences,
require_claims=self.require_claims,
verify_expiry=self.verify_expiry,
verify_not_before=self.verify_not_before,
strict_audience=self.strict_audience,
)

@property
Expand Down
60 changes: 60 additions & 0 deletions litestar/security/jwt/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ class JWTAuthenticationMiddleware(AbstractAuthenticationMiddleware):
"retrieve_user_handler",
"token_secret",
"token_cls",
"token_audience",
"token_issuer",
"require_claims",
"verify_expiry",
"verify_not_before",
"strict_audience",
)

def __init__(
Expand All @@ -45,6 +51,12 @@ def __init__(
scopes: Scopes,
token_secret: str,
token_cls: type[Token] = Token,
token_audience: Sequence[str] | None = None,
token_issuer: Sequence[str] | None = None,
require_claims: Sequence[str] | None = None,
verify_expiry: bool = True,
verify_not_before: bool = True,
strict_audience: bool = False,
) -> None:
"""Check incoming requests for an encoded token in the auth header specified, and if present retrieve the user
from persistence using the provided function.
Expand All @@ -62,6 +74,18 @@ def __init__(
token_secret: Secret for decoding the JWT. This value should be equivalent to the secret used to
encode it.
token_cls: Token class used when encoding / decoding JWTs
token_audience: Verify the audience when decoding the token. If the audience
in the token does not match any audience given, raise a
:exc:`NotAuthorizedException`
token_issuer: Verify the issuer when decoding the token. If the issuer in
the token does not match any issuer given, raise a
:exc:`NotAuthorizedException`
require_claims: Require these claims to be present in the JWT payload
verify_expiry: Verify that the value of the ``exp`` (*expiration*) claim is in the future
verify_not_before: Verify that the value of the ``nbf`` (*not before*) claim is in the past
strict_audience: Verify that the value of the ``aud`` (*audience*) claim is a single value, and
not a list of values, and matches ``audience`` exactly. Requires that
``accepted_audiences`` is a sequence of length 1
"""
super().__init__(
app=app,
Expand All @@ -75,6 +99,12 @@ def __init__(
self.retrieve_user_handler = retrieve_user_handler
self.token_secret = token_secret
self.token_cls = token_cls
self.token_audience = token_audience
self.token_issuer = token_issuer
self.require_claims = require_claims
self.verify_expiry = verify_expiry
self.verify_not_before = verify_not_before
self.strict_audience = strict_audience

async def authenticate_request(self, connection: ASGIConnection[Any, Any, Any, Any]) -> AuthenticationResult:
"""Given an HTTP Connection, parse the JWT api key stored in the header and retrieve the user correlating to the
Expand Down Expand Up @@ -114,6 +144,12 @@ async def authenticate_token(
encoded_token=encoded_token,
secret=self.token_secret,
algorithm=self.algorithm,
audience=self.token_audience,
issuer=self.token_issuer,
require_claims=self.require_claims,
verify_exp=self.verify_expiry,
verify_nbf=self.verify_not_before,
strict_audience=self.strict_audience,
)

user = await self.retrieve_user_handler(token, connection)
Expand Down Expand Up @@ -142,6 +178,12 @@ def __init__(
scopes: Scopes,
token_secret: str,
token_cls: type[Token] = Token,
token_audience: Sequence[str] | None = None,
token_issuer: Sequence[str] | None = None,
require_claims: Sequence[str] | None = None,
verify_expiry: bool = True,
verify_not_before: bool = True,
strict_audience: bool = False,
) -> None:
"""Check incoming requests for an encoded token in the auth header or cookie name specified, and if present
retrieves the user from persistence using the provided function.
Expand All @@ -160,6 +202,18 @@ def __init__(
token_secret: Secret for decoding the JWT. This value should be equivalent to the secret used to
encode it.
token_cls: Token class used when encoding / decoding JWTs
token_audience: Verify the audience when decoding the token. If the audience
in the token does not match any audience given, raise a
:exc:`NotAuthorizedException`
token_issuer: Verify the issuer when decoding the token. If the issuer in
the token does not match any issuer given, raise a
:exc:`NotAuthorizedException`
require_claims: Require these claims to be present in the JWT payload
verify_expiry: Verify that the value of the ``exp`` (*expiration*) claim is in the future
verify_not_before: Verify that the value of the ``nbf`` (*not before*) claim is in the past
strict_audience: Verify that the value of the ``aud`` (*audience*) claim is a single value, and
not a list of values, and matches ``audience`` exactly. Requires that
``accepted_audiences`` is a sequence of length 1
"""
super().__init__(
algorithm=algorithm,
Expand All @@ -172,6 +226,12 @@ def __init__(
scopes=scopes,
token_secret=token_secret,
token_cls=token_cls,
token_audience=token_audience,
token_issuer=token_issuer,
require_claims=require_claims,
verify_expiry=verify_expiry,
verify_not_before=verify_not_before,
strict_audience=strict_audience,
)
self.auth_cookie_key = auth_cookie_key

Expand Down
Loading
Loading