Skip to content

Commit

Permalink
Support custom headers and query string in test client
Browse files Browse the repository at this point in the history
Fixes #520
  • Loading branch information
miguelgrinberg committed Aug 5, 2017
1 parent 8a09692 commit ecf5925
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 9 deletions.
5 changes: 3 additions & 2 deletions flask_socketio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,9 +583,10 @@ def sleep(self, seconds=0):
"""
return self.server.sleep(seconds)

def test_client(self, app, namespace=None):
def test_client(self, app, namespace=None, query_string=None, headers=None):
"""Return a simple SocketIO client that can be used for unit tests."""
return SocketIOTestClient(app, self, namespace)
return SocketIOTestClient(app, self, namespace=namespace,
query_string=query_string, headers=headers)

def _handle_event(self, handler, message, namespace, sid, *args):
if sid not in self.server.environ:
Expand Down
19 changes: 15 additions & 4 deletions flask_socketio/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@ class SocketIOTestClient(object):
:param socketio: The application's ``SocketIO`` instance.
:param namespace: The namespace for the client. If not provided, the client
connects to the server on the global namespace.
:param query_string: A string with custom query string arguments.
:param headers: A dictionary with custom HTTP headers.
"""
queue = {}
ack = None

def __init__(self, app, socketio, namespace=None):
def __init__(self, app, socketio, namespace=None, query_string=None,
headers=None):
def _mock_send_packet(sid, pkt):
if pkt.packet_type == packet.EVENT or \
pkt.packet_type == packet.BINARY_EVENT:
Expand Down Expand Up @@ -49,21 +52,29 @@ def _mock_send_packet(sid, pkt):
'queue. Disable the queue on your test '
'configuration.')
socketio.server.manager.initialize()
self.connect(namespace)
self.connect(namespace=namespace, query_string=query_string,
headers=headers)

def connect(self, namespace=None):
def connect(self, namespace=None, query_string=None, headers=None):
"""Connect the client.
:param namespace: The namespace for the client. If not provided, the
client connects to the server on the global
namespace.
:param query_string: A string with custom query string arguments.
:param headers: A dictionary with custom HTTP headers.
Note that it is usually not necessary to explicitly call this method,
since a connection is automatically established when an instance of
this class is created. An example where it this method would be useful
is when the application accepts multiple namespace connections.
"""
environ = EnvironBuilder('/socket.io').get_environ()
url = '/socket.io'
if query_string:
if query_string[0] != '?':
query_string = '?' + query_string
url += query_string
environ = EnvironBuilder(url, headers=headers).get_environ()
environ['flask.app'] = self.app
self.socketio.server._handle_eio_connect(self.sid, environ)
if namespace is not None and namespace != '/':
Expand Down
58 changes: 55 additions & 3 deletions test_socketio.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import unittest
import coverage

Expand All @@ -17,6 +18,9 @@
@socketio.on('connect')
def on_connect():
send('connected')
send(json.dumps(dict(request.args)))
send(json.dumps({h: request.headers[h] for h in request.headers.keys()
if h not in ['Host', 'Content-Type', 'Content-Length']}))


@socketio.on('disconnect')
Expand All @@ -28,6 +32,9 @@ def on_disconnect():
@socketio.on('connect', namespace='/test')
def on_connect_test():
send('connected-test')
send(json.dumps(dict(request.args)))
send(json.dumps({h: request.headers[h] for h in request.headers.keys()
if h not in ['Host', 'Content-Type', 'Content-Length']}))


@socketio.on('disconnect', namespace='/test')
Expand Down Expand Up @@ -188,6 +195,10 @@ def raise_error_default(data):
class MyNamespace(Namespace):
def on_connect(self):
send('connected-ns')
send(json.dumps(dict(request.args)))
send(json.dumps(
{h: request.headers[h] for h in request.headers.keys()
if h not in ['Host', 'Content-Type', 'Content-Length']}))

def on_disconnect(self):
global disconnected
Expand Down Expand Up @@ -238,15 +249,42 @@ def tearDown(self):
def test_connect(self):
client = socketio.test_client(app)
received = client.get_received()
self.assertEqual(len(received), 1)
self.assertEqual(len(received), 3)
self.assertEqual(received[0]['args'], 'connected')
self.assertEqual(received[1]['args'], '{}')
self.assertEqual(received[2]['args'], '{}')
client.disconnect()

def test_connect_query_string_and_headers(self):
client = socketio.test_client(
app, query_string='?foo=bar&foo=baz',
headers={'Authorization': 'Bearer foobar'})
received = client.get_received()
self.assertEqual(len(received), 3)
self.assertEqual(received[0]['args'], 'connected')
self.assertEqual(received[1]['args'], '{"foo": ["bar", "baz"]}')
self.assertEqual(received[2]['args'],
'{"Authorization": "Bearer foobar"}')
client.disconnect()

def test_connect_namespace(self):
client = socketio.test_client(app, namespace='/test')
received = client.get_received('/test')
self.assertEqual(len(received), 1)
self.assertEqual(len(received), 3)
self.assertEqual(received[0]['args'], 'connected-test')
self.assertEqual(received[1]['args'], '{}')
self.assertEqual(received[2]['args'], '{}')
client.disconnect(namespace='/test')

def test_connect_namespace_query_string_and_headers(self):
client = socketio.test_client(
app, namespace='/test', query_string='foo=bar',
headers={'My-Custom-Header': 'Value'})
received = client.get_received('/test')
self.assertEqual(len(received), 3)
self.assertEqual(received[0]['args'], 'connected-test')
self.assertEqual(received[1]['args'], '{"foo": ["bar"]}')
self.assertEqual(received[2]['args'], '{"My-Custom-Header": "Value"}')
client.disconnect(namespace='/test')

def test_disconnect(self):
Expand Down Expand Up @@ -507,8 +545,22 @@ def test_on_event(self):
def test_connect_class_based(self):
client = socketio.test_client(app, namespace='/ns')
received = client.get_received('/ns')
self.assertEqual(len(received), 1)
self.assertEqual(len(received), 3)
self.assertEqual(received[0]['args'], 'connected-ns')
self.assertEqual(received[1]['args'], '{}')
self.assertEqual(received[2]['args'], '{}')
client.disconnect('/ns')

def test_connect_class_based_query_string_and_headers(self):
client = socketio.test_client(
app, namespace='/ns', query_string='foo=bar',
headers={'Authorization': 'Basic foobar'})
received = client.get_received('/ns')
self.assertEqual(len(received), 3)
self.assertEqual(received[0]['args'], 'connected-ns')
self.assertEqual(received[1]['args'], '{"foo": ["bar"]}')
self.assertEqual(received[2]['args'],
'{"Authorization": "Basic foobar"}')
client.disconnect('/ns')

def test_disconnect_class_based(self):
Expand Down

0 comments on commit ecf5925

Please sign in to comment.