From 45e97b8cf885e998168857c46e29a7e257754f3e Mon Sep 17 00:00:00 2001 From: Bruce Yu Date: Tue, 25 Oct 2022 10:19:40 -0400 Subject: [PATCH] Allow configuring underlying websocket connection with custom options (Fixes #293) --- src/engineio/asyncio_client.py | 31 +++++++++++++++++----------- src/engineio/client.py | 20 ++++++++++++++---- tests/asyncio/test_asyncio_client.py | 24 +++++++++++++++++++++ tests/common/test_client.py | 20 ++++++++++++++++++ 4 files changed, 79 insertions(+), 16 deletions(-) diff --git a/src/engineio/asyncio_client.py b/src/engineio/asyncio_client.py index 976b85af..d82e10dd 100644 --- a/src/engineio/asyncio_client.py +++ b/src/engineio/asyncio_client.py @@ -62,7 +62,11 @@ class AsyncClient(client.Client): leave interrupt handling to the calling application. Interrupt handling can only be enabled when the client instance is created in the main thread. + :param websocket_extra_options: Dictionary containing additional keyword + arguments passed to + ``aiohttp.ws_connect()``. """ + def is_asyncio_based(self): return True @@ -297,19 +301,22 @@ async def _connect_websocket(self, url, headers, engineio_path): break self.http.cookie_jar.update_cookies(cookies) + extra_options = {'timeout': self.request_timeout} + if not self.ssl_verify: + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + extra_options['ssl'] = ssl_context + + # combine internally generated options with the ones supplied by the + # caller. The caller's options take precedence. + headers.update(self.websocket_extra_options.pop('headers', {})) + extra_options['headers'] = headers + extra_options.update(self.websocket_extra_options) + try: - if not self.ssl_verify: - ssl_context = ssl.create_default_context() - ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE - ws = await self.http.ws_connect( - websocket_url + self._get_url_timestamp(), - headers=headers, ssl=ssl_context, - timeout=self.request_timeout) - else: - ws = await self.http.ws_connect( - websocket_url + self._get_url_timestamp(), - headers=headers, timeout=self.request_timeout) + ws = await self.http.ws_connect( + websocket_url + self._get_url_timestamp(), **extra_options) except (aiohttp.client_exceptions.WSServerHandshakeError, aiohttp.client_exceptions.ServerConnectionError, aiohttp.client_exceptions.ClientConnectionError): diff --git a/src/engineio/client.py b/src/engineio/client.py index df4da079..ad01b758 100644 --- a/src/engineio/client.py +++ b/src/engineio/client.py @@ -71,11 +71,15 @@ class Client(object): leave interrupt handling to the calling application. Interrupt handling can only be enabled when the client instance is created in the main thread. + :param websocket_extra_options: Dictionary containing additional keyword + arguments passed to + ``websocket.create_connection()``. """ event_names = ['connect', 'disconnect', 'message'] def __init__(self, logger=False, json=None, request_timeout=5, - http_session=None, ssl_verify=True, handle_sigint=True): + http_session=None, ssl_verify=True, handle_sigint=True, + websocket_extra_options=None): global original_signal_handler if handle_sigint and original_signal_handler is None and \ threading.current_thread() == threading.main_thread(): @@ -97,6 +101,7 @@ def __init__(self, logger=False, json=None, request_timeout=5, self.queue = None self.state = 'disconnected' self.ssl_verify = ssl_verify + self.websocket_extra_options = websocket_extra_options or {} if json is not None: packet.Packet.json = json @@ -414,11 +419,18 @@ def _connect_websocket(self, url, headers, engineio_path): if not self.ssl_verify: extra_options['sslopt'] = {"cert_reqs": ssl.CERT_NONE} + + # combine internally generated options with the ones supplied by the + # caller. The caller's options take precedence. + headers.update(self.websocket_extra_options.pop('header', {})) + extra_options['header'] = headers + extra_options['cookie'] = cookies + extra_options['enable_multithread'] = True + extra_options['timeout'] = self.request_timeout + extra_options.update(self.websocket_extra_options) try: ws = websocket.create_connection( - websocket_url + self._get_url_timestamp(), header=headers, - cookie=cookies, enable_multithread=True, - timeout=self.request_timeout, **extra_options) + websocket_url + self._get_url_timestamp(), **extra_options) except (ConnectionError, IOError, websocket.WebSocketException): if upgrade: self.logger.warning( diff --git a/tests/asyncio/test_asyncio_client.py b/tests/asyncio/test_asyncio_client.py index f86483db..c3371ce2 100644 --- a/tests/asyncio/test_asyncio_client.py +++ b/tests/asyncio/test_asyncio_client.py @@ -542,6 +542,30 @@ def test_websocket_connection_failed(self, _time): timeout=5 ) + @mock.patch('engineio.client.time.time', return_value=123.456) + def test_websocket_connection_extra(self, _time): + c = asyncio_client.AsyncClient(websocket_extra_options={ + 'headers': {'Baz': 'Qux'}, + 'timeout': 10 + }) + c.http = mock.MagicMock(closed=False) + c.http.ws_connect = AsyncMock( + side_effect=[aiohttp.client_exceptions.ServerConnectionError()] + ) + with pytest.raises(exceptions.ConnectionError): + _run( + c.connect( + 'http://foo', + transports=['websocket'], + headers={'Foo': 'Bar'}, + ) + ) + c.http.ws_connect.mock.assert_called_once_with( + 'ws://foo/engine.io/?transport=websocket&EIO=4&t=123.456', + headers={'Foo': 'Bar', 'Baz': 'Qux'}, + timeout=10, + ) + @mock.patch('engineio.client.time.time', return_value=123.456) def test_websocket_upgrade_failed(self, _time): c = asyncio_client.AsyncClient() diff --git a/tests/common/test_client.py b/tests/common/test_client.py index 1af4957f..6d932c91 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -561,6 +561,26 @@ def test_websocket_connection_failed(self, create_connection, _time): timeout=5 ) + @mock.patch('engineio.client.time.time', return_value=123.456) + @mock.patch( + 'engineio.client.websocket.create_connection', + side_effect=[ConnectionError], + ) + def test_websocket_connection_extra(self, create_connection, _time): + c = client.Client(websocket_extra_options={'header': {'Baz': 'Qux'}, + 'timeout': 10}) + with pytest.raises(exceptions.ConnectionError): + c.connect( + 'http://foo', transports=['websocket'], headers={'Foo': 'Bar'} + ) + create_connection.assert_called_once_with( + 'ws://foo/engine.io/?transport=websocket&EIO=4&t=123.456', + header={'Foo': 'Bar', 'Baz': 'Qux'}, + cookie=None, + enable_multithread=True, + timeout=10 + ) + @mock.patch('engineio.client.time.time', return_value=123.456) @mock.patch( 'engineio.client.websocket.create_connection',