From ecf5925827a916ef52361856290f16087c4e36e9 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Sat, 5 Aug 2017 10:07:38 -0700 Subject: [PATCH] Support custom headers and query string in test client Fixes #520 --- flask_socketio/__init__.py | 5 +-- flask_socketio/test_client.py | 19 +++++++++--- test_socketio.py | 58 +++++++++++++++++++++++++++++++++-- 3 files changed, 73 insertions(+), 9 deletions(-) diff --git a/flask_socketio/__init__.py b/flask_socketio/__init__.py index 0c2d7bfc..032ae933 100644 --- a/flask_socketio/__init__.py +++ b/flask_socketio/__init__.py @@ -583,9 +583,10 @@ def sleep(self, seconds=0): """ return self.server.sleep(seconds) - def test_client(self, app, namespace=None): + def test_client(self, app, namespace=None, query_string=None, headers=None): """Return a simple SocketIO client that can be used for unit tests.""" - return SocketIOTestClient(app, self, namespace) + return SocketIOTestClient(app, self, namespace=namespace, + query_string=query_string, headers=headers) def _handle_event(self, handler, message, namespace, sid, *args): if sid not in self.server.environ: diff --git a/flask_socketio/test_client.py b/flask_socketio/test_client.py index 6aa5e469..a2966890 100644 --- a/flask_socketio/test_client.py +++ b/flask_socketio/test_client.py @@ -14,11 +14,14 @@ class SocketIOTestClient(object): :param socketio: The application's ``SocketIO`` instance. :param namespace: The namespace for the client. If not provided, the client connects to the server on the global namespace. + :param query_string: A string with custom query string arguments. + :param headers: A dictionary with custom HTTP headers. """ queue = {} ack = None - def __init__(self, app, socketio, namespace=None): + def __init__(self, app, socketio, namespace=None, query_string=None, + headers=None): def _mock_send_packet(sid, pkt): if pkt.packet_type == packet.EVENT or \ pkt.packet_type == packet.BINARY_EVENT: @@ -49,21 +52,29 @@ def _mock_send_packet(sid, pkt): 'queue. Disable the queue on your test ' 'configuration.') socketio.server.manager.initialize() - self.connect(namespace) + self.connect(namespace=namespace, query_string=query_string, + headers=headers) - def connect(self, namespace=None): + def connect(self, namespace=None, query_string=None, headers=None): """Connect the client. :param namespace: The namespace for the client. If not provided, the client connects to the server on the global namespace. + :param query_string: A string with custom query string arguments. + :param headers: A dictionary with custom HTTP headers. Note that it is usually not necessary to explicitly call this method, since a connection is automatically established when an instance of this class is created. An example where it this method would be useful is when the application accepts multiple namespace connections. """ - environ = EnvironBuilder('/socket.io').get_environ() + url = '/socket.io' + if query_string: + if query_string[0] != '?': + query_string = '?' + query_string + url += query_string + environ = EnvironBuilder(url, headers=headers).get_environ() environ['flask.app'] = self.app self.socketio.server._handle_eio_connect(self.sid, environ) if namespace is not None and namespace != '/': diff --git a/test_socketio.py b/test_socketio.py index 22afbc41..4fad4ea8 100755 --- a/test_socketio.py +++ b/test_socketio.py @@ -1,3 +1,4 @@ +import json import unittest import coverage @@ -17,6 +18,9 @@ @socketio.on('connect') def on_connect(): send('connected') + send(json.dumps(dict(request.args))) + send(json.dumps({h: request.headers[h] for h in request.headers.keys() + if h not in ['Host', 'Content-Type', 'Content-Length']})) @socketio.on('disconnect') @@ -28,6 +32,9 @@ def on_disconnect(): @socketio.on('connect', namespace='/test') def on_connect_test(): send('connected-test') + send(json.dumps(dict(request.args))) + send(json.dumps({h: request.headers[h] for h in request.headers.keys() + if h not in ['Host', 'Content-Type', 'Content-Length']})) @socketio.on('disconnect', namespace='/test') @@ -188,6 +195,10 @@ def raise_error_default(data): class MyNamespace(Namespace): def on_connect(self): send('connected-ns') + send(json.dumps(dict(request.args))) + send(json.dumps( + {h: request.headers[h] for h in request.headers.keys() + if h not in ['Host', 'Content-Type', 'Content-Length']})) def on_disconnect(self): global disconnected @@ -238,15 +249,42 @@ def tearDown(self): def test_connect(self): client = socketio.test_client(app) received = client.get_received() - self.assertEqual(len(received), 1) + self.assertEqual(len(received), 3) + self.assertEqual(received[0]['args'], 'connected') + self.assertEqual(received[1]['args'], '{}') + self.assertEqual(received[2]['args'], '{}') + client.disconnect() + + def test_connect_query_string_and_headers(self): + client = socketio.test_client( + app, query_string='?foo=bar&foo=baz', + headers={'Authorization': 'Bearer foobar'}) + received = client.get_received() + self.assertEqual(len(received), 3) self.assertEqual(received[0]['args'], 'connected') + self.assertEqual(received[1]['args'], '{"foo": ["bar", "baz"]}') + self.assertEqual(received[2]['args'], + '{"Authorization": "Bearer foobar"}') client.disconnect() def test_connect_namespace(self): client = socketio.test_client(app, namespace='/test') received = client.get_received('/test') - self.assertEqual(len(received), 1) + self.assertEqual(len(received), 3) self.assertEqual(received[0]['args'], 'connected-test') + self.assertEqual(received[1]['args'], '{}') + self.assertEqual(received[2]['args'], '{}') + client.disconnect(namespace='/test') + + def test_connect_namespace_query_string_and_headers(self): + client = socketio.test_client( + app, namespace='/test', query_string='foo=bar', + headers={'My-Custom-Header': 'Value'}) + received = client.get_received('/test') + self.assertEqual(len(received), 3) + self.assertEqual(received[0]['args'], 'connected-test') + self.assertEqual(received[1]['args'], '{"foo": ["bar"]}') + self.assertEqual(received[2]['args'], '{"My-Custom-Header": "Value"}') client.disconnect(namespace='/test') def test_disconnect(self): @@ -507,8 +545,22 @@ def test_on_event(self): def test_connect_class_based(self): client = socketio.test_client(app, namespace='/ns') received = client.get_received('/ns') - self.assertEqual(len(received), 1) + self.assertEqual(len(received), 3) + self.assertEqual(received[0]['args'], 'connected-ns') + self.assertEqual(received[1]['args'], '{}') + self.assertEqual(received[2]['args'], '{}') + client.disconnect('/ns') + + def test_connect_class_based_query_string_and_headers(self): + client = socketio.test_client( + app, namespace='/ns', query_string='foo=bar', + headers={'Authorization': 'Basic foobar'}) + received = client.get_received('/ns') + self.assertEqual(len(received), 3) self.assertEqual(received[0]['args'], 'connected-ns') + self.assertEqual(received[1]['args'], '{"foo": ["bar"]}') + self.assertEqual(received[2]['args'], + '{"Authorization": "Basic foobar"}') client.disconnect('/ns') def test_disconnect_class_based(self):