Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints for tests/unittest.py. #12347

Merged
merged 9 commits into from
Apr 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12347.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type annotations for `tests/unittest.py`.
1 change: 0 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ exclude = (?x)
|tests/test_server.py
|tests/test_state.py
|tests/test_terms_auth.py
|tests/unittest.py
|tests/util/caches/test_cached_call.py
|tests/util/caches/test_deferred_cache.py
|tests/util/caches/test_descriptors.py
Expand Down
6 changes: 4 additions & 2 deletions tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,10 @@ def test_self_signing_key_doesnt_show_up_as_device(self) -> None:
res = e.value.code
self.assertEqual(res, 400)

res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}})
query_res = self.get_success(
self.handler.query_local_devices({local_user: None})
)
self.assertDictEqual(query_res, {local_user: {}})

def test_upload_signatures(self) -> None:
"""should check signatures that are uploaded"""
Expand Down
5 changes: 3 additions & 2 deletions tests/handlers/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ def test_backfill_floating_outlier_membership_auth(self) -> None:
member_event.signatures = member_event_dict["signatures"]

# Add the new member_event to the StateMap
prev_state_map[
updated_state_map = dict(prev_state_map)
updated_state_map[
(member_event.type, member_event.state_key)
] = member_event.event_id
auth_events.append(member_event)
Expand All @@ -399,7 +400,7 @@ def test_backfill_floating_outlier_membership_auth(self) -> None:
prev_event_ids=message_event_dict["prev_events"],
auth_event_ids=self._event_auth_handler.compute_auth_events(
builder,
prev_state_map,
updated_state_map,
for_verification=False,
),
depth=message_event_dict["depth"],
Expand Down
7 changes: 4 additions & 3 deletions tests/handlers/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,11 @@ def test_redirect_request(self) -> None:
req = Mock(spec=["cookies"])
req.cookies = []

url = self.get_success(
self.provider.handle_redirect_request(req, b"http://client/redirect")
url = urlparse(
self.get_success(
self.provider.handle_redirect_request(req, b"http://client/redirect")
)
)
url = urlparse(url)
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)

self.assertEqual(url.scheme, auth_endpoint.scheme)
Expand Down
2 changes: 2 additions & 0 deletions tests/handlers/test_user_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def test_handle_local_profile_change_with_support_user(self) -> None:
self.handler.handle_local_profile_change(regular_user_id, profile_info)
)
profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
assert profile is not None
self.assertTrue(profile["display_name"] == display_name)

def test_handle_local_profile_change_with_deactivated_user(self) -> None:
Expand All @@ -369,6 +370,7 @@ def test_handle_local_profile_change_with_deactivated_user(self) -> None:

# profile is in directory
profile = self.get_success(self.store.get_user_in_directory(r_user_id))
assert profile is not None
self.assertTrue(profile["display_name"] == display_name)

# deactivate user
Expand Down
8 changes: 8 additions & 0 deletions tests/rest/admin/test_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,7 @@ def test_quarantine_media(self) -> None:
"""

media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["quarantined_by"])

# quarantining
Expand All @@ -715,6 +716,7 @@ def test_quarantine_media(self) -> None:
self.assertFalse(channel.json_body)

media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertTrue(media_info["quarantined_by"])

# remove from quarantine
Expand All @@ -728,6 +730,7 @@ def test_quarantine_media(self) -> None:
self.assertFalse(channel.json_body)

media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["quarantined_by"])

def test_quarantine_protected_media(self) -> None:
Expand All @@ -740,6 +743,7 @@ def test_quarantine_protected_media(self) -> None:

# verify protection
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertTrue(media_info["safe_from_quarantine"])

# quarantining
Expand All @@ -754,6 +758,7 @@ def test_quarantine_protected_media(self) -> None:

# verify that is not in quarantine
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["quarantined_by"])


Expand Down Expand Up @@ -830,6 +835,7 @@ def test_protect_media(self) -> None:
"""

media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["safe_from_quarantine"])

# protect
Expand All @@ -843,6 +849,7 @@ def test_protect_media(self) -> None:
self.assertFalse(channel.json_body)

media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertTrue(media_info["safe_from_quarantine"])

# unprotect
Expand All @@ -856,6 +863,7 @@ def test_protect_media(self) -> None:
self.assertFalse(channel.json_body)

media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["safe_from_quarantine"])


Expand Down
15 changes: 9 additions & 6 deletions tests/rest/admin/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,10 +1590,9 @@ def test_create_user_email_notif_for_new_users(self) -> None:
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("[email protected]", channel.json_body["threepids"][0]["address"])

pushers = self.get_success(
self.store.get_pushers_by({"user_name": "@bob:test"})
pushers = list(
self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertEqual("@bob:test", pushers[0].user_name)

Expand Down Expand Up @@ -1632,10 +1631,9 @@ def test_create_user_email_no_notif_for_new_users(self) -> None:
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("[email protected]", channel.json_body["threepids"][0]["address"])

pushers = self.get_success(
self.store.get_pushers_by({"user_name": "@bob:test"})
pushers = list(
self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
)
pushers = list(pushers)
self.assertEqual(len(pushers), 0)

def test_set_password(self) -> None:
Expand Down Expand Up @@ -2144,6 +2142,7 @@ def test_change_name_deactivate_user_user_directory(self) -> None:

# is in user directory
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
assert profile is not None
self.assertTrue(profile["display_name"] == "User")

# Deactivate user
Expand Down Expand Up @@ -2711,6 +2710,7 @@ def test_get_pushers(self) -> None:
user_tuple = self.get_success(
self.store.get_user_by_access_token(other_user_token)
)
assert user_tuple is not None
token_id = user_tuple.token_id

self.get_success(
Expand Down Expand Up @@ -3676,6 +3676,7 @@ def test_success(self) -> None:
# The user starts off as not shadow-banned.
other_user_token = self.login("user", "pass")
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
assert result is not None
self.assertFalse(result.shadow_banned)

channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
Expand All @@ -3684,6 +3685,7 @@ def test_success(self) -> None:

# Ensure the user is shadow-banned (and the cache was cleared).
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
assert result is not None
self.assertTrue(result.shadow_banned)

# Un-shadow-ban the user.
Expand All @@ -3695,6 +3697,7 @@ def test_success(self) -> None:

# Ensure the user is no longer shadow-banned (and the cache was cleared).
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
assert result is not None
self.assertFalse(result.shadow_banned)


Expand Down
6 changes: 4 additions & 2 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from collections import deque
from io import SEEK_END, BytesIO
from typing import (
AnyStr,
Callable,
Dict,
Iterable,
Expand Down Expand Up @@ -86,6 +85,9 @@

logger = logging.getLogger(__name__)

# the type of thing that can be passed into `make_request` in the headers list
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]


class TimedOutException(Exception):
"""
Expand Down Expand Up @@ -260,7 +262,7 @@ def make_request(
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
Expand Down
8 changes: 4 additions & 4 deletions tests/storage/databases/main/test_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_simple_lock(self):
"""
# First to acquire this lock, so it should complete
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock)
assert lock is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it quite annoying that mypy doesn't handle assertIs(Not)None the same way...

python/mypy#5088 for cross-refs.


# Enter the context manager
self.get_success(lock.__aenter__())
Expand All @@ -45,15 +45,15 @@ def test_simple_lock(self):

# We can now acquire the lock again.
lock3 = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock3)
assert lock3 is not None
self.get_success(lock3.__aenter__())
self.get_success(lock3.__aexit__(None, None, None))

def test_maintain_lock(self):
"""Test that we don't time out locks while they're still active"""

lock = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock)
assert lock is not None

self.get_success(lock.__aenter__())

Expand All @@ -69,7 +69,7 @@ def test_timeout_lock(self):
"""Test that we time out locks if they're not updated for ages"""

lock = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock)
assert lock is not None

self.get_success(lock.__aenter__())

Expand Down
1 change: 1 addition & 0 deletions tests/storage/test_appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def test_get_oldest_unsent_txn(self) -> None:
self.get_success(self._insert_txn(service.id, 12, other_events))

txn = self.get_success(self.store.get_oldest_unsent_txn(service))
assert txn is not None
self.assertEqual(service, txn.service)
self.assertEqual(10, txn.id)
self.assertEqual(events, txn.events)
Expand Down
Loading