Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve testing #688

Merged
merged 15 commits into from
Jun 21, 2023
Merged
98 changes: 51 additions & 47 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,19 @@ def test_get_header(self):
)
self.assertEqual(self.backend.get_header(request), self.fake_header)

# Should work with the x_access_token
with override_api_settings(AUTH_HEADER_NAME="HTTP_X_ACCESS_TOKEN"):
# Should pull correct header off request when using X_ACCESS_TOKEN
request = self.factory.get(
"/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header
)
self.assertEqual(self.backend.get_header(request), self.fake_header)

# Should work for unicode headers when using
request = self.factory.get(
"/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header.decode("utf-8")
)
self.assertEqual(self.backend.get_header(request), self.fake_header)
@override_api_settings(AUTH_HEADER_NAME="HTTP_X_ACCESS_TOKEN")
def test_get_header_x_access_token(self):
# Should pull correct header off request when using X_ACCESS_TOKEN
request = self.factory.get("/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header)
self.assertEqual(self.backend.get_header(request), self.fake_header)

# Should work for unicode headers when using
request = self.factory.get(
"/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header.decode("utf-8")
)
self.assertEqual(self.backend.get_header(request), self.fake_header)

def test_get_raw_token(self):
# Should return None if header lacks correct type keyword
with override_api_settings(AUTH_HEADER_TYPES="JWT"):
reload(authentication)
self.assertIsNone(self.backend.get_raw_token(self.fake_header))
reload(authentication)

# Should return None if an empty AUTHORIZATION header is sent
Expand All @@ -74,14 +68,21 @@ def test_get_raw_token(self):
# Otherwise, should return unvalidated token in header
self.assertEqual(self.backend.get_raw_token(self.fake_header), self.fake_token)

@override_api_settings(AUTH_HEADER_TYPES="JWT")
def test_get_raw_token_incorrect_header_keyword(self):
# Should return None if header lacks correct type keyword
# AUTH_HEADER_TYPES is "JWT", but header is "Bearer"
reload(authentication)
self.assertIsNone(self.backend.get_raw_token(self.fake_header))

@override_api_settings(AUTH_HEADER_TYPES=("JWT", "Bearer"))
def test_get_raw_token_multi_header_keyword(self):
# Should return token if header has one of many valid token types
with override_api_settings(AUTH_HEADER_TYPES=("JWT", "Bearer")):
reload(authentication)
self.assertEqual(
self.backend.get_raw_token(self.fake_header),
self.fake_token,
)
reload(authentication)
self.assertEqual(
self.backend.get_raw_token(self.fake_header),
self.fake_token,
)

def test_get_validated_token(self):
# Should raise InvalidToken if token not valid
Expand All @@ -96,36 +97,39 @@ def test_get_validated_token(self):
self.backend.get_validated_token(str(token)).payload, token.payload
)

@override_api_settings(
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",),
)
def test_get_validated_token_reject_unknown_token(self):
# Should not accept tokens not included in AUTH_TOKEN_CLASSES
sliding_token = SlidingToken()
with override_api_settings(
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",)
):
with self.assertRaises(InvalidToken) as e:
self.backend.get_validated_token(str(sliding_token))

messages = e.exception.detail["messages"]
self.assertEqual(1, len(messages))
self.assertEqual(
{
"token_class": "AccessToken",
"token_type": "access",
"message": "Token has wrong type",
},
messages[0],
)
with self.assertRaises(InvalidToken) as e:
self.backend.get_validated_token(str(sliding_token))

messages = e.exception.detail["messages"]
self.assertEqual(1, len(messages))
self.assertEqual(
{
"token_class": "AccessToken",
"token_type": "access",
"message": "Token has wrong type",
},
messages[0],
)

@override_api_settings(
AUTH_TOKEN_CLASSES=(
"rest_framework_simplejwt.tokens.AccessToken",
"rest_framework_simplejwt.tokens.SlidingToken",
),
)
def test_get_validated_token_accept_known_token(self):
# Should accept tokens included in AUTH_TOKEN_CLASSES
access_token = AccessToken()
sliding_token = SlidingToken()
with override_api_settings(
AUTH_TOKEN_CLASSES=(
"rest_framework_simplejwt.tokens.AccessToken",
"rest_framework_simplejwt.tokens.SlidingToken",
)
):
self.backend.get_validated_token(str(access_token))
self.backend.get_validated_token(str(sliding_token))

self.backend.get_validated_token(str(access_token))
self.backend.get_validated_token(str(sliding_token))

def test_get_user(self):
payload = {"some_other_id": "foo"}
Expand Down
42 changes: 20 additions & 22 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from django.contrib.auth import get_user_model
from django.urls import reverse
from rest_framework.status import HTTP_200_OK, HTTP_401_UNAUTHORIZED

from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.tokens import AccessToken
Expand All @@ -26,7 +27,7 @@ def setUp(self):
def test_no_authorization(self):
res = self.view_get()

self.assertEqual(res.status_code, 401)
self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED)
self.assertIn("credentials were not provided", res.data["detail"])

def test_wrong_auth_type(self):
Expand All @@ -43,9 +44,12 @@ def test_wrong_auth_type(self):

res = self.view_get()

self.assertEqual(res.status_code, 401)
self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED)
self.assertIn("credentials were not provided", res.data["detail"])

@override_api_settings(
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",),
)
def test_expired_token(self):
old_lifetime = AccessToken.lifetime
AccessToken.lifetime = timedelta(seconds=0)
Expand All @@ -63,14 +67,14 @@ def test_expired_token(self):
access = res.data["access"]
self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access)

with override_api_settings(
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",)
):
res = self.view_get()
res = self.view_get()

self.assertEqual(res.status_code, 401)
self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED)
self.assertEqual("token_not_valid", res.data["code"])

@override_api_settings(
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.SlidingToken",),
)
def test_user_can_get_sliding_token_and_use_it(self):
res = self.client.post(
reverse("token_obtain_sliding"),
Expand All @@ -83,14 +87,14 @@ def test_user_can_get_sliding_token_and_use_it(self):
token = res.data["token"]
self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], token)

with override_api_settings(
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.SlidingToken",)
):
res = self.view_get()
res = self.view_get()

self.assertEqual(res.status_code, 200)
self.assertEqual(res.status_code, HTTP_200_OK)
self.assertEqual(res.data["foo"], "bar")

@override_api_settings(
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",),
)
def test_user_can_get_access_and_refresh_tokens_and_use_them(self):
res = self.client.post(
reverse("token_obtain_pair"),
Expand All @@ -105,12 +109,9 @@ def test_user_can_get_access_and_refresh_tokens_and_use_them(self):

self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access)

with override_api_settings(
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",)
):
res = self.view_get()
res = self.view_get()

self.assertEqual(res.status_code, 200)
self.assertEqual(res.status_code, HTTP_200_OK)
self.assertEqual(res.data["foo"], "bar")

res = self.client.post(
Expand All @@ -122,10 +123,7 @@ def test_user_can_get_access_and_refresh_tokens_and_use_them(self):

self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access)

with override_api_settings(
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",)
):
res = self.view_get()
res = self.view_get()

self.assertEqual(res.status_code, 200)
self.assertEqual(res.status_code, HTTP_200_OK)
self.assertEqual(res.data["foo"], "bar")
30 changes: 14 additions & 16 deletions tests/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ def test_it_should_return_access_token_if_everything_ok(self):
access["exp"], datetime_to_epoch(now + api_settings.ACCESS_TOKEN_LIFETIME)
)

@override_api_settings(
ROTATE_REFRESH_TOKENS=True,
BLACKLIST_AFTER_ROTATION=False,
)
def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self):
refresh = RefreshToken()

Expand All @@ -298,14 +302,9 @@ def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self):

now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

with override_api_settings(
ROTATE_REFRESH_TOKENS=True, BLACKLIST_AFTER_ROTATION=False
):
with patch(
"rest_framework_simplejwt.tokens.aware_utcnow"
) as fake_aware_utcnow:
fake_aware_utcnow.return_value = now
self.assertTrue(ser.is_valid())
with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow:
fake_aware_utcnow.return_value = now
self.assertTrue(ser.is_valid())

access = AccessToken(ser.validated_data["access"])
new_refresh = RefreshToken(ser.validated_data["refresh"])
Expand All @@ -324,6 +323,10 @@ def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self):
datetime_to_epoch(now + api_settings.REFRESH_TOKEN_LIFETIME),
)

@override_api_settings(
ROTATE_REFRESH_TOKENS=True,
BLACKLIST_AFTER_ROTATION=True,
)
def test_it_should_blacklist_refresh_token_if_tokens_should_be_rotated_and_blacklisted(
self,
):
Expand All @@ -342,14 +345,9 @@ def test_it_should_blacklist_refresh_token_if_tokens_should_be_rotated_and_black

now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

with override_api_settings(
ROTATE_REFRESH_TOKENS=True, BLACKLIST_AFTER_ROTATION=True
):
with patch(
"rest_framework_simplejwt.tokens.aware_utcnow"
) as fake_aware_utcnow:
fake_aware_utcnow.return_value = now
self.assertTrue(ser.is_valid())
with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow:
fake_aware_utcnow.return_value = now
self.assertTrue(ser.is_valid())

access = AccessToken(ser.validated_data["access"])
new_refresh = RefreshToken(ser.validated_data["refresh"])
Expand Down
20 changes: 10 additions & 10 deletions tests/test_token_blacklist.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,25 +237,25 @@ def setUp(self):

super().setUp()

@override_api_settings(BLACKLIST_AFTER_ROTATION=True)
def test_token_verify_serializer_should_honour_blacklist_if_blacklisting_enabled(
self,
):
with override_api_settings(BLACKLIST_AFTER_ROTATION=True):
refresh_token = RefreshToken.for_user(self.user)
refresh_token.blacklist()
refresh_token = RefreshToken.for_user(self.user)
refresh_token.blacklist()

serializer = TokenVerifySerializer(data={"token": str(refresh_token)})
self.assertFalse(serializer.is_valid())
serializer = TokenVerifySerializer(data={"token": str(refresh_token)})
self.assertFalse(serializer.is_valid())

@override_api_settings(BLACKLIST_AFTER_ROTATION=False)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really needed since BLACKLIST_AFTER_ROTATION is default to False, maybe for explicitly sake.

def test_token_verify_serializer_should_not_honour_blacklist_if_blacklisting_not_enabled(
self,
):
with override_api_settings(BLACKLIST_AFTER_ROTATION=False):
refresh_token = RefreshToken.for_user(self.user)
refresh_token.blacklist()
refresh_token = RefreshToken.for_user(self.user)
refresh_token.blacklist()

serializer = TokenVerifySerializer(data={"token": str(refresh_token)})
self.assertTrue(serializer.is_valid())
serializer = TokenVerifySerializer(data={"token": str(refresh_token)})
self.assertTrue(serializer.is_valid())


class TestBigAutoFieldIDMigration(MigrationTestCase):
Expand Down
34 changes: 18 additions & 16 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ class TestToken(TestCase):
def setUp(self):
self.token = MyToken()

@classmethod
def setUpTestData(cls):
cls.username = "test_user"
cls.user = User.objects.create_user(
username=cls.username,
password="test_password",
)

def test_init_no_token_type_or_lifetime(self):
class MyTestToken(Token):
pass
Expand Down Expand Up @@ -225,14 +233,14 @@ def test_set_jti(self):
self.assertIn("jti", token)
self.assertNotEqual(old_jti, token["jti"])

@override_api_settings(JTI_CLAIM=None)
def test_optional_jti(self):
with override_api_settings(JTI_CLAIM=None):
token = MyToken()
token = MyToken()
self.assertNotIn("jti", token)

@override_api_settings(TOKEN_TYPE_CLAIM=None)
def test_optional_type_token(self):
with override_api_settings(TOKEN_TYPE_CLAIM=None):
token = MyToken()
token = MyToken()
self.assertNotIn("type", token)

def test_set_exp(self):
Expand Down Expand Up @@ -355,25 +363,19 @@ def test_check_token_if_wrong_type_leeway(self):
token.token_backend.leeway = 0

def test_for_user(self):
username = "test_user"
user = User.objects.create_user(
username=username,
password="test_password",
)
token = MyToken.for_user(self.user)

token = MyToken.for_user(user)

user_id = getattr(user, api_settings.USER_ID_FIELD)
user_id = getattr(self.user, api_settings.USER_ID_FIELD)
if not isinstance(user_id, int):
user_id = str(user_id)

self.assertEqual(token[api_settings.USER_ID_CLAIM], user_id)

@override_api_settings(USER_ID_FIELD="username")
def test_for_user_with_username(self):
# Test with non-int user id
with override_api_settings(USER_ID_FIELD="username"):
token = MyToken.for_user(user)

self.assertEqual(token[api_settings.USER_ID_CLAIM], username)
token = MyToken.for_user(self.user)
self.assertEqual(token[api_settings.USER_ID_CLAIM], self.username)

def test_get_token_backend(self):
token = MyToken()
Expand Down
Loading