Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Replace make_awaitable with AsyncMock #16179

Merged
merged 10 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16179.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use `AsyncMock` instead of custom code.
5 changes: 2 additions & 3 deletions tests/crypto/test_keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import time
from typing import Any, Dict, List, Optional, cast
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock

import attr
import canonicaljson
Expand Down Expand Up @@ -45,7 +45,6 @@
from synapse.util import Clock

from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import logcontext_clean, override_config


Expand Down Expand Up @@ -291,7 +290,7 @@ def test_verify_json_for_server_with_null_valid_until_ms(self) -> None:
with a null `ts_valid_until_ms`
"""
mock_fetcher = Mock()
mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
mock_fetcher.get_keys = AsyncMock(return_value={})

key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_signature_keys(
Expand Down
33 changes: 16 additions & 17 deletions tests/federation/test_complexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import Mock
from unittest.mock import AsyncMock

from synapse.api.errors import Codes, SynapseError
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.types import JsonDict, UserID, create_requester

from tests import unittest
from tests.test_utils import make_awaitable


class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
Expand Down Expand Up @@ -75,9 +74,9 @@ def test_join_too_large(self) -> None:
fed_transport = self.hs.get_federation_transport_client()

# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment]
handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment]
return_value=("", 1)
)

d = handler._remote_join(
Expand Down Expand Up @@ -106,9 +105,9 @@ def test_join_too_large_admin(self) -> None:
fed_transport = self.hs.get_federation_transport_client()

# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment]
handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment]
return_value=("", 1)
)

d = handler._remote_join(
Expand Down Expand Up @@ -143,9 +142,9 @@ def test_join_too_large_once_joined(self) -> None:
fed_transport = self.hs.get_federation_transport_client()

# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
fed_transport.client.get_json = AsyncMock(return_value=None) # type: ignore[assignment]
handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment]
return_value=("", 1)
)

# Artificially raise the complexity
Expand Down Expand Up @@ -200,9 +199,9 @@ def test_join_too_large_no_admin(self) -> None:
fed_transport = self.hs.get_federation_transport_client()

# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment]
handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment]
return_value=("", 1)
)

d = handler._remote_join(
Expand Down Expand Up @@ -230,9 +229,9 @@ def test_join_too_large_admin(self) -> None:
fed_transport = self.hs.get_federation_transport_client()

# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment]
handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment]
return_value=("", 1)
)

d = handler._remote_join(
Expand Down
8 changes: 4 additions & 4 deletions tests/federation/test_federation_catch_up.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable, Collection, List, Optional, Tuple
from unittest import mock
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock

from twisted.test.proto_helpers import MemoryReactor

Expand All @@ -19,7 +19,7 @@
from synapse.util import Clock
from synapse.util.retryutils import NotRetryingDestination

from tests.test_utils import event_injection, make_awaitable
from tests.test_utils import event_injection
from tests.unittest import FederatingHomeserverTestCase


Expand Down Expand Up @@ -50,8 +50,8 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# This mock is crucial for destination_rooms to be populated.
# TODO: this seems to no longer be the case---tests pass with this mock
# commented out.
state_storage_controller.get_current_hosts_in_room = Mock( # type: ignore[assignment]
return_value=make_awaitable({"test", "host2"})
state_storage_controller.get_current_hosts_in_room = AsyncMock( # type: ignore[assignment]
return_value={"test", "host2"}
)

# whenever send_transaction is called, record the pdu data
Expand Down
42 changes: 20 additions & 22 deletions tests/federation/test_federation_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, FrozenSet, List, Optional, Set
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock

from signedjson import key, sign
from signedjson.types import BaseKey, SigningKey
Expand All @@ -29,7 +29,6 @@
from synapse.types import JsonDict, ReadReceipt
from synapse.util import Clock

from tests.test_utils import make_awaitable
from tests.unittest import HomeserverTestCase


Expand All @@ -43,12 +42,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):

def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.federation_transport_client = Mock(spec=["send_transaction"])
self.federation_transport_client.send_transaction = AsyncMock()
hs = self.setup_test_homeserver(
federation_transport_client=self.federation_transport_client,
)

hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment]
return_value=make_awaitable({"test", "host2"})
hs.get_storage_controllers().state.get_current_hosts_in_room = AsyncMock( # type: ignore[assignment]
return_value={"test", "host2"}
)

hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[assignment]
Expand All @@ -64,7 +64,7 @@ def default_config(self) -> JsonDict:

def test_send_receipts(self) -> None:
mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
mock_send_transaction.return_value = {}

sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_send_receipts(self) -> None:

def test_send_receipts_thread(self) -> None:
mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
mock_send_transaction.return_value = {}

# Create receipts for:
#
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_send_receipts_with_backoff(self) -> None:
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
mock_send_transaction.return_value = {}

sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
Expand Down Expand Up @@ -276,6 +276,8 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.federation_transport_client = Mock(
spec=["send_transaction", "query_user_devices"]
)
self.federation_transport_client.send_transaction = AsyncMock()
self.federation_transport_client.query_user_devices = AsyncMock()
return self.setup_test_homeserver(
federation_transport_client=self.federation_transport_client,
)
Expand Down Expand Up @@ -317,13 +319,13 @@ async def get_current_hosts_in_room(room_id: str) -> Set[str]:
self.record_transaction
)

def record_transaction(
async def record_transaction(
self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None
) -> "defer.Deferred[JsonDict]":
) -> JsonDict:
assert json_cb is not None
data = json_cb()
self.edus.extend(data["edus"])
return defer.succeed({})
return {}

def test_send_device_updates(self) -> None:
"""Basic case: each device update should result in an EDU"""
Expand Down Expand Up @@ -354,15 +356,11 @@ def test_dont_send_device_updates_for_remote_users(self) -> None:

# Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists.
self.federation_transport_client.query_user_devices.return_value = (
make_awaitable(
{
"stream_id": "1",
"user_id": "@user2:host2",
"devices": [{"device_id": "D1"}],
}
)
)
self.federation_transport_client.query_user_devices.return_value = {
"stream_id": "1",
"user_id": "@user2:host2",
"devices": [{"device_id": "D1"}],
}

self.get_success(
self.device_handler.device_list_updater.incoming_device_list_update(
Expand Down Expand Up @@ -533,7 +531,7 @@ def test_unreachable_server(self) -> None:
recovery
"""
mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
mock_send_txn.side_effect = AssertionError("fail")

# create devices
u1 = self.register_user("user", "pass")
Expand Down Expand Up @@ -578,7 +576,7 @@ def test_prune_outbound_device_pokes1(self) -> None:
This case tests the behaviour when the server has never been reachable.
"""
mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
mock_send_txn.side_effect = AssertionError("fail")

# create devices
u1 = self.register_user("user", "pass")
Expand Down Expand Up @@ -636,7 +634,7 @@ def test_prune_outbound_device_pokes2(self) -> None:

# now the server goes offline
mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
mock_send_txn.side_effect = AssertionError("fail")

self.login("user", "pass", device_id="D2")
self.login("user", "pass", device_id="D3")
Expand Down
Loading