Skip to content

Commit

Permalink
Pass auth data from client in connect event handler (Fixes #1555)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed May 23, 2021
1 parent f9036eb commit 43dc6e5
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 39 deletions.
19 changes: 16 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ Custom named events can also support multiple arguments::
print('received args: ' + arg1 + arg2 + arg3)

When the name of the event is a valid Python identifier that does not collide
with other defined symbols, the ``@socketio.event`` provides a more compact
syntax that takes the event name from the decorated function::
with other defined symbols, the ``@socketio.event`` decorator provides a more
compact syntax that takes the event name from the decorated function::

@socketio.event
def my_custom_event(arg1, arg2, arg3):
Expand Down Expand Up @@ -345,13 +345,19 @@ Flask-SocketIO also dispatches connection and disconnection events. The
following example shows how to register handlers for them::

@socketio.on('connect')
def test_connect():
def test_connect(auth):
emit('my response', {'data': 'Connected'})

@socketio.on('disconnect')
def test_disconnect():
print('Client disconnected')

The ``auth`` argument in the connection handler is optional. The client can
use it to pass authentication data such as tokens in dictionary format. If the
client does not provide authentication details, then this argument is set to
``None``. If the server defines a connection event handler without this
argument, then any authentication data passed by the cient is discarded.

The connection event handler can return ``False`` to reject the connection, or
it can also raise `ConectionRefusedError`. This is so that the client can be
authenticated at this point. When using the exception, any arguments passed to
Expand Down Expand Up @@ -517,6 +523,13 @@ user's identity can then be recorded in the user session or in a cookie, and
later when the SocketIO connection is established that information will be
accessible to SocketIO event handlers.

Recent revisions of the Socket.IO protocol include the ability to pass a
dictionary with authentication information during the connection. This is an
ideal place for the client to include a token or other authentication details.
If the client uses this capability, the server will provide this dictionary as
an argument to the ``connect`` event handler, as shown above.


Using Flask-Login with Flask-SocketIO
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
10 changes: 8 additions & 2 deletions flask_socketio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ def sleep(self, seconds=0):
return self.server.sleep(seconds)

def test_client(self, app, namespace=None, query_string=None,
headers=None, flask_test_client=None):
headers=None, auth=None, flask_test_client=None):
"""The Socket.IO test client is useful for testing a Flask-SocketIO
server. It works in a similar way to the Flask Test Client, but
adapted to the Socket.IO server.
Expand All @@ -719,6 +719,7 @@ def test_client(self, app, namespace=None, query_string=None,
namespace.
:param query_string: A string with custom query string arguments.
:param headers: A dictionary with custom HTTP headers.
:param auth: Optional authentication data, given as a dictionary.
:param flask_test_client: The instance of the Flask test client
currently in use. Passing the Flask test
client is optional, but is necessary if you
Expand All @@ -728,6 +729,7 @@ def test_client(self, app, namespace=None, query_string=None,
"""
return SocketIOTestClient(app, self, namespace=namespace,
query_string=query_string, headers=headers,
auth=auth,
flask_test_client=flask_test_client)

def _handle_event(self, handler, message, namespace, sid, *args):
Expand Down Expand Up @@ -756,7 +758,11 @@ def _handle_event(self, handler, message, namespace, sid, *args):
flask.request.event = {'message': message, 'args': args}
try:
if message == 'connect':
ret = handler()
auth = args[1] if len(args) > 1 else None
try:
ret = handler(auth)
except TypeError:
ret = handler()
else:
ret = handler(*args)
except:
Expand Down
11 changes: 7 additions & 4 deletions flask_socketio/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class SocketIOTestClient(object):
connects to the server on the global namespace.
:param query_string: A string with custom query string arguments.
:param headers: A dictionary with custom HTTP headers.
:param auth: Optional authentication data, given as a dictionary.
:param flask_test_client: The instance of the Flask test client
currently in use. Passing the Flask test
client is optional, but is necessary if you
Expand All @@ -27,7 +28,7 @@ class SocketIOTestClient(object):
acks = {}

def __init__(self, app, socketio, namespace=None, query_string=None,
headers=None, flask_test_client=None):
headers=None, auth=None, flask_test_client=None):
def _mock_send_packet(eio_sid, pkt):
# make sure the packet can be encoded and decoded
epkt = pkt.encode()
Expand Down Expand Up @@ -76,7 +77,7 @@ def _mock_send_packet(eio_sid, pkt):
'configuration.')
socketio.server.manager.initialize()
self.connect(namespace=namespace, query_string=query_string,
headers=headers)
headers=headers, auth=auth)

def is_connected(self, namespace=None):
"""Check if a namespace is connected.
Expand All @@ -86,14 +87,16 @@ def is_connected(self, namespace=None):
"""
return self.connected.get(namespace or '/', False)

def connect(self, namespace=None, query_string=None, headers=None):
def connect(self, namespace=None, query_string=None, headers=None,
auth=None):
"""Connect the client.
:param namespace: The namespace for the client. If not provided, the
client connects to the server on the global
namespace.
:param query_string: A string with custom query string arguments.
:param headers: A dictionary with custom HTTP headers.
:param auth: Optional authentication data, given as a dictionary.
Note that it is usually not necessary to explicitly call this method,
since a connection is automatically established when an instance of
Expand All @@ -112,7 +115,7 @@ def connect(self, namespace=None, query_string=None, headers=None):
# inject cookies from Flask
self.flask_test_client.cookie_jar.inject_wsgi(environ)
self.socketio.server._handle_eio_connect(self.eio_sid, environ)
pkt = packet.Packet(packet.CONNECT, namespace=namespace)
pkt = packet.Packet(packet.CONNECT, auth, namespace=namespace)
with self.app.app_context():
self.socketio.server._handle_eio_message(self.eio_sid,
pkt.encode())
Expand Down
65 changes: 35 additions & 30 deletions test_socketio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@


@socketio.on('connect')
def on_connect():
def on_connect(auth):
if auth != {'foo': 'bar'}: # pragma: no cover
return False
if request.args.get('fail'):
return False
send('connected')
Expand Down Expand Up @@ -278,8 +280,8 @@ def tearDown(self):
pass

def test_connect(self):
client = socketio.test_client(app)
client2 = socketio.test_client(app)
client = socketio.test_client(app, auth={'foo': 'bar'})
client2 = socketio.test_client(app, auth={'foo': 'bar'})
self.assertTrue(client.is_connected())
self.assertTrue(client2.is_connected())
self.assertNotEqual(client.eio_sid, client2.eio_sid)
Expand All @@ -297,7 +299,8 @@ def test_connect(self):
def test_connect_query_string_and_headers(self):
client = socketio.test_client(
app, query_string='?foo=bar&foo=baz',
headers={'Authorization': 'Bearer foobar'})
headers={'Authorization': 'Bearer foobar'},
auth={'foo': 'bar'})
received = client.get_received()
self.assertEqual(len(received), 3)
self.assertEqual(received[0]['args'], 'connected')
Expand Down Expand Up @@ -329,13 +332,14 @@ def test_connect_namespace_query_string_and_headers(self):
client.disconnect(namespace='/test')

def test_connect_rejected(self):
client = socketio.test_client(app, query_string='fail=1')
client = socketio.test_client(app, query_string='fail=1',
auth={'foo': 'bar'})
self.assertFalse(client.is_connected())

def test_disconnect(self):
global disconnected
disconnected = None
client = socketio.test_client(app)
client = socketio.test_client(app, auth={'foo': 'bar'})
client.disconnect()
self.assertEqual(disconnected, '/')

Expand All @@ -347,16 +351,16 @@ def test_disconnect_namespace(self):
self.assertEqual(disconnected, '/test')

def test_send(self):
client = socketio.test_client(app)
client = socketio.test_client(app, auth={'foo': 'bar'})
client.get_received()
client.send('echo this message back')
received = client.get_received()
self.assertEqual(len(received), 1)
self.assertEqual(received[0]['args'], 'echo this message back')

def test_send_json(self):
client1 = socketio.test_client(app)
client2 = socketio.test_client(app)
client1 = socketio.test_client(app, auth={'foo': 'bar'})
client2 = socketio.test_client(app, auth={'foo': 'bar'})
client1.get_received()
client2.get_received()
client1.send({'a': 'b'}, json=True)
Expand Down Expand Up @@ -384,7 +388,7 @@ def test_send_json_namespace(self):
self.assertEqual(received[0]['args']['a'], 'b')

def test_emit(self):
client = socketio.test_client(app)
client = socketio.test_client(app, auth={'foo': 'bar'})
client.get_received()
client.emit('my custom event', {'a': 'b'})
received = client.get_received()
Expand All @@ -394,7 +398,7 @@ def test_emit(self):
self.assertEqual(received[0]['args'][0]['a'], 'b')

def test_emit_binary(self):
client = socketio.test_client(app)
client = socketio.test_client(app, auth={'foo': 'bar'})
client.get_received()
client.emit('my custom event', {u'a': b'\x01\x02\x03'})
received = client.get_received()
Expand All @@ -404,7 +408,7 @@ def test_emit_binary(self):
self.assertEqual(received[0]['args'][0]['a'], b'\x01\x02\x03')

def test_request_event_data(self):
client = socketio.test_client(app)
client = socketio.test_client(app, auth={'foo': 'bar'})
client.get_received()
global request_event_data
request_event_data = None
Expand All @@ -427,8 +431,8 @@ def test_emit_namespace(self):
self.assertEqual(received[0]['args'][0]['a'], 'b')

def test_broadcast(self):
client1 = socketio.test_client(app)
client2 = socketio.test_client(app)
client1 = socketio.test_client(app, auth={'foo': 'bar'})
client2 = socketio.test_client(app, auth={'foo': 'bar'})
client3 = socketio.test_client(app, namespace='/test')
client2.get_received()
client3.get_received('/test')
Expand All @@ -443,7 +447,7 @@ def test_broadcast(self):
def test_broadcast_namespace(self):
client1 = socketio.test_client(app, namespace='/test')
client2 = socketio.test_client(app, namespace='/test')
client3 = socketio.test_client(app)
client3 = socketio.test_client(app, auth={'foo': 'bar'})
client2.get_received('/test')
client3.get_received()
client1.emit('my custom broadcast namespace event', {'a': 'b'},
Expand All @@ -458,7 +462,8 @@ def test_broadcast_namespace(self):
def test_session(self):
flask_client = app.test_client()
flask_client.get('/session')
client = socketio.test_client(app, flask_test_client=flask_client)
client = socketio.test_client(app, flask_test_client=flask_client,
auth={'foo': 'bar'})
client.get_received()
client.send('echo this message back')
self.assertEqual(
Expand All @@ -470,8 +475,8 @@ def test_session(self):
{'a': 'b', 'foo': 'bar'})

def test_room(self):
client1 = socketio.test_client(app)
client2 = socketio.test_client(app)
client1 = socketio.test_client(app, auth={'foo': 'bar'})
client2 = socketio.test_client(app, auth={'foo': 'bar'})
client3 = socketio.test_client(app, namespace='/test')
client1.get_received()
client2.get_received()
Expand Down Expand Up @@ -516,7 +521,7 @@ def test_room(self):
self.assertEqual(len(received), 0)

def test_error_handling(self):
client = socketio.test_client(app)
client = socketio.test_client(app, auth={'foo': 'bar'})
client.get_received()
global error_testing
error_testing = False
Expand All @@ -540,9 +545,9 @@ def test_error_handling_default(self):
self.assertTrue(error_testing_default)

def test_ack(self):
client1 = socketio.test_client(app)
client2 = socketio.test_client(app)
client3 = socketio.test_client(app)
client1 = socketio.test_client(app, auth={'foo': 'bar'})
client2 = socketio.test_client(app, auth={'foo': 'bar'})
client3 = socketio.test_client(app, auth={'foo': 'bar'})
ack = client1.send('echo this message back', callback=True)
self.assertEqual(ack, 'echo this message back')
ack = client1.send('test noackargs', callback=True)
Expand All @@ -556,9 +561,9 @@ def test_ack(self):
self.assertEqual(ack3, {'a': 'b'})

def test_noack(self):
client1 = socketio.test_client(app)
client2 = socketio.test_client(app)
client3 = socketio.test_client(app)
client1 = socketio.test_client(app, auth={'foo': 'bar'})
client2 = socketio.test_client(app, auth={'foo': 'bar'})
client3 = socketio.test_client(app, auth={'foo': 'bar'})
no_ack_dict = {'noackargs': True}
noack = client1.send("test noackargs", callback=False)
self.assertIsNone(noack)
Expand All @@ -568,7 +573,7 @@ def test_noack(self):
self.assertIsNone(noack3)

def test_error_handling_ack(self):
client1 = socketio.test_client(app)
client1 = socketio.test_client(app, auth={'foo': 'bar'})
client2 = socketio.test_client(app, namespace='/test')
client3 = socketio.test_client(app, namespace='/unused_namespace')
errorack = client1.emit("error testing", "", callback=True)
Expand All @@ -582,7 +587,7 @@ def test_error_handling_ack(self):
self.assertEqual(errorack_default, 'error/default')

def test_on_event(self):
client = socketio.test_client(app)
client = socketio.test_client(app, auth={'foo': 'bar'})
client.get_received()
global request_event_data
request_event_data = None
Expand Down Expand Up @@ -684,13 +689,13 @@ def on_connect():
self.assertFalse(socketio.server.eio.allow_upgrades)
self.assertEqual(socketio.server.eio.cookie, 'foo')

client = socketio.test_client(app)
client = socketio.test_client(app, auth={'foo': 'bar'})
received = client.get_received()
self.assertEqual(len(received), 1)
self.assertEqual(received[0]['args'], {'connected': 'foo'})

def test_encode_decode(self):
client = socketio.test_client(app)
client = socketio.test_client(app, auth={'foo': 'bar'})
client.get_received()
data = {'foo': 'bar', 'invalid': socketio}
self.assertRaises(TypeError, client.emit, 'my custom event', data,
Expand All @@ -704,7 +709,7 @@ def test_encode_decode(self):
self.assertEqual(received[0]['args'][0], {'foo': 'bar'})

def test_encode_decode_2(self):
client = socketio.test_client(app)
client = socketio.test_client(app, auth={'foo': 'bar'})
self.assertRaises(TypeError, client.emit, 'bad response')
self.assertRaises(TypeError, client.emit, 'bad callback',
callback=True)
Expand Down

0 comments on commit 43dc6e5

Please sign in to comment.