Skip to content

Commit

Permalink
Caching signing key (#859)
Browse files Browse the repository at this point in the history
Co-authored-by: henry_fool <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 14, 2025
1 parent a01cb8b commit a041ccf
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions rest_framework_simplejwt/backends.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from collections.abc import Iterable
from datetime import timedelta
from functools import cached_property
from typing import Any, Optional, Union

import jwt
Expand Down Expand Up @@ -64,6 +65,21 @@ def __init__(
self.leeway = leeway
self.json_encoder = json_encoder

@cached_property
def prepared_signing_key(self) -> Any:
return self._prepare_key(self.signing_key)

@cached_property
def prepared_verifying_key(self) -> Any:
return self._prepare_key(self.verifying_key)

def _prepare_key(self, key: Optional[str]) -> Any:
# Support for PyJWT 1.7.1 or empty signing key
if key is None or not getattr(jwt.PyJWS, "get_algorithm_by_name", None):
return key
jws_alg = jwt.PyJWS().get_algorithm_by_name(self.algorithm)
return jws_alg.prepare_key(key)

def _validate_algorithm(self, algorithm: str) -> None:
"""
Ensure that the nominated algorithm is recognized, and that cryptography is installed for those
Expand Down Expand Up @@ -98,17 +114,17 @@ def get_leeway(self) -> timedelta:
)
)

def get_verifying_key(self, token: Token) -> Optional[str]:
def get_verifying_key(self, token: Token) -> Any:
if self.algorithm.startswith("HS"):
return self.signing_key
return self.prepared_signing_key

if self.jwks_client:
try:
return self.jwks_client.get_signing_key_from_jwt(token).key
except PyJWKClientError as ex:
raise TokenBackendError(_("Token is invalid")) from ex

return self.verifying_key
return self.prepared_verifying_key

def encode(self, payload: dict[str, Any]) -> str:
"""
Expand All @@ -122,7 +138,7 @@ def encode(self, payload: dict[str, Any]) -> str:

token = jwt.encode(
jwt_payload,
self.signing_key,
self.prepared_signing_key,
algorithm=self.algorithm,
json_encoder=self.json_encoder,
)
Expand Down

0 comments on commit a041ccf

Please sign in to comment.