diff --git a/.gitignore b/.gitignore index d619cd70..588cd806 100644 --- a/.gitignore +++ b/.gitignore @@ -39,6 +39,8 @@ docs/_build venv* .eggs .ropeproject +tags .idea +.vscode htmlcov *.swp diff --git a/engineio/__init__.py b/engineio/__init__.py index e3d2e905..0915a596 100644 --- a/engineio/__init__.py +++ b/engineio/__init__.py @@ -4,7 +4,7 @@ from .server import Server if sys.version_info >= (3, 5): # pragma: no cover from .asyncio_server import AsyncServer -else: +else: # pragma: no cover AsyncServer = None __version__ = '1.1.2' diff --git a/engineio/async_aiohttp.py b/engineio/async_aiohttp.py index 7138e8d1..cf943d6d 100644 --- a/engineio/async_aiohttp.py +++ b/engineio/async_aiohttp.py @@ -77,7 +77,7 @@ def make_response(status, headers, payload): headers=headers) -class WebSocket(object): +class WebSocket(object): # pragma: no cover """ This wrapper class provides a aiohttp WebSocket interface that is somewhat compatible with eventlet's implementation. diff --git a/engineio/asyncio_server.py b/engineio/asyncio_server.py index a42b70b4..5e2c4d40 100644 --- a/engineio/asyncio_server.py +++ b/engineio/asyncio_server.py @@ -84,11 +84,11 @@ async def disconnect(self, sid=None): is not given, then all clients are closed. """ if sid is not None: - self._get_socket(sid).close() + await self._get_socket(sid).close() del self.sockets[sid] else: for client in six.itervalues(self.sockets): - client.close() + await client.close() self.sockets = {} async def handle_request(self, *args, **kwargs): diff --git a/engineio/asyncio_socket.py b/engineio/asyncio_socket.py index ef85eb1c..2b4256fc 100644 --- a/engineio/asyncio_socket.py +++ b/engineio/asyncio_socket.py @@ -145,6 +145,7 @@ async def _websocket_handler(self, ws): # start separate writer thread async def writer(): while True: + packets = None try: packets = await self.poll() except IOError: diff --git a/tests/test_async_aiohttp.py b/tests/test_async_aiohttp.py new file mode 100644 index 00000000..690ac06e --- /dev/null +++ b/tests/test_async_aiohttp.py @@ -0,0 +1,56 @@ +import sys +import unittest + +import six +if six.PY3: + from unittest import mock +else: + import mock + +if sys.version_info >= (3, 5): + from aiohttp import web + from engineio import async_aiohttp + + +@unittest.skipIf(sys.version_info < (3, 5), 'only for Python 3.5+') +class AiohttpTests(unittest.TestCase): + @mock.patch('aiohttp.web_urldispatcher.UrlDispatcher.add_route') + def test_create_route(self, add_route): + app = web.Application() + mock_server = mock.MagicMock() + async_aiohttp.create_route(app, mock_server, '/foo') + self.assertEqual(add_route.call_count, 2) + add_route.assert_any_call('GET', '/foo', mock_server.handle_request) + add_route.assert_any_call('POST', '/foo', mock_server.handle_request) + + def test_translate_request(self): + request = mock.MagicMock() + request._message.method = 'PUT' + request._message.path = '/foo/bar?baz=1' + request._message.version = (1, 1) + request._message.headers = {'a': 'b', 'c-c': 'd', 'c_c': 'e', + 'content-type': 'application/json', + 'content-length': 123} + request._payload = b'hello world' + environ = async_aiohttp.translate_request(request) + expected_environ = { + 'REQUEST_METHOD': 'PUT', + 'PATH_INFO': '/foo/bar', + 'QUERY_STRING': 'baz=1', + 'CONTENT_TYPE': 'application/json', + 'CONTENT_LENGTH': 123, + 'HTTP_A': 'b', + # 'HTTP_C_C': 'd,e', + 'RAW_URI': '/foo/bar?baz=1', + 'SERVER_PROTOCOL': 'HTTP/1.1', + 'wsgi.input': b'hello world', + 'aiohttp.request': request, + } + for k, v in expected_environ.items(): + self.assertEqual(v, environ[k]) + self.assertTrue(environ['HTTP_C_C'] == 'd,e' or environ['HTTP_C_C'] == 'e,d') + @mock.patch('engineio.async_aiohttp.aiohttp.web.Response') + def test_make_response(self, Response): + async_aiohttp.make_response('202 ACCEPTED', 'headers', 'payload') + Response.assert_called_once_with(body='payload', status=202, + headers='headers') diff --git a/tests/test_asyncio_server.py b/tests/test_asyncio_server.py index ab15331a..8b6176db 100644 --- a/tests/test_asyncio_server.py +++ b/tests/test_asyncio_server.py @@ -23,55 +23,52 @@ def coroutine(f): return f -mock_coro_args = {} -mock_coro_kwargs = {} +def AsyncMock(*args, **kwargs): + """Return a mock asynchronous function.""" + m = mock.MagicMock(*args, **kwargs) -def get_mock_coro(name, return_value=None): @coroutine - def coro(*args, **kwargs): - global mock_coro_args - global mock_coro_kwargs - mock_coro_args[name] = args - mock_coro_kwargs[name] = kwargs - return return_value - - return coro() - - -def get_async_mock(environ={'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}): - a = mock.MagicMock() - a.async = { - 'asyncio': True, - 'create_route': mock.MagicMock(), - 'translate_request': mock.MagicMock(), - 'make_response': mock.MagicMock(), - 'websocket': 'w', - 'websocket_class': 'wc' - } - a.async['translate_request'].return_value = environ - a.async['make_response'].return_value = 'response' - return a + def mock_coro(*args, **kwargs): + return m(*args, **kwargs) + + mock_coro.mock = m + return mock_coro + + +def _run(coro): + """Run the given coroutine.""" + return asyncio.get_event_loop().run_until_complete(coro) @unittest.skipIf(sys.version_info < (3, 5), 'only for Python 3.5+') class TestAsyncServer(unittest.TestCase): + @staticmethod + def get_async_mock(environ={'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}): + a = mock.MagicMock() + a.async = { + 'asyncio': True, + 'create_route': mock.MagicMock(), + 'translate_request': mock.MagicMock(), + 'make_response': mock.MagicMock(), + 'websocket': 'w', + 'websocket_class': 'wc' + } + a.async['translate_request'].return_value = environ + a.async['make_response'].return_value = 'response' + return a + def _get_mock_socket(self): mock_socket = mock.MagicMock() mock_socket.connected = False mock_socket.closed = False mock_socket.upgraded = False - mock_socket.send.return_value = get_mock_coro('socket.send') - mock_socket.handle_get_request.return_value = get_mock_coro( - 'socket.handle_get_request') - mock_socket.handle_post_request.return_value = get_mock_coro( - 'socket.handle_post_request') - mock_socket.close.return_value = get_mock_coro('socket.close') + mock_socket.send = AsyncMock() + mock_socket.handle_get_request = AsyncMock() + mock_socket.handle_post_request = AsyncMock() + mock_socket.close = AsyncMock() return mock_socket - def _run(self, coro): - return asyncio.get_event_loop().run_until_complete(coro) - def setUp(self): logging.getLogger('engineio').setLevel(logging.NOTSET) @@ -99,8 +96,9 @@ def test_async_mode_aiohttp(self): self.assertEqual(s._async['websocket'], async_aiohttp) self.assertEqual(s._async['websocket_class'], 'WebSocket') - @mock.patch('importlib.import_module', side_effect=[get_async_mock()]) + @mock.patch('importlib.import_module') def test_async_mode_auto_aiohttp(self, import_module): + import_module.side_effect = [self.get_async_mock()] s = asyncio_server.AsyncServer() self.assertEqual(s.async_mode, 'aiohttp') @@ -116,7 +114,7 @@ def test_async_modes_wsgi(self): @mock.patch('importlib.import_module') def test_attach(self, import_module): - a = get_async_mock() + a = self.get_async_mock() import_module.side_effect = [a] s = asyncio_server.AsyncServer() s.attach('app', engineio_path='path') @@ -125,29 +123,30 @@ def test_attach(self, import_module): def test_disconnect(self): s = asyncio_server.AsyncServer() s.sockets['foo'] = mock_socket = self._get_mock_socket() - self._run(s.disconnect('foo')) - self.assertEqual(mock_socket.close.call_count, 1) - mock_socket.close.assert_called_once_with() + _run(s.disconnect('foo')) + self.assertEqual(mock_socket.close.mock.call_count, 1) + mock_socket.close.mock.assert_called_once_with() self.assertNotIn('foo', s.sockets) def test_disconnect_all(self): s = asyncio_server.AsyncServer() s.sockets['foo'] = mock_foo = self._get_mock_socket() s.sockets['bar'] = mock_bar = self._get_mock_socket() - self._run(s.disconnect()) - self.assertEqual(mock_foo.close.call_count, 1) - self.assertEqual(mock_bar.close.call_count, 1) - mock_foo.close.assert_called_once_with() - mock_bar.close.assert_called_once_with() + _run(s.disconnect()) + self.assertEqual(mock_foo.close.mock.call_count, 1) + self.assertEqual(mock_bar.close.mock.call_count, 1) + mock_foo.close.mock.assert_called_once_with() + mock_bar.close.mock.assert_called_once_with() self.assertNotIn('foo', s.sockets) self.assertNotIn('bar', s.sockets) @mock.patch('importlib.import_module') def test_jsonp_not_supported(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'j=abc'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'j=abc'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer() - response = self._run(s.handle_request('request')) + response = _run(s.handle_request('request')) self.assertEqual(response, 'response') a.async['translate_request'].assert_called_once_with('request') self.assertEqual(a.async['make_response'].call_count, 1) @@ -156,10 +155,10 @@ def test_jsonp_not_supported(self, import_module): @mock.patch('importlib.import_module') def test_connect(self, import_module): - a = get_async_mock() + a = self.get_async_mock() import_module.side_effect = [a] s = asyncio_server.AsyncServer() - self._run(s.handle_request('request')) + _run(s.handle_request('request')) self.assertEqual(len(s.sockets), 1) self.assertEqual(a.async['make_response'].call_count, 1) self.assertEqual(a.async['make_response'].call_args[0][0], '200 OK') @@ -175,94 +174,96 @@ def test_connect(self, import_module): @mock.patch('importlib.import_module') def test_connect_no_upgrades(self, import_module): - a = get_async_mock() + a = self.get_async_mock() import_module.side_effect = [a] s = asyncio_server.AsyncServer(allow_upgrades=False) - self._run(s.handle_request('request')) + _run(s.handle_request('request')) packets = payload.Payload( encoded_payload=a.async['make_response'].call_args[0][2]).packets self.assertEqual(packets[0].data['upgrades'], []) @mock.patch('importlib.import_module') def test_connect_b64_with_1(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'b64=1'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'b64=1'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer(allow_upgrades=False) s._generate_id = mock.MagicMock(return_value='1') - self._run(s.handle_request('request')) + _run(s.handle_request('request')) self.assertEqual(a.async['make_response'].call_count, 1) self.assertEqual(a.async['make_response'].call_args[0][0], '200 OK') self.assertIn(('Content-Type', 'text/plain; charset=UTF-8'), a.async['make_response'].call_args[0][1]) - self._run(s.send('1', b'\x00\x01\x02', binary=True)) + _run(s.send('1', b'\x00\x01\x02', binary=True)) a.async['translate_request'].return_value = { 'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=1&b64=1'} - self._run(s.handle_request('request')) + _run(s.handle_request('request')) self.assertEqual(a.async['make_response'].call_args[0][2], b'6:b4AAEC') @mock.patch('importlib.import_module') def test_connect_b64_with_true(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', - 'QUERY_STRING': 'b64=true'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'b64=true'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer(allow_upgrades=False) s._generate_id = mock.MagicMock(return_value='1') - self._run(s.handle_request('request')) + _run(s.handle_request('request')) self.assertEqual(a.async['make_response'].call_count, 1) self.assertEqual(a.async['make_response'].call_args[0][0], '200 OK') self.assertIn(('Content-Type', 'text/plain; charset=UTF-8'), a.async['make_response'].call_args[0][1]) - self._run(s.send('1', b'\x00\x01\x02', binary=True)) + _run(s.send('1', b'\x00\x01\x02', binary=True)) a.async['translate_request'].return_value = { 'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=1&b64=true'} - self._run(s.handle_request('request')) + _run(s.handle_request('request')) self.assertEqual(a.async['make_response'].call_args[0][2], b'6:b4AAEC') @mock.patch('importlib.import_module') def test_connect_b64_with_0(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'b64=0'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'b64=0'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer(allow_upgrades=False) s._generate_id = mock.MagicMock(return_value='1') - self._run(s.handle_request('request')) + _run(s.handle_request('request')) self.assertEqual(a.async['make_response'].call_count, 1) self.assertEqual(a.async['make_response'].call_args[0][0], '200 OK') self.assertIn(('Content-Type', 'application/octet-stream'), a.async['make_response'].call_args[0][1]) - self._run(s.send('1', b'\x00\x01\x02', binary=True)) + _run(s.send('1', b'\x00\x01\x02', binary=True)) a.async['translate_request'].return_value = { 'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=1&b64=0'} - self._run(s.handle_request('request')) + _run(s.handle_request('request')) self.assertEqual(a.async['make_response'].call_args[0][2], b'\x01\x04\xff\x04\x00\x01\x02') @mock.patch('importlib.import_module') def test_connect_b64_with_false(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', - 'QUERY_STRING': 'b64=false'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'b64=false'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer(allow_upgrades=False) s._generate_id = mock.MagicMock(return_value='1') - self._run(s.handle_request('request')) + _run(s.handle_request('request')) self.assertEqual(a.async['make_response'].call_count, 1) self.assertEqual(a.async['make_response'].call_args[0][0], '200 OK') self.assertIn(('Content-Type', 'application/octet-stream'), a.async['make_response'].call_args[0][1]) - self._run(s.send('1', b'\x00\x01\x02', binary=True)) + _run(s.send('1', b'\x00\x01\x02', binary=True)) a.async['translate_request'].return_value = { 'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=1&b64=false'} - self._run(s.handle_request('request')) + _run(s.handle_request('request')) self.assertEqual(a.async['make_response'].call_args[0][2], b'\x01\x04\xff\x04\x00\x01\x02') @mock.patch('importlib.import_module') def test_connect_custom_ping_times(self, import_module): - a = get_async_mock() + a = self.get_async_mock() import_module.side_effect = [a] s = asyncio_server.AsyncServer(ping_timeout=123, ping_interval=456) - self._run(s.handle_request('request')) + _run(s.handle_request('request')) packets = payload.Payload( encoded_payload=a.async['make_response'].call_args[0][2]).packets self.assertEqual(packets[0].data['pingTimeout'], 123000) @@ -271,70 +272,71 @@ def test_connect_custom_ping_times(self, import_module): @mock.patch('engineio.asyncio_socket.AsyncSocket') @mock.patch('importlib.import_module') def test_connect_transport_websocket(self, import_module, AsyncSocket): - a = get_async_mock({'REQUEST_METHOD': 'GET', - 'QUERY_STRING': 'transport=websocket'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'transport=websocket'}) import_module.side_effect = [a] AsyncSocket.return_value = self._get_mock_socket() s = asyncio_server.AsyncServer() s._generate_id = mock.MagicMock(return_value='123') - self._run(s.handle_request('request')) - self.assertEqual(s.sockets['123'].send.call_args[0][0].packet_type, - packet.OPEN) + _run(s.handle_request('request')) + self.assertEqual( + s.sockets['123'].send.mock.call_args[0][0].packet_type, + packet.OPEN) @mock.patch('importlib.import_module') def test_connect_transport_invalid(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', - 'QUERY_STRING': 'transport=foo'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'transport=foo'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer() - self._run(s.handle_request('request')) + _run(s.handle_request('request')) self.assertEqual(a.async['make_response'].call_count, 1) self.assertEqual(a.async['make_response'].call_args[0][0], '400 BAD REQUEST') @mock.patch('importlib.import_module') def test_connect_cors_headers(self, import_module): - a = get_async_mock() + a = self.get_async_mock() import_module.side_effect = [a] s = asyncio_server.AsyncServer() - self._run(s.handle_request('request')) + _run(s.handle_request('request')) headers = a.async['make_response'].call_args[0][1] self.assertIn(('Access-Control-Allow-Origin', '*'), headers) self.assertIn(('Access-Control-Allow-Credentials', 'true'), headers) @mock.patch('importlib.import_module') def test_connect_cors_allowed_origin(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': '', - 'HTTP_ORIGIN': 'b'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': '', + 'HTTP_ORIGIN': 'b'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer(cors_allowed_origins=['a', 'b']) - self._run(s.handle_request('request')) + _run(s.handle_request('request')) headers = a.async['make_response'].call_args[0][1] self.assertIn(('Access-Control-Allow-Origin', 'b'), headers) @mock.patch('importlib.import_module') def test_connect_cors_not_allowed_origin(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': '', - 'HTTP_ORIGIN': 'c'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': '', + 'HTTP_ORIGIN': 'c'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer(cors_allowed_origins=['a', 'b']) - self._run(s.handle_request('request')) + _run(s.handle_request('request')) headers = a.async['make_response'].call_args[0][1] self.assertNotIn(('Access-Control-Allow-Origin', 'c'), headers) self.assertNotIn(('Access-Control-Allow-Origin', '*'), headers) @mock.patch('importlib.import_module') def test_connect_cors_no_credentials(self, import_module): - a = get_async_mock() + a = self.get_async_mock() import_module.side_effect = [a] s = asyncio_server.AsyncServer(cors_credentials=False) - self._run(s.handle_request('request')) + _run(s.handle_request('request')) headers = a.async['make_response'].call_args[0][1] self.assertNotIn(('Access-Control-Allow-Credentials', 'true'), headers) @mock.patch('importlib.import_module') def test_connect_event(self, import_module): - a = get_async_mock() + a = self.get_async_mock() import_module.side_effect = [a] s = asyncio_server.AsyncServer() s._generate_id = mock.MagicMock(return_value='123') @@ -343,12 +345,12 @@ def mock_connect(sid, environ): return True s.on('connect', handler=mock_connect) - self._run(s.handle_request('request')) + _run(s.handle_request('request')) self.assertEqual(len(s.sockets), 1) @mock.patch('importlib.import_module') def test_connect_event_rejects(self, import_module): - a = get_async_mock() + a = self.get_async_mock() import_module.side_effect = [a] s = asyncio_server.AsyncServer() s._generate_id = mock.MagicMock(return_value='123') @@ -357,74 +359,73 @@ def mock_connect(sid, environ): return False s.on('connect')(mock_connect) - self._run(s.handle_request('request')) + _run(s.handle_request('request')) self.assertEqual(len(s.sockets), 0) self.assertEqual(a.async['make_response'].call_args[0][0], '401 UNAUTHORIZED') @mock.patch('importlib.import_module') def test_method_not_found(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'PUT', 'QUERY_STRING': ''}) + a = self.get_async_mock({'REQUEST_METHOD': 'PUT', 'QUERY_STRING': ''}) import_module.side_effect = [a] s = asyncio_server.AsyncServer() - self._run(s.handle_request('request')) + _run(s.handle_request('request')) self.assertEqual(len(s.sockets), 0) self.assertEqual(a.async['make_response'].call_args[0][0], '405 METHOD NOT FOUND') @mock.patch('importlib.import_module') def test_get_request_with_bad_sid(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', - 'QUERY_STRING': 'sid=foo'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'sid=foo'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer() - self._run(s.handle_request('request')) + _run(s.handle_request('request')) self.assertEqual(len(s.sockets), 0) self.assertEqual(a.async['make_response'].call_args[0][0], '400 BAD REQUEST') @mock.patch('importlib.import_module') def test_post_request_with_bad_sid(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'POST', - 'QUERY_STRING': 'sid=foo'}) + a = self.get_async_mock({'REQUEST_METHOD': 'POST', + 'QUERY_STRING': 'sid=foo'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer() - self._run(s.handle_request('request')) + _run(s.handle_request('request')) self.assertEqual(len(s.sockets), 0) self.assertEqual(a.async['make_response'].call_args[0][0], '400 BAD REQUEST') @mock.patch('importlib.import_module') def test_send(self, import_module): - a = get_async_mock() + a = self.get_async_mock() import_module.side_effect = [a] s = asyncio_server.AsyncServer() s.sockets['foo'] = mock_socket = self._get_mock_socket() - self._run(s.send('foo', 'hello')) - self.assertEqual(mock_socket.send.call_count, 1) - self.assertEqual(mock_socket.send.call_args[0][0].packet_type, + _run(s.send('foo', 'hello')) + self.assertEqual(mock_socket.send.mock.call_count, 1) + self.assertEqual(mock_socket.send.mock.call_args[0][0].packet_type, packet.MESSAGE) - self.assertEqual(mock_socket.send.call_args[0][0].data, 'hello') + self.assertEqual(mock_socket.send.mock.call_args[0][0].data, 'hello') @mock.patch('importlib.import_module') def test_send_unknown_socket(self, import_module): - a = get_async_mock() + a = self.get_async_mock() import_module.side_effect = [a] s = asyncio_server.AsyncServer() # just ensure no exceptions are raised - self._run(s.send('foo', 'hello')) + _run(s.send('foo', 'hello')) @mock.patch('importlib.import_module') def test_get_request(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', - 'QUERY_STRING': 'sid=foo'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'sid=foo'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer() s.sockets['foo'] = mock_socket = self._get_mock_socket() - mock_socket.handle_get_request.return_value = get_mock_coro( - 'socket.handle_get_request', - return_value=[packet.Packet(packet.MESSAGE, data='hello')]) - self._run(s.handle_request('request')) + mock_socket.handle_get_request.mock.return_value = \ + [packet.Packet(packet.MESSAGE, data='hello')] + _run(s.handle_request('request')) self.assertEqual(a.async['make_response'].call_args[0][0], '200 OK') packets = payload.Payload( encoded_payload=a.async['make_response'].call_args[0][2]).packets @@ -433,20 +434,19 @@ def test_get_request(self, import_module): @mock.patch('importlib.import_module') def test_get_request_custom_response(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', - 'QUERY_STRING': 'sid=foo'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'sid=foo'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer() s.sockets['foo'] = mock_socket = self._get_mock_socket() - mock_socket.handle_get_request.return_value = get_mock_coro( - 'socket.handle_get_request', return_value='resp') - r = self._run(s.handle_request('request')) + mock_socket.handle_get_request.mock.return_value = 'resp' + r = _run(s.handle_request('request')) self.assertEqual(r, 'resp') @mock.patch('importlib.import_module') def test_get_request_closes_socket(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', - 'QUERY_STRING': 'sid=foo'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'sid=foo'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer() s.sockets['foo'] = mock_socket = self._get_mock_socket() @@ -456,15 +456,15 @@ def mock_get_request(*args, **kwargs): mock_socket.closed = True return 'resp' - mock_socket.handle_get_request.return_value = mock_get_request() - r = self._run(s.handle_request('request')) + mock_socket.handle_get_request.mock.return_value = mock_get_request() + r = _run(s.handle_request('request')) self.assertEqual(r, 'resp') self.assertNotIn('foo', s.sockets) @mock.patch('importlib.import_module') def test_get_request_error(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', - 'QUERY_STRING': 'sid=foo'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'sid=foo'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer() s.sockets['foo'] = mock_socket = self._get_mock_socket() @@ -473,26 +473,26 @@ def test_get_request_error(self, import_module): def mock_get_request(*args, **kwargs): raise IOError() - mock_socket.handle_get_request.return_value = mock_get_request() - self._run(s.handle_request('request')) + mock_socket.handle_get_request.mock.return_value = mock_get_request() + _run(s.handle_request('request')) self.assertEqual(a.async['make_response'].call_args[0][0], '400 BAD REQUEST') self.assertEqual(len(s.sockets), 0) @mock.patch('importlib.import_module') def test_post_request(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'POST', - 'QUERY_STRING': 'sid=foo'}) + a = self.get_async_mock({'REQUEST_METHOD': 'POST', + 'QUERY_STRING': 'sid=foo'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer() s.sockets['foo'] = self._get_mock_socket() - self._run(s.handle_request('request')) + _run(s.handle_request('request')) self.assertEqual(a.async['make_response'].call_args[0][0], '200 OK') @mock.patch('importlib.import_module') def test_post_request_error(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'POST', - 'QUERY_STRING': 'sid=foo'}) + a = self.get_async_mock({'REQUEST_METHOD': 'POST', + 'QUERY_STRING': 'sid=foo'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer() s.sockets['foo'] = mock_socket = self._get_mock_socket() @@ -501,8 +501,8 @@ def test_post_request_error(self, import_module): def mock_post_request(*args, **kwargs): raise ValueError() - mock_socket.handle_post_request.return_value = mock_post_request() - self._run(s.handle_request('request')) + mock_socket.handle_post_request.mock.return_value = mock_post_request() + _run(s.handle_request('request')) self.assertEqual(a.async['make_response'].call_args[0][0], '400 BAD REQUEST') @@ -514,48 +514,45 @@ def _gzip_decompress(b): @mock.patch('importlib.import_module') def test_gzip_compression(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', - 'QUERY_STRING': 'sid=foo', - 'ACCEPT_ENCODING': 'gzip,deflate'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'sid=foo', + 'ACCEPT_ENCODING': 'gzip,deflate'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer(compression_threshold=0) s.sockets['foo'] = mock_socket = self._get_mock_socket() - mock_socket.handle_get_request.return_value = get_mock_coro( - 'socket.handle_get_request', - return_value=[packet.Packet(packet.MESSAGE, data='hello')]) - self._run(s.handle_request('request')) + mock_socket.handle_get_request.mock.return_value = \ + [packet.Packet(packet.MESSAGE, data='hello')] + _run(s.handle_request('request')) headers = a.async['make_response'].call_args[0][1] self.assertIn(('Content-Encoding', 'gzip'), headers) self._gzip_decompress(a.async['make_response'].call_args[0][2]) @mock.patch('importlib.import_module') def test_deflate_compression(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', - 'QUERY_STRING': 'sid=foo', - 'ACCEPT_ENCODING': 'deflate;q=1,gzip'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'sid=foo', + 'ACCEPT_ENCODING': 'deflate;q=1,gzip'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer(compression_threshold=0) s.sockets['foo'] = mock_socket = self._get_mock_socket() - mock_socket.handle_get_request.return_value = get_mock_coro( - 'socket.handle_get_request', - return_value=[packet.Packet(packet.MESSAGE, data='hello')]) - self._run(s.handle_request('request')) + mock_socket.handle_get_request.mock.return_value = \ + [packet.Packet(packet.MESSAGE, data='hello')] + _run(s.handle_request('request')) headers = a.async['make_response'].call_args[0][1] self.assertIn(('Content-Encoding', 'deflate'), headers) zlib.decompress(a.async['make_response'].call_args[0][2]) @mock.patch('importlib.import_module') def test_gzip_compression_threshold(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', - 'QUERY_STRING': 'sid=foo', - 'ACCEPT_ENCODING': 'gzip'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'sid=foo', + 'ACCEPT_ENCODING': 'gzip'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer(compression_threshold=1000) s.sockets['foo'] = mock_socket = self._get_mock_socket() - mock_socket.handle_get_request.return_value = get_mock_coro( - 'socket.handle_get_request', - return_value=[packet.Packet(packet.MESSAGE, data='hello')]) - self._run(s.handle_request('request')) + mock_socket.handle_get_request.mock.return_value = \ + [packet.Packet(packet.MESSAGE, data='hello')] + _run(s.handle_request('request')) headers = a.async['make_response'].call_args[0][1] for header, value in headers: self.assertNotEqual(header, 'Content-Encoding') @@ -564,17 +561,16 @@ def test_gzip_compression_threshold(self, import_module): @mock.patch('importlib.import_module') def test_compression_disabled(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', - 'QUERY_STRING': 'sid=foo', - 'ACCEPT_ENCODING': 'gzip'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'sid=foo', + 'ACCEPT_ENCODING': 'gzip'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer(http_compression=False, compression_threshold=0) s.sockets['foo'] = mock_socket = self._get_mock_socket() - mock_socket.handle_get_request.return_value = get_mock_coro( - 'socket.handle_get_request', - return_value=[packet.Packet(packet.MESSAGE, data='hello')]) - self._run(s.handle_request('request')) + mock_socket.handle_get_request.mock.return_value = \ + [packet.Packet(packet.MESSAGE, data='hello')] + _run(s.handle_request('request')) headers = a.async['make_response'].call_args[0][1] for header, value in headers: self.assertNotEqual(header, 'Content-Encoding') @@ -583,16 +579,15 @@ def test_compression_disabled(self, import_module): @mock.patch('importlib.import_module') def test_compression_unknown(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', - 'QUERY_STRING': 'sid=foo', - 'ACCEPT_ENCODING': 'rar'}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'sid=foo', + 'ACCEPT_ENCODING': 'rar'}) import_module.side_effect = [a] s = asyncio_server.AsyncServer(compression_threshold=0) s.sockets['foo'] = mock_socket = self._get_mock_socket() - mock_socket.handle_get_request.return_value = get_mock_coro( - 'socket.handle_get_request', - return_value=[packet.Packet(packet.MESSAGE, data='hello')]) - self._run(s.handle_request('request')) + mock_socket.handle_get_request.mock.return_value = \ + [packet.Packet(packet.MESSAGE, data='hello')] + _run(s.handle_request('request')) headers = a.async['make_response'].call_args[0][1] for header, value in headers: self.assertNotEqual(header, 'Content-Encoding') @@ -601,16 +596,15 @@ def test_compression_unknown(self, import_module): @mock.patch('importlib.import_module') def test_compression_no_encoding(self, import_module): - a = get_async_mock({'REQUEST_METHOD': 'GET', - 'QUERY_STRING': 'sid=foo', - 'ACCEPT_ENCODING': ''}) + a = self.get_async_mock({'REQUEST_METHOD': 'GET', + 'QUERY_STRING': 'sid=foo', + 'ACCEPT_ENCODING': ''}) import_module.side_effect = [a] s = asyncio_server.AsyncServer(compression_threshold=0) s.sockets['foo'] = mock_socket = self._get_mock_socket() - mock_socket.handle_get_request.return_value = get_mock_coro( - 'socket.handle_get_request', - return_value=[packet.Packet(packet.MESSAGE, data='hello')]) - self._run(s.handle_request('request')) + mock_socket.handle_get_request.mock.return_value = \ + [packet.Packet(packet.MESSAGE, data='hello')] + _run(s.handle_request('request')) headers = a.async['make_response'].call_args[0][1] for header, value in headers: self.assertNotEqual(header, 'Content-Encoding') @@ -619,21 +613,21 @@ def test_compression_no_encoding(self, import_module): @mock.patch('importlib.import_module') def test_cookie(self, import_module): - a = get_async_mock() + a = self.get_async_mock() import_module.side_effect = [a] s = asyncio_server.AsyncServer(cookie='sid') s._generate_id = mock.MagicMock(return_value='123') - self._run(s.handle_request('request')) + _run(s.handle_request('request')) headers = a.async['make_response'].call_args[0][1] self.assertIn(('Set-Cookie', 'sid=123'), headers) @mock.patch('importlib.import_module') def test_no_cookie(self, import_module): - a = get_async_mock() + a = self.get_async_mock() import_module.side_effect = [a] s = asyncio_server.AsyncServer(cookie=None) s._generate_id = mock.MagicMock(return_value='123') - self._run(s.handle_request('request')) + _run(s.handle_request('request')) headers = a.async['make_response'].call_args[0][1] for header, value in headers: self.assertNotEqual(header, 'Set-Cookie') @@ -691,7 +685,7 @@ def foo_handler(arg): s = asyncio_server.AsyncServer() s.on('message', handler=foo_handler) - self._run(s._trigger_event('message', 'bar')) + _run(s._trigger_event('message', 'bar')) self.assertEqual(result, ['ok', 'bar']) def test_trigger_event_coroutine(self): @@ -704,5 +698,5 @@ def foo_handler(arg): s = asyncio_server.AsyncServer() s.on('message', handler=foo_handler) - self._run(s._trigger_event('message', 'bar')) + _run(s._trigger_event('message', 'bar')) self.assertEqual(result, ['ok', 'bar']) diff --git a/tests/test_asyncio_socket.py b/tests/test_asyncio_socket.py new file mode 100644 index 00000000..d5d7d9ad --- /dev/null +++ b/tests/test_asyncio_socket.py @@ -0,0 +1,403 @@ +import sys +import time +import unittest + +import six +if six.PY3: + from unittest import mock +else: + import mock + +from engineio import packet +from engineio import payload +if sys.version_info >= (3, 5): + import asyncio + from asyncio import coroutine + from engineio import asyncio_socket +else: + # mock coroutine so that Python 2 doesn't complain + def coroutine(f): + return f + + +def AsyncMock(*args, **kwargs): + """Return a mock asynchronous function.""" + m = mock.MagicMock(*args, **kwargs) + + @coroutine + def mock_coro(*args, **kwargs): + return m(*args, **kwargs) + + mock_coro.mock = m + return mock_coro + + +def _run(coro): + """Run the given coroutine.""" + return asyncio.get_event_loop().run_until_complete(coro) + + +@unittest.skipIf(sys.version_info < (3, 5), 'only for Python 3.5+') +class TestSocket(unittest.TestCase): + def _get_read_mock_coro(self, payload): + mock_input = mock.MagicMock() + mock_input.read = AsyncMock() + mock_input.read.mock.return_value = payload + return mock_input + + def _get_mock_server(self): + mock_server = mock.Mock() + mock_server.ping_timeout = 0.2 + mock_server.ping_interval = 0.2 + mock_server.async_handlers = True + mock_server._async = {'asyncio': True, + 'create_route': mock.MagicMock(), + 'translate_request': mock.MagicMock(), + 'make_response': mock.MagicMock(), + 'websocket': 'w', + 'websocket_class': 'wc'} + mock_server._async['translate_request'].return_value = 'request' + mock_server._async['make_response'].return_value = 'response' + mock_server._trigger_event = AsyncMock() + return mock_server + + def test_create(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + self.assertEqual(s.server, mock_server) + self.assertEqual(s.sid, 'sid') + self.assertFalse(s.upgraded) + self.assertFalse(s.closed) + self.assertTrue(hasattr(s.queue, 'get')) + self.assertTrue(hasattr(s.queue, 'put')) + self.assertTrue(hasattr(s.queue, 'task_done')) + self.assertTrue(hasattr(s.queue, 'join')) + + def test_empty_poll(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + self.assertRaises(IOError, _run, s.poll()) + + def test_poll(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + pkt1 = packet.Packet(packet.MESSAGE, data='hello') + pkt2 = packet.Packet(packet.MESSAGE, data='bye') + _run(s.send(pkt1)) + _run(s.send(pkt2)) + self.assertEqual(_run(s.poll()), [pkt1, pkt2]) + + def test_poll_none(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + _run(s.queue.put(None)) + self.assertEqual(_run(s.poll()), []) + + def test_ping_pong(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + _run(s.receive(packet.Packet(packet.PING, data='abc'))) + r = _run(s.poll()) + self.assertEqual(len(r), 1) + self.assertTrue(r[0].encode(), b'3abc') + + def test_message_handler(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + _run(s.receive(packet.Packet(packet.MESSAGE, data='foo'))) + mock_server._trigger_event.mock.assert_called_once_with( + 'message', 'sid', 'foo') + + def test_invalid_packet(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + self.assertRaises(ValueError, _run, + s.receive(packet.Packet(packet.OPEN))) + + def test_timeout(self): + mock_server = self._get_mock_server() + mock_server.ping_interval = -0.1 + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + s.last_ping = time.time() - 1 + s.close = AsyncMock() + _run(s.send('packet')) + s.close.mock.assert_called_once_with(wait=False, abort=True) + + def test_polling_read(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'foo') + pkt1 = packet.Packet(packet.MESSAGE, data='hello') + pkt2 = packet.Packet(packet.MESSAGE, data='bye') + _run(s.send(pkt1)) + _run(s.send(pkt2)) + environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'} + packets = _run(s.handle_get_request(environ)) + self.assertEqual(packets, [pkt1, pkt2]) + + def test_polling_read_error(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'foo') + environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'} + self.assertRaises(IOError, _run, s.handle_get_request(environ)) + + def test_polling_write(self): + mock_server = self._get_mock_server() + mock_server.max_http_buffer_size = 1000 + pkt1 = packet.Packet(packet.MESSAGE, data='hello') + pkt2 = packet.Packet(packet.MESSAGE, data='bye') + p = payload.Payload(packets=[pkt1, pkt2]).encode() + s = asyncio_socket.AsyncSocket(mock_server, 'foo') + s.receive = AsyncMock() + environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo', + 'CONTENT_LENGTH': len(p), + 'wsgi.input': self._get_read_mock_coro(p)} + _run(s.handle_post_request(environ)) + self.assertEqual(s.receive.mock.call_count, 2) + + def test_polling_write_too_large(self): + mock_server = self._get_mock_server() + pkt1 = packet.Packet(packet.MESSAGE, data='hello') + pkt2 = packet.Packet(packet.MESSAGE, data='bye') + p = payload.Payload(packets=[pkt1, pkt2]).encode() + mock_server.max_http_buffer_size = len(p) - 1 + s = asyncio_socket.AsyncSocket(mock_server, 'foo') + s.receive = AsyncMock() + environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo', + 'CONTENT_LENGTH': len(p), + 'wsgi.input': self._get_read_mock_coro(p)} + self.assertRaises(ValueError, _run, + s.handle_post_request(environ)) + + def test_upgrade_handshake(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'foo') + s._upgrade_websocket = AsyncMock() + environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo', + 'HTTP_CONNECTION': 'Foo,Upgrade,Bar', + 'HTTP_UPGRADE': 'websocket'} + _run(s.handle_get_request(environ)) + s._upgrade_websocket.mock.assert_called_once_with(environ) + + def test_upgrade(self): + mock_server = self._get_mock_server() + mock_server._async['websocket'] = mock.MagicMock() + mock_server._async['websocket_class'] = 'WebSocket' + mock_ws = AsyncMock() + mock_server._async['websocket'].WebSocket.return_value = mock_ws + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + s.connected = True + environ = "foo" + _run(s._upgrade_websocket(environ)) + mock_server._async['websocket'].WebSocket.assert_called_once_with( + s._websocket_handler) + mock_ws.mock.assert_called_once_with(environ) + + def test_upgrade_twice(self): + mock_server = self._get_mock_server() + mock_server._async['websocket'] = mock.MagicMock() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + s.connected = True + s.upgraded = True + environ = "foo" + self.assertRaises(IOError, _run, s._upgrade_websocket(environ)) + + def test_upgrade_packet(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + s.connected = True + _run(s.receive(packet.Packet(packet.UPGRADE))) + r = _run(s.poll()) + self.assertEqual(len(r), 1) + self.assertEqual(r[0].encode(), packet.Packet(packet.NOOP).encode()) + + def test_upgrade_no_probe(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + s.connected = True + ws = mock.MagicMock() + ws.wait = AsyncMock() + ws.wait.mock.return_value = packet.Packet(packet.NOOP).encode( + always_bytes=False) + _run(s._websocket_handler(ws)) + self.assertFalse(s.upgraded) + + def test_upgrade_no_upgrade_packet(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + s.connected = True + s.queue.join = AsyncMock(return_value=None) + ws = mock.MagicMock() + ws.send = AsyncMock() + ws.wait = AsyncMock() + probe = six.text_type('probe') + ws.wait.mock.side_effect = [ + packet.Packet(packet.PING, data=probe).encode( + always_bytes=False), + packet.Packet(packet.NOOP).encode(always_bytes=False)] + _run(s._websocket_handler(ws)) + ws.send.mock.assert_called_once_with(packet.Packet( + packet.PONG, data=probe).encode(always_bytes=False)) + self.assertEqual(_run(s.queue.get()).packet_type, packet.NOOP) + self.assertFalse(s.upgraded) + + def test_upgrade_not_supported(self): + mock_server = self._get_mock_server() + mock_server._async['websocket'] = None + mock_server._async['websocket_class'] = None + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + s.connected = True + environ = "foo" + _run(s._upgrade_websocket(environ)) + mock_server._bad_request.assert_called_once_with() + + def test_websocket_read_write(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + s.connected = False + s.queue.join = AsyncMock(return_value=None) + foo = six.text_type('foo') + bar = six.text_type('bar') + s.poll = AsyncMock(side_effect=[ + [packet.Packet(packet.MESSAGE, data=bar)], None]) + ws = mock.MagicMock() + ws.send = AsyncMock() + ws.wait = AsyncMock() + ws.wait.mock.side_effect = [ + packet.Packet(packet.MESSAGE, data=foo).encode( + always_bytes=False), + None] + _run(s._websocket_handler(ws)) + self.assertTrue(s.connected) + self.assertTrue(s.upgraded) + self.assertEqual(mock_server._trigger_event.mock.call_count, 2) + mock_server._trigger_event.mock.assert_has_calls([ + mock.call('message', 'sid', 'foo'), + mock.call('disconnect', 'sid')]) + ws.send.mock.assert_called_with('4bar') + + def test_websocket_upgrade_read_write(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + s.connected = True + s.queue.join = AsyncMock(return_value=None) + foo = six.text_type('foo') + bar = six.text_type('bar') + probe = six.text_type('probe') + s.poll = AsyncMock(side_effect=[ + [packet.Packet(packet.MESSAGE, data=bar)], IOError]) + ws = mock.MagicMock() + ws.send = AsyncMock() + ws.wait = AsyncMock() + ws.wait.mock.side_effect = [ + packet.Packet(packet.PING, data=probe).encode( + always_bytes=False), + packet.Packet(packet.UPGRADE).encode(always_bytes=False), + packet.Packet(packet.MESSAGE, data=foo).encode( + always_bytes=False), + None] + _run(s._websocket_handler(ws)) + self.assertTrue(s.upgraded) + self.assertEqual(mock_server._trigger_event.mock.call_count, 2) + mock_server._trigger_event.mock.assert_has_calls([ + mock.call('message', 'sid', 'foo'), + mock.call('disconnect', 'sid')]) + ws.send.mock.assert_called_with('4bar') + + def test_websocket_upgrade_with_payload(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + s.connected = True + s.queue.join = AsyncMock(return_value=None) + probe = six.text_type('probe') + ws = mock.MagicMock() + ws.send = AsyncMock() + ws.wait = AsyncMock() + ws.wait.mock.side_effect = [ + packet.Packet(packet.PING, data=probe).encode( + always_bytes=False), + packet.Packet(packet.UPGRADE, data=b'2').encode( + always_bytes=False)] + _run(s._websocket_handler(ws)) + self.assertTrue(s.upgraded) + + def test_websocket_read_write_fail(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + s.connected = False + s.queue.join = AsyncMock(return_value=None) + foo = six.text_type('foo') + bar = six.text_type('bar') + s.poll = AsyncMock(side_effect=[ + [packet.Packet(packet.MESSAGE, data=bar)], + [packet.Packet(packet.MESSAGE, data=bar)], IOError]) + ws = mock.MagicMock() + ws.send = AsyncMock() + ws.wait = AsyncMock() + ws.wait.mock.side_effect = [ + packet.Packet(packet.MESSAGE, data=foo).encode( + always_bytes=False), + RuntimeError] + ws.send.mock.side_effect = [None, RuntimeError] + _run(s._websocket_handler(ws)) + self.assertEqual(s.closed, True) + + def test_websocket_ignore_invalid_packet(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + s.connected = False + s.queue.join = AsyncMock(return_value=None) + foo = six.text_type('foo') + bar = six.text_type('bar') + s.poll = AsyncMock(side_effect=[ + [packet.Packet(packet.MESSAGE, data=bar)], IOError]) + ws = mock.MagicMock() + ws.send = AsyncMock() + ws.wait = AsyncMock() + ws.wait.mock.side_effect = [ + packet.Packet(packet.OPEN).encode(always_bytes=False), + packet.Packet(packet.MESSAGE, data=foo).encode( + always_bytes=False), + None] + _run(s._websocket_handler(ws)) + self.assertTrue(s.connected) + self.assertEqual(mock_server._trigger_event.mock.call_count, 2) + mock_server._trigger_event.mock.assert_has_calls([ + mock.call('message', 'sid', foo), + mock.call('disconnect', 'sid')]) + ws.send.mock.assert_called_with('4bar') + + def test_send_after_close(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + _run(s.close(wait=False)) + self.assertRaises(IOError, _run, + s.send(packet.Packet(packet.NOOP))) + + def test_close_after_close(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + _run(s.close(wait=False)) + self.assertTrue(s.closed) + self.assertEqual(mock_server._trigger_event.mock.call_count, 1) + mock_server._trigger_event.mock.assert_called_once_with('disconnect', + 'sid') + _run(s.close()) + self.assertEqual(mock_server._trigger_event.mock.call_count, 1) + + def test_close_and_wait(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + s.queue = mock.MagicMock() + s.queue.put = AsyncMock() + s.queue.join = AsyncMock() + _run(s.close(wait=True)) + s.queue.join.mock.assert_called_once_with() + + def test_close_without_wait(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + s.queue = mock.MagicMock() + s.queue.put = AsyncMock() + s.queue.join = AsyncMock() + _run(s.close(wait=False)) + self.assertEqual(s.queue.join.mock.call_count, 0) diff --git a/tests/test_server.py b/tests/test_server.py index f46a75d8..f0b54ee4 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -21,9 +21,9 @@ original_import_module = importlib.import_module -def _mock_import(module, pkg=None): +def _mock_import(module, *args, **kwargs): if module.startswith('engineio.'): - return original_import_module(module, pkg) + return original_import_module(module, *args, **kwargs) return module diff --git a/tests/test_socket.py b/tests/test_socket.py index 6d68899f..3d2d5d4b 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -170,8 +170,7 @@ def test_upgrade(self): mock_server._async['websocket'] = mock.MagicMock() mock_server._async['websocket_class'] = 'WebSocket' mock_ws = mock.MagicMock() - mock_server._async['websocket'].WebSocket.configure_mock( - return_value=mock_ws) + mock_server._async['websocket'].WebSocket.return_value = mock_ws s = socket.Socket(mock_server, 'sid') s.connected = True environ = "foo" diff --git a/tox.ini b/tox.ini index e1771e4e..b784257a 100644 --- a/tox.ini +++ b/tox.ini @@ -13,11 +13,14 @@ deps= eventlet aiohttp basepython = + flake8: python3.6 py27: python2.7 py34: python3.4 py35: python3.5 py36: python3.6 pypy: pypy + coverage: python3.6 + docs: python3.6 [testenv:py27] deps= @@ -32,14 +35,12 @@ deps= eventlet [testenv:flake8] -basepython=python deps= flake8 commands= flake8 --exclude=".*" --ignore=E402 engineio tests [testenv:docs] -basepython=python2.7 changedir=docs deps= sphinx @@ -49,7 +50,6 @@ commands= make html [testenv:coverage] -basepython=python commands= coverage run --branch --source=engineio setup.py test coverage html