Skip to content

Commit

Permalink
Fixed infrequent race condition when upgrading from polling to WebSoc…
Browse files Browse the repository at this point in the history
…ket (Fixes #160)
  • Loading branch information
miguelgrinberg committed Mar 10, 2020
1 parent 59a0cd4 commit f2cce2b
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 19 deletions.
14 changes: 7 additions & 7 deletions engineio/asyncio_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ async def send(self, pkt):
"""Send a packet to the client."""
if not await self.check_ping_timeout():
return
if self.upgrading:
self.packet_backlog.append(pkt)
else:
await self.queue.put(pkt)
self.server.logger.info('%s: Sending packet %s data %s',
Expand All @@ -93,6 +91,10 @@ async def handle_get_request(self, environ):
self.server.logger.info('%s: Received request to upgrade to %s',
self.sid, transport)
return await getattr(self, '_upgrade_' + transport)(environ)
if self.upgrading or self.upgraded:
# we are upgrading to WebSocket, do not return any more packets
# through the polling endpoint
return [packet.Packet(packet.NOOP)]
try:
packets = await self.poll()
except exceptions.QueueEmpty:
Expand Down Expand Up @@ -148,6 +150,7 @@ async def _websocket_handler(self, ws):
decoded_pkt.data != 'probe':
self.server.logger.info(
'%s: Failed websocket upgrade, no PING packet', self.sid)
self.upgrading = False
return
await ws.send(packet.Packet(
packet.PONG,
Expand All @@ -157,6 +160,7 @@ async def _websocket_handler(self, ws):
try:
pkt = await ws.wait()
except IOError: # pragma: no cover
self.upgrading = False
return
decoded_pkt = packet.Packet(encoded_packet=pkt)
if decoded_pkt.packet_type != packet.UPGRADE:
Expand All @@ -165,13 +169,9 @@ async def _websocket_handler(self, ws):
('%s: Failed websocket upgrade, expected UPGRADE packet, '
'received %s instead.'),
self.sid, pkt)
self.upgrading = False
return
self.upgraded = True

# flush any packets that were sent during the upgrade
for pkt in self.packet_backlog:
await self.queue.put(pkt)
self.packet_backlog = []
self.upgrading = False
else:
self.connected = True
Expand Down
14 changes: 6 additions & 8 deletions engineio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def __init__(self, server, sid):
self.connected = False
self.upgrading = False
self.upgraded = False
self.packet_backlog = []
self.closing = False
self.closed = False
self.session = {}
Expand Down Expand Up @@ -89,8 +88,6 @@ def send(self, pkt):
"""Send a packet to the client."""
if not self.check_ping_timeout():
return
if self.upgrading:
self.packet_backlog.append(pkt)
else:
self.queue.put(pkt)
self.server.logger.info('%s: Sending packet %s data %s',
Expand All @@ -109,6 +106,10 @@ def handle_get_request(self, environ, start_response):
self.sid, transport)
return getattr(self, '_upgrade_' + transport)(environ,
start_response)
if self.upgrading or self.upgraded:
# we are upgrading to WebSocket, do not return any more packets
# through the polling endpoint
return [packet.Packet(packet.NOOP)]
try:
packets = self.poll()
except exceptions.QueueEmpty:
Expand Down Expand Up @@ -167,6 +168,7 @@ def _websocket_handler(self, ws):
decoded_pkt.data != 'probe':
self.server.logger.info(
'%s: Failed websocket upgrade, no PING packet', self.sid)
self.upgrading = False
return []
ws.send(packet.Packet(
packet.PONG,
Expand All @@ -181,13 +183,9 @@ def _websocket_handler(self, ws):
('%s: Failed websocket upgrade, expected UPGRADE packet, '
'received %s instead.'),
self.sid, pkt)
self.upgrading = False
return []
self.upgraded = True

# flush any packets that were sent during the upgrade
for pkt in self.packet_backlog:
self.queue.put(pkt)
self.packet_backlog = []
self.upgrading = False
else:
self.connected = True
Expand Down
13 changes: 11 additions & 2 deletions tests/asyncio/test_asyncio_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,20 @@ def test_websocket_upgrade_with_backlog(self):
always_bytes=False)]
s.upgrading = True
_run(s.send(packet.Packet(packet.MESSAGE, data=foo)))
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=sid'}
packets = _run(s.handle_get_request(environ))
self.assertEqual(len(packets), 1)
self.assertEqual(packets[0].encode(), b'6')
packets = _run(s.poll())
self.assertEqual(len(packets), 1)
self.assertEqual(packets[0].encode(), b'4foo')

_run(s._websocket_handler(ws))
self.assertTrue(s.upgraded)
self.assertFalse(s.upgrading)
self.assertEqual(s.packet_backlog, [])
ws.send.mock.assert_called_with('4foo')
packets = _run(s.handle_get_request(environ))
self.assertEqual(len(packets), 1)
self.assertEqual(packets[0].encode(), b'6')

def test_websocket_read_write_wait_fail(self):
mock_server = self._get_mock_server()
Expand Down
14 changes: 12 additions & 2 deletions tests/common/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,22 @@ def test_websocket_upgrade_with_backlog(self):
always_bytes=False)]
s.upgrading = True
s.send(packet.Packet(packet.MESSAGE, data=foo))
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=sid'}
start_response = mock.MagicMock()
packets = s.handle_get_request(environ, start_response)
self.assertEqual(len(packets), 1)
self.assertEqual(packets[0].encode(), b'6')
packets = s.poll()
self.assertEqual(len(packets), 1)
self.assertEqual(packets[0].encode(), b'4foo')

s._websocket_handler(ws)
self._join_bg_tasks()
self.assertTrue(s.upgraded)
self.assertFalse(s.upgrading)
self.assertEqual(s.packet_backlog, [])
ws.send.assert_called_with('4foo')
packets = s.handle_get_request(environ, start_response)
self.assertEqual(len(packets), 1)
self.assertEqual(packets[0].encode(), b'6')

def test_websocket_read_write_wait_fail(self):
mock_server = self._get_mock_server()
Expand Down

0 comments on commit f2cce2b

Please sign in to comment.