Skip to content

Commit

Permalink
Showing 7 changed files with 165 additions and 40 deletions.
29 changes: 14 additions & 15 deletions raiden/network/transport/matrix/client.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,6 @@

import gevent
import structlog
from cachetools.func import ttl_cache
from gevent.lock import Semaphore
from matrix_client.api import MatrixHttpApi
from matrix_client.client import CACHE, MatrixClient
@@ -29,21 +28,21 @@ def __init__(self, client, room_id):
# dict of 'type': 'content' key/value pairs
self.account_data: Dict[str, Dict[str, Any]] = dict()

@ttl_cache(ttl=10)
def get_joined_members(self) -> List[User]:
def get_joined_members(self, force_resync=False) -> List[User]:
""" Return a list of members of this room. """
response = self.client.api.get_room_members(self.room_id)
for event in response['chunk']:
if event['content']['membership'] == 'join':
user_id = event["state_key"]
if user_id not in self._members:
self._mkmembers(
User(
self.client.api,
user_id,
event['content'].get('displayname'),
),
)
if force_resync:
response = self.client.api.get_room_members(self.room_id)
for event in response['chunk']:
if event['content']['membership'] == 'join':
user_id = event["state_key"]
if user_id not in self._members:
self._mkmembers(
User(
self.client.api,
user_id,
event['content'].get('displayname'),
),
)
return list(self._members.values())

def _mkmembers(self, member):
45 changes: 41 additions & 4 deletions raiden/network/transport/matrix/transport.py
Original file line number Diff line number Diff line change
@@ -700,11 +700,15 @@ def _handle_invite(self, room_id: _RoomID, state: dict):
# _get_room_ids_for_address will take care of returning only matching rooms and
# _leave_unused_rooms will clear it in the future, if and when needed
last_ex: Optional[Exception] = None
retry_interval = 0.1
for _ in range(JOIN_RETRIES):
try:
room = self._client.join_room(room_id)
except MatrixRequestError as e:
last_ex = e
if self._stop_event.wait(retry_interval):
break
retry_interval = retry_interval * 2
else:
break
else:
@@ -1012,12 +1016,45 @@ def _get_room_for_address(
self.log.error('No valid peer found', peer_address=address_hex)
return None

self._address_to_userids[address].update({user.user_id for user in peers})

if self._private_rooms:
room = self._get_private_room(invitees=peers)
else:
room = self._get_public_room(room_name, invitees=peers)

peer_ids = self._address_to_userids[address]
member_ids = {member.user_id for member in room.get_joined_members(force_resync=True)}
room_is_empty = not bool(peer_ids & member_ids)
if room_is_empty:
last_ex: Optional[Exception] = False
retry_interval = 0.1
self.log.debug(
'Waiting for peer to join from invite',
peer_address=address_hex,
)
for _ in range(JOIN_RETRIES):
try:
member_ids = {member.user_id for member in room.get_joined_members()}
except MatrixRequestError as e:
last_ex = e
room_is_empty = not bool(peer_ids & member_ids)
if room_is_empty or last_ex:
if self._stop_event.wait(retry_interval):
break
retry_interval = retry_interval * 2
else:
break

if room_is_empty or last_ex:
if last_ex:
raise last_ex # re-raise if couldn't succeed in retries
else:
# Inform the client, that currently no one listens:
self.log.error(
'Peer has not joined from invite yet, should join eventually',
peer_address=address_hex,
)

self._address_to_userids[address].update({user.user_id for user in peers})
self._set_room_id_for_address(address, room.room_id)

if not room.listeners:
@@ -1063,7 +1100,7 @@ def _get_public_room(self, room_name, invitees: List[User]):
)
else:
# Invite users to existing room
member_ids = {user.user_id for user in room.get_joined_members()}
member_ids = {user.user_id for user in room.get_joined_members(force_resync=True)}
users_to_invite = set(invitees_uids) - member_ids
self.log.debug('Inviting users', room=room, invitee_ids=users_to_invite)
for invitee_id in users_to_invite:
@@ -1209,7 +1246,7 @@ def _maybe_invite_user(self, user: User):

room = self._client.rooms[room_ids[0]]
if not room._members:
room.get_joined_members()
room.get_joined_members(force_resync=True)
if user.user_id not in room._members:
self.log.debug('Inviting', user=user, room=room)
try:
2 changes: 1 addition & 1 deletion raiden/network/transport/matrix/utils.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@

log = structlog.get_logger(__name__)

JOIN_RETRIES = 5
JOIN_RETRIES = 10
USERID_RE = re.compile(r'^@(0x[0-9a-f]{40})(?:\.[0-9a-f]{8})?(?::.+)?$')
ROOM_NAME_SEPARATOR = '_'
ROOM_NAME_PREFIX = 'raiden'
50 changes: 38 additions & 12 deletions raiden/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -167,30 +167,56 @@ def insecure_tls():

# Convert `--transport all` to two separate invocations with `matrix` and `udp`
def pytest_generate_tests(metafunc):
if 'transport' in metafunc.fixturenames:
fixtures = metafunc.fixturenames

if 'transport' in fixtures:
transport = metafunc.config.getoption('transport')
parmeterize_private_rooms = True
transport_and_privacy = list()

number_of_transports = list()

# Filter existing parametrization which is already done in the test
for mark in metafunc.definition.own_markers:
if mark.name == 'parametrize':
# Check if 'private_rooms' gets parameterized
if 'private_rooms' in mark.args[0]:
parmeterize_private_rooms = False
# Check if more than one transport is used
if 'number_of_transports' in mark.args[0]:
number_of_transports = mark.args[1]
# avoid collecting test if 'skip_if_not_*'
if transport in ('udp', 'all') and 'skip_if_not_matrix' not in metafunc.fixturenames:
if transport in ('udp', 'all') and 'skip_if_not_matrix' not in fixtures:
transport_and_privacy.append(('udp', None))

if transport in ('matrix', 'all') and 'skip_if_not_udp' not in metafunc.fixturenames:
if 'public_and_private_rooms' in metafunc.fixturenames:
transport_and_privacy.extend([('matrix', False), ('matrix', True)])
if transport in ('matrix', 'all') and 'skip_if_not_udp' not in fixtures:

if 'public_and_private_rooms' in fixtures:
if number_of_transports:
transport_and_privacy.extend([
('matrix', [False for _ in range(number_of_transports[0])]),
('matrix', [True for _ in range(number_of_transports[0])]),
])
else:
transport_and_privacy.extend([('matrix', False), ('matrix', True)])
else:
transport_and_privacy.append(('matrix', False))
if number_of_transports:
transport_and_privacy.extend([
('matrix', [False for _ in range(number_of_transports[0])]),
])
else:
transport_and_privacy.append(('matrix', False))

if 'private_rooms' in metafunc.fixturenames:
metafunc.parametrize('transport,private_rooms', transport_and_privacy)
else:
# If the test function isn't taking the `private_rooms` fixture only give the
# transport values
if not parmeterize_private_rooms or 'private_rooms' not in fixtures:
# If the test does not expect the private_rooms parameter or parametrizes
# `private_rooms` itself, only give he transport values
metafunc.parametrize(
'transport',
list(set(transport_type for transport_type, _ in transport_and_privacy)),
)

else:
metafunc.parametrize('transport,private_rooms', transport_and_privacy)


if sys.platform == 'darwin':
# On macOS the temp directory base path is already very long.
5 changes: 0 additions & 5 deletions raiden/tests/fixtures/variables.py
Original file line number Diff line number Diff line change
@@ -319,11 +319,6 @@ def database_paths(tmpdir, private_keys):
return database_paths


@pytest.fixture
def private_rooms():
return False


@pytest.fixture
def environment_type():
"""Specifies the environment type"""
2 changes: 1 addition & 1 deletion raiden/tests/integration/fixtures/transport.py
Original file line number Diff line number Diff line change
@@ -69,7 +69,7 @@ def matrix_transports(
'server': server,
'server_name': server.netloc,
'available_servers': local_matrix_servers,
'private_rooms': private_rooms,
'private_rooms': private_rooms[transport_index],
}),
)

72 changes: 70 additions & 2 deletions raiden/tests/integration/test_matrix_transport.py
Original file line number Diff line number Diff line change
@@ -110,6 +110,50 @@ def mock_receive_message(klass, message):
return transport


def ping_pong_message_success(transport0, transport1):
queueid0 = QueueIdentifier(
recipient=transport0._raiden_service.address,
channel_identifier=CHANNEL_IDENTIFIER_GLOBAL_QUEUE,
)

queueid1 = QueueIdentifier(
recipient=transport1._raiden_service.address,
channel_identifier=CHANNEL_IDENTIFIER_GLOBAL_QUEUE,
)

received_messages0 = transport0._raiden_service.message_handler.bag
received_messages1 = transport1._raiden_service.message_handler.bag
number_of_received_messages0 = len(received_messages0)
number_of_received_messages1 = len(received_messages1)

message = Processed(message_identifier=number_of_received_messages0)
transport0._raiden_service.sign(message)

transport0.send_async(queueid1, message)
with Timeout(20, exception=False):
all_messages_received = False
while not all_messages_received:
all_messages_received = (
len(received_messages0) == number_of_received_messages0 + 1 and
len(received_messages1) == number_of_received_messages1 + 1
)
gevent.sleep(.1)
message = Processed(message_identifier=number_of_received_messages1)
transport1._raiden_service.sign(message)
transport1.send_async(queueid0, message)

with Timeout(20, exception=False):
all_messages_received = False
while not all_messages_received:
all_messages_received = (
len(received_messages0) == number_of_received_messages0 + 2 and
len(received_messages1) == number_of_received_messages1 + 2
)
gevent.sleep(.1)

return all_messages_received


@pytest.fixture()
def skip_userid_validation(monkeypatch):
import raiden.network.transport.matrix
@@ -297,10 +341,12 @@ def test_matrix_message_sync(
queue_identifier,
message,
)

gevent.sleep(2)
with Timeout(retry_interval * 20, exception=False):
while not len(received_messages) == 10:
gevent.sleep(.1)

assert len(received_messages) == 10

for i in range(5):
assert any(getattr(m, 'message_identifier', -1) == i for m in received_messages)

@@ -778,3 +824,25 @@ def make_unsigned_balance_proof(nonce):
)
transport.stop()
transport.get()


@pytest.mark.parametrize('private_rooms', [[True, True]])
@pytest.mark.parametrize('matrix_server_count', [2])
@pytest.mark.parametrize('number_of_transports', [2])
def test_reproduce_handle_invite_send_race_issue_3588(matrix_transports):
transport0, transport1 = matrix_transports
received_messages0 = set()
received_messages1 = set()

message_handler0 = MessageHandler(received_messages0)
message_handler1 = MessageHandler(received_messages1)

raiden_service0 = MockRaidenService(message_handler0)
raiden_service1 = MockRaidenService(message_handler1)

transport0.start(raiden_service0, message_handler0, '')
transport1.start(raiden_service1, message_handler1, '')

transport0.start_health_check(raiden_service1.address)
transport1.start_health_check(raiden_service0.address)
assert ping_pong_message_success(transport0, transport1)

0 comments on commit 61bdaff

Please sign in to comment.