diff --git a/src/socketio/asyncio_aiopika_manager.py b/src/socketio/asyncio_aiopika_manager.py index 905057d5..96dcec65 100644 --- a/src/socketio/asyncio_aiopika_manager.py +++ b/src/socketio/asyncio_aiopika_manager.py @@ -94,7 +94,7 @@ async def _listen(self): async with self.listener_queue.iterator() as queue_iter: async for message in queue_iter: with message.process(): - return pickle.loads(message.body) + yield pickle.loads(message.body) except Exception: self._get_logger().error('Cannot receive from rabbitmq... ' 'retrying in ' diff --git a/src/socketio/asyncio_pubsub_manager.py b/src/socketio/asyncio_pubsub_manager.py index 916c4a6f..ff37f2df 100644 --- a/src/socketio/asyncio_pubsub_manager.py +++ b/src/socketio/asyncio_pubsub_manager.py @@ -148,35 +148,34 @@ async def _handle_close_room(self, message): async def _thread(self): while True: try: - message = await self._listen() + async for message in self._listen(): # pragma: no branch + data = None + if isinstance(message, dict): + data = message + else: + if isinstance(message, bytes): # pragma: no cover + try: + data = pickle.loads(message) + except: + pass + if data is None: + try: + data = json.loads(message) + except: + pass + if data and 'method' in data: + self._get_logger().info('pubsub message: {}'.format( + data['method'])) + if data['method'] == 'emit': + await self._handle_emit(data) + elif data['method'] == 'callback': + await self._handle_callback(data) + elif data['method'] == 'disconnect': + await self._handle_disconnect(data) + elif data['method'] == 'close_room': + await self._handle_close_room(data) except asyncio.CancelledError: # pragma: no cover break - except: + except: # pragma: no cover import traceback traceback.print_exc() - break - data = None - if isinstance(message, dict): - data = message - else: - if isinstance(message, bytes): # pragma: no cover - try: - data = pickle.loads(message) - except: - pass - if data is None: - try: - data = json.loads(message) - except: - pass - if data and 'method' in data: - self._get_logger().info('pubsub message: {}'.format( - data['method'])) - if data['method'] == 'emit': - await self._handle_emit(data) - elif data['method'] == 'callback': - await self._handle_callback(data) - elif data['method'] == 'disconnect': - await self._handle_disconnect(data) - elif data['method'] == 'close_room': - await self._handle_close_room(data) diff --git a/src/socketio/asyncio_redis_manager.py b/src/socketio/asyncio_redis_manager.py index 9762d3eb..41a62c63 100644 --- a/src/socketio/asyncio_redis_manager.py +++ b/src/socketio/asyncio_redis_manager.py @@ -1,6 +1,5 @@ import asyncio import pickle -from urllib.parse import urlparse try: import aioredis @@ -10,34 +9,18 @@ from .asyncio_pubsub_manager import AsyncPubSubManager -def _parse_redis_url(url): - p = urlparse(url) - if p.scheme not in {'redis', 'rediss'}: - raise ValueError('Invalid redis url') - ssl = p.scheme == 'rediss' - host = p.hostname or 'localhost' - port = p.port or 6379 - password = p.password - if p.path: - db = int(p.path[1:]) - else: - db = 0 - return host, port, password, db, ssl - - class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover """Redis based client manager for asyncio servers. This class implements a Redis backend for event sharing across multiple - processes. Only kept here as one more example of how to build a custom - backend, since the kombu backend is perfectly adequate to support a Redis - message queue. + processes. - To use a Redis backend, initialize the :class:`Server` instance as + To use a Redis backend, initialize the :class:`AsyncServer` instance as follows:: - server = socketio.Server(client_manager=socketio.AsyncRedisManager( - 'redis://hostname:port/0')) + url = 'redis://hostname:port/0' + server = socketio.AsyncServer( + client_manager=socketio.AsyncRedisManager(url)) :param url: The connection URL for the Redis server. For a default Redis store running on the same host, use ``redis://``. To use an @@ -47,62 +30,73 @@ class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover :param write_only: If set to ``True``, only initialize to emit events. The default of ``False`` initializes the class for emitting and receiving. + :param redis_options: additional keyword arguments to be passed to + ``aioredis.from_url()``. """ name = 'aioredis' def __init__(self, url='redis://localhost:6379/0', channel='socketio', - write_only=False, logger=None): + write_only=False, logger=None, redis_options=None): if aioredis is None: raise RuntimeError('Redis package is not installed ' '(Run "pip install aioredis" in your ' 'virtualenv).') - ( - self.host, self.port, self.password, self.db, self.ssl - ) = _parse_redis_url(url) - self.pub = None - self.sub = None + if not hasattr(aioredis.Redis, 'from_url'): + raise RuntimeError('Version 2 of aioredis package is required.') + self.redis_url = url + self.redis_options = redis_options or {} + self._redis_connect() super().__init__(channel=channel, write_only=write_only, logger=logger) + def _redis_connect(self): + self.redis = aioredis.Redis.from_url(self.redis_url, + **self.redis_options) + self.pubsub = self.redis.pubsub() + async def _publish(self, data): retry = True while True: try: - if self.pub is None: - self.pub = await aioredis.create_redis( - (self.host, self.port), db=self.db, - password=self.password, ssl=self.ssl - ) - return await self.pub.publish(self.channel, - pickle.dumps(data)) - except (aioredis.RedisError, OSError): + if not retry: + self._redis_connect() + return await self.redis.publish( + self.channel, pickle.dumps(data)) + except aioredis.exceptions.RedisError: if retry: self._get_logger().error('Cannot publish to redis... ' 'retrying') - self.pub = None retry = False else: self._get_logger().error('Cannot publish to redis... ' 'giving up') break - async def _listen(self): + async def _redis_listen_with_retries(self): retry_sleep = 1 + connect = False while True: try: - if self.sub is None: - self.sub = await aioredis.create_redis( - (self.host, self.port), db=self.db, - password=self.password, ssl=self.ssl - ) - self.ch = (await self.sub.subscribe(self.channel))[0] - retry_sleep = 1 - return await self.ch.get() - except (aioredis.RedisError, OSError): + if connect: + self._redis_connect() + await self.pubsub.subscribe(self.channel) + retry_sleep = 1 + async for message in self.pubsub.listen(): + yield message + except aioredis.exceptions.RedisError: self._get_logger().error('Cannot receive from redis... ' 'retrying in ' '{} secs'.format(retry_sleep)) - self.sub = None + connect = True await asyncio.sleep(retry_sleep) retry_sleep *= 2 if retry_sleep > 60: retry_sleep = 60 + + async def _listen(self): + channel = self.channel.encode('utf-8') + await self.pubsub.subscribe(self.channel) + async for message in self._redis_listen_with_retries(): + if message['channel'] == channel and \ + message['type'] == 'message' and 'data' in message: + yield message['data'] + await self.pubsub.unsubscribe(self.channel) diff --git a/src/socketio/redis_manager.py b/src/socketio/redis_manager.py index 7e99d31e..6ac06018 100644 --- a/src/socketio/redis_manager.py +++ b/src/socketio/redis_manager.py @@ -27,7 +27,8 @@ class RedisManager(PubSubManager): # pragma: no cover server = socketio.Server(client_manager=socketio.RedisManager(url)) :param url: The connection URL for the Redis server. For a default Redis - store running on the same host, use ``redis://``. + store running on the same host, use ``redis://``. To use an + SSL connection, use ``rediss://``. :param channel: The channel name on which the server sends and receives notifications. Must be the same in all the servers. :param write_only: If set to ``True``, only initialize to emit events. The diff --git a/tests/asyncio/test_asyncio_pubsub_manager.py b/tests/asyncio/test_asyncio_pubsub_manager.py index 5cefec87..48480a5b 100644 --- a/tests/asyncio/test_asyncio_pubsub_manager.py +++ b/tests/asyncio/test_asyncio_pubsub_manager.py @@ -417,7 +417,7 @@ def test_background_thread(self): self.pm._handle_disconnect = AsyncMock() self.pm._handle_close_room = AsyncMock() - def messages(): + async def messages(): import pickle yield {'method': 'emit', 'value': 'foo'} @@ -428,12 +428,10 @@ def messages(): yield pickle.dumps({'method': 'close_room', 'value': 'baz'}) yield 'bad json' yield b'bad pickled' + raise asyncio.CancelledError() # force the thread to exit - self.pm._listen = AsyncMock(side_effect=list(messages())) - try: - _run(self.pm._thread()) - except StopIteration: - pass + self.pm._listen = messages + _run(self.pm._thread()) self.pm._handle_emit.mock.assert_called_once_with( {'method': 'emit', 'value': 'foo'} diff --git a/tests/asyncio/test_asyncio_redis_manager.py b/tests/asyncio/test_asyncio_redis_manager.py deleted file mode 100644 index a8cf7d8e..00000000 --- a/tests/asyncio/test_asyncio_redis_manager.py +++ /dev/null @@ -1,73 +0,0 @@ -import sys -import unittest - -import pytest - -from socketio import asyncio_redis_manager - - -@unittest.skipIf(sys.version_info < (3, 5), 'only for Python 3.5+') -class TestAsyncRedisManager(unittest.TestCase): - def test_default_url(self): - assert asyncio_redis_manager._parse_redis_url('redis://') == ( - 'localhost', - 6379, - None, - 0, - False, - ) - - def test_only_host_url(self): - assert asyncio_redis_manager._parse_redis_url( - 'redis://redis.host' - ) == ('redis.host', 6379, None, 0, False) - - def test_no_db_url(self): - assert asyncio_redis_manager._parse_redis_url( - 'redis://redis.host:123/1' - ) == ('redis.host', 123, None, 1, False) - - def test_no_port_url(self): - assert asyncio_redis_manager._parse_redis_url( - 'redis://redis.host/1' - ) == ('redis.host', 6379, None, 1, False) - - def test_password(self): - assert asyncio_redis_manager._parse_redis_url( - 'redis://:pw@redis.host/1' - ) == ('redis.host', 6379, 'pw', 1, False) - - def test_no_host_url(self): - assert asyncio_redis_manager._parse_redis_url('redis://:123/1') == ( - 'localhost', - 123, - None, - 1, - False, - ) - - def test_no_host_password_url(self): - assert asyncio_redis_manager._parse_redis_url( - 'redis://:pw@:123/1' - ) == ('localhost', 123, 'pw', 1, False) - - def test_bad_port_url(self): - with pytest.raises(ValueError): - asyncio_redis_manager._parse_redis_url('redis://localhost:abc/1') - - def test_bad_db_url(self): - with pytest.raises(ValueError): - asyncio_redis_manager._parse_redis_url('redis://localhost:abc/z') - - def test_bad_scheme_url(self): - with pytest.raises(ValueError): - asyncio_redis_manager._parse_redis_url('http://redis.host:123/1') - - def test_ssl_scheme(self): - assert asyncio_redis_manager._parse_redis_url('rediss://') == ( - 'localhost', - 6379, - None, - 0, - True, - )