Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to tests/rest/client #12066

Merged
merged 4 commits into from
Feb 23, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12066.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `tests/rest/client`.
70 changes: 38 additions & 32 deletions tests/rest/client/test_auth.py
Original file line number Diff line number Diff line change
@@ -13,17 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from typing import Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from twisted.internet.defer import succeed
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource

import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client import account, auth, devices, login, logout, register
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict, UserID
from synapse.util import Clock

from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
@@ -33,11 +37,11 @@


class DummyRecaptchaChecker(UserInteractiveAuthChecker):
def __init__(self, hs):
def __init__(self, hs: HomeServer) -> None:
super().__init__(hs)
self.recaptcha_attempts = []
self.recaptcha_attempts: List[Tuple[dict, str]] = []

def check_auth(self, authdict, clientip):
def check_auth(self, authdict: dict, clientip: str) -> Any:
self.recaptcha_attempts.append((authdict, clientip))
return succeed(True)

@@ -50,7 +54,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
]
hijack_auth = False

def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:

config = self.default_config()

@@ -61,7 +65,7 @@ def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(config=config)
return hs

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.recaptcha_checker = DummyRecaptchaChecker(hs)
auth_handler = hs.get_auth_handler()
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
@@ -101,7 +105,7 @@ def recaptcha(
self.assertEqual(len(attempts), 1)
self.assertEqual(attempts[0][0]["response"], "a")

def test_fallback_captcha(self):
def test_fallback_captcha(self) -> None:
"""Ensure that fallback auth via a captcha works."""
# Returns a 401 as per the spec
channel = self.register(
@@ -132,7 +136,7 @@ def test_fallback_captcha(self):
# We're given a registered user.
self.assertEqual(channel.json_body["user_id"], "@user:test")

def test_complete_operation_unknown_session(self):
def test_complete_operation_unknown_session(self) -> None:
"""
Attempting to mark an invalid session as complete should error.
"""
@@ -165,7 +169,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
register.register_servlets,
]

def default_config(self):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()

# public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
@@ -182,12 +186,12 @@ def default_config(self):

return config

def create_resource_dict(self):
def create_resource_dict(self) -> Dict[str, Resource]:
resource_dict = super().create_resource_dict()
resource_dict.update(build_synapse_client_resource_tree(self.hs))
return resource_dict

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
self.device_id = "dev1"
@@ -229,7 +233,7 @@ def delete_devices(self, expected_response: int, body: JsonDict) -> FakeChannel:

return channel

def test_ui_auth(self):
def test_ui_auth(self) -> None:
"""
Test user interactive authentication outside of registration.
"""
@@ -259,7 +263,7 @@ def test_ui_auth(self):
},
)

def test_grandfathered_identifier(self):
def test_grandfathered_identifier(self) -> None:
"""Check behaviour without "identifier" dict

Synapse used to require clients to submit a "user" field for m.login.password
@@ -286,7 +290,7 @@ def test_grandfathered_identifier(self):
},
)

def test_can_change_body(self):
def test_can_change_body(self) -> None:
"""
The client dict can be modified during the user interactive authentication session.

@@ -325,7 +329,7 @@ def test_can_change_body(self):
},
)

def test_cannot_change_uri(self):
def test_cannot_change_uri(self) -> None:
"""
The initial requested URI cannot be modified during the user interactive authentication session.
"""
@@ -362,7 +366,7 @@ def test_cannot_change_uri(self):
)

@unittest.override_config({"ui_auth": {"session_timeout": "5s"}})
def test_can_reuse_session(self):
def test_can_reuse_session(self) -> None:
"""
The session can be reused if configured.

@@ -409,7 +413,7 @@ def test_can_reuse_session(self):

@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_ui_auth_via_sso(self):
def test_ui_auth_via_sso(self) -> None:
"""Test a successful UI Auth flow via SSO

This includes:
@@ -452,7 +456,7 @@ def test_ui_auth_via_sso(self):

@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_does_not_offer_password_for_sso_user(self):
def test_does_not_offer_password_for_sso_user(self) -> None:
login_resp = self.helper.login_via_oidc("username")
user_tok = login_resp["access_token"]
device_id = login_resp["device_id"]
@@ -464,7 +468,7 @@ def test_does_not_offer_password_for_sso_user(self):
flows = channel.json_body["flows"]
self.assertEqual(flows, [{"stages": ["m.login.sso"]}])

def test_does_not_offer_sso_for_password_user(self):
def test_does_not_offer_sso_for_password_user(self) -> None:
channel = self.delete_device(
self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
)
@@ -474,7 +478,7 @@ def test_does_not_offer_sso_for_password_user(self):

@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_offers_both_flows_for_upgraded_user(self):
def test_offers_both_flows_for_upgraded_user(self) -> None:
"""A user that had a password and then logged in with SSO should get both flows"""
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
self.assertEqual(login_resp["user_id"], self.user)
@@ -491,7 +495,7 @@ def test_offers_both_flows_for_upgraded_user(self):

@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_ui_auth_fails_for_incorrect_sso_user(self):
def test_ui_auth_fails_for_incorrect_sso_user(self) -> None:
"""If the user tries to authenticate with the wrong SSO user, they get an error"""
# log the user in
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
@@ -534,7 +538,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
]
hijack_auth = False

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)

@@ -548,7 +552,7 @@ def use_refresh_token(self, refresh_token: str) -> FakeChannel:
{"refresh_token": refresh_token},
)

def is_access_token_valid(self, access_token) -> bool:
def is_access_token_valid(self, access_token: str) -> bool:
"""
Checks whether an access token is valid, returning whether it is or not.
"""
@@ -561,7 +565,7 @@ def is_access_token_valid(self, access_token) -> bool:

return code == HTTPStatus.OK

def test_login_issue_refresh_token(self):
def test_login_issue_refresh_token(self) -> None:
"""
A login response should include a refresh_token only if asked.
"""
@@ -591,7 +595,7 @@ def test_login_issue_refresh_token(self):
self.assertIn("refresh_token", login_with_refresh.json_body)
self.assertIn("expires_in_ms", login_with_refresh.json_body)

def test_register_issue_refresh_token(self):
def test_register_issue_refresh_token(self) -> None:
"""
A register response should include a refresh_token only if asked.
"""
@@ -627,7 +631,7 @@ def test_register_issue_refresh_token(self):
self.assertIn("refresh_token", register_with_refresh.json_body)
self.assertIn("expires_in_ms", register_with_refresh.json_body)

def test_token_refresh(self):
def test_token_refresh(self) -> None:
"""
A refresh token can be used to issue a new access token.
"""
@@ -665,7 +669,7 @@ def test_token_refresh(self):
)

@override_config({"refreshable_access_token_lifetime": "1m"})
def test_refreshable_access_token_expiration(self):
def test_refreshable_access_token_expiration(self) -> None:
"""
The access token should have some time as specified in the config.
"""
@@ -722,7 +726,9 @@ def test_refreshable_access_token_expiration(self):
"nonrefreshable_access_token_lifetime": "10m",
}
)
def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self):
def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(
self,
) -> None:
"""
Tests that the expiry times for refreshable and non-refreshable access
tokens can be different.
@@ -782,7 +788,7 @@ def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self)
@override_config(
{"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
)
def test_refresh_token_expiry(self):
def test_refresh_token_expiry(self) -> None:
"""
The refresh token can be configured to have a limited lifetime.
When that lifetime has ended, the refresh token can no longer be used to
@@ -834,7 +840,7 @@ def test_refresh_token_expiry(self):
"session_lifetime": "3m",
}
)
def test_ultimate_session_expiry(self):
def test_ultimate_session_expiry(self) -> None:
"""
The session can be configured to have an ultimate, limited lifetime.
"""
@@ -882,7 +888,7 @@ def test_ultimate_session_expiry(self):
refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result
)

def test_refresh_token_invalidation(self):
def test_refresh_token_invalidation(self) -> None:
"""Refresh tokens are invalidated after first use of the next token.

A refresh token is considered invalid if:
@@ -987,7 +993,7 @@ def test_refresh_token_invalidation(self):
fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
)

def test_many_token_refresh(self):
def test_many_token_refresh(self) -> None:
"""
If a refresh is performed many times during a session, there shouldn't be
extra 'cruft' built up over time.
Loading