Skip to content

Commit

Permalink
invoke disconnect handler when application handler crashes
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Apr 18, 2017
1 parent 23859e4 commit 246edc3
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 14 deletions.
13 changes: 11 additions & 2 deletions engineio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ async def handle_request(self, *args, **kwargs):
r = self._ok(packets, b64=b64)
else:
r = packets
except IOError:
except EngineIOError:
if sid in self.sockets: # pragma: no cover
del self.sockets[sid]
await self.disconnect(sid)
r = self._bad_request()
if sid in self.sockets and self.sockets[sid].closed:
del self.sockets[sid]
Expand All @@ -155,7 +155,16 @@ async def handle_request(self, *args, **kwargs):
await socket.handle_post_request(environ)
r = self._ok()
except EngineIOError:
if sid in self.sockets: # pragma: no cover
await self.disconnect(sid)
r = self._bad_request()
except Exception as e:
# for any other unexpected errors, we disconnect
# the cient and reraise
print('yo')
if sid in self.sockets: # pragma: no cover
await self.disconnect(sid)
raise e
else:
self.logger.warning('Method %s not supported', method)
r = self._method_not_found()
Expand Down
4 changes: 2 additions & 2 deletions engineio/asyncio_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async def poll(self):
self.server.ping_timeout)]
self.queue.task_done()
except (asyncio.TimeoutError, asyncio.CancelledError):
raise IOError()
raise exceptions.QueueEmpty()
if packets == [None]:
return []
try:
Expand Down Expand Up @@ -73,7 +73,7 @@ async def handle_get_request(self, environ):
return await getattr(self, '_upgrade_' + transport)(environ)
try:
packets = await self.poll()
except IOError as e:
except exceptions.QueueEmpty as e:
await self.close(wait=False)
raise e
return packets
Expand Down
4 changes: 4 additions & 0 deletions engineio/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ class ContentTooLongError(EngineIOError):

class UnknownPacketError(EngineIOError):
pass


class QueueEmpty(EngineIOError):
pass
12 changes: 10 additions & 2 deletions engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,9 @@ def handle_request(self, environ, start_response):
r = self._ok(packets, b64=b64)
else:
r = packets
except IOError:
except EngineIOError:
if sid in self.sockets: # pragma: no cover
del self.sockets[sid]
self.disconnect(sid)
r = self._bad_request()
if sid in self.sockets and self.sockets[sid].closed:
del self.sockets[sid]
Expand All @@ -279,7 +279,15 @@ def handle_request(self, environ, start_response):
socket.handle_post_request(environ)
r = self._ok()
except EngineIOError:
if sid in self.sockets: # pragma: no cover
self.disconnect(sid)
r = self._bad_request()
except Exception as e:
# for any other unexpected errors, we disconnect
# the cient and reraise
if sid in self.sockets: # pragma: no cover
self.disconnect(sid)
raise e
else:
self.logger.warning('Method %s not supported', method)
r = self._method_not_found()
Expand Down
4 changes: 2 additions & 2 deletions engineio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def poll(self):
packets = [self.queue.get(timeout=self.server.ping_timeout)]
self.queue.task_done()
except self.server._async['queue'].Empty:
raise IOError()
raise exceptions.QueueEmpty()
if packets == [None]:
return []
try:
Expand Down Expand Up @@ -89,7 +89,7 @@ def handle_get_request(self, environ, start_response):
start_response)
try:
packets = self.poll()
except IOError as e:
except exceptions.QueueEmpty as e:
self.close(wait=False)
raise e
return packets
Expand Down
19 changes: 18 additions & 1 deletion tests/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _get_mock_socket(self):
mock_socket = mock.MagicMock()
mock_socket.connected = False
mock_socket.closed = False
mock_socket.closing = False
mock_socket.upgraded = False
mock_socket.send = AsyncMock()
mock_socket.handle_get_request = AsyncMock()
Expand Down Expand Up @@ -501,7 +502,7 @@ def test_get_request_error(self, import_module):

@asyncio.coroutine
def mock_get_request(*args, **kwargs):
raise IOError()
raise exceptions.QueueEmpty()

mock_socket.handle_get_request.mock.return_value = mock_get_request()
_run(s.handle_request('request'))
Expand Down Expand Up @@ -536,6 +537,22 @@ def mock_post_request(*args, **kwargs):
self.assertEqual(a._async['make_response'].call_args[0][0],
'400 BAD REQUEST')

@mock.patch('importlib.import_module')
def test_post_request_application_error(self, import_module):
a = self.get_async_mock({'REQUEST_METHOD': 'POST',
'QUERY_STRING': 'sid=foo'})
import_module.side_effect = [a]
s = asyncio_server.AsyncServer()
s.sockets['foo'] = mock_socket = self._get_mock_socket()

@asyncio.coroutine
def mock_get_request(*args, **kwargs):
raise ZeroDivisionError()

mock_socket.handle_post_request.mock.return_value = mock_get_request()
self.assertRaises(ZeroDivisionError, _run, s.handle_request('request'))
self.assertEqual(len(s.sockets), 0)

@staticmethod
def _gzip_decompress(b):
bytesio = six.BytesIO(b)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_asyncio_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_create(self):
def test_empty_poll(self):
mock_server = self._get_mock_server()
s = asyncio_socket.AsyncSocket(mock_server, 'sid')
self.assertRaises(IOError, _run, s.poll())
self.assertRaises(exceptions.QueueEmpty, _run, s.poll())

def test_poll(self):
mock_server = self._get_mock_server()
Expand Down Expand Up @@ -139,7 +139,8 @@ def test_polling_read_error(self):
mock_server = self._get_mock_server()
s = asyncio_socket.AsyncSocket(mock_server, 'foo')
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
self.assertRaises(IOError, _run, s.handle_get_request(environ))
self.assertRaises(exceptions.QueueEmpty, _run,
s.handle_get_request(environ))

def test_polling_write(self):
mock_server = self._get_mock_server()
Expand Down
17 changes: 16 additions & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class TestServer(unittest.TestCase):
def _get_mock_socket(self):
mock_socket = mock.MagicMock()
mock_socket.closed = False
mock_socket.closing = False
mock_socket.upgraded = False
return mock_socket

Expand Down Expand Up @@ -603,7 +604,8 @@ def mock_get_request(*args, **kwargs):
def test_get_request_error(self):
s = server.Server()
mock_socket = self._get_mock_socket()
mock_socket.handle_get_request = mock.MagicMock(side_effect=[IOError])
mock_socket.handle_get_request = mock.MagicMock(
side_effect=[exceptions.QueueEmpty])
s.sockets['foo'] = mock_socket
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
start_response = mock.MagicMock()
Expand Down Expand Up @@ -634,6 +636,19 @@ def test_post_request_error(self):
s.handle_request(environ, start_response)
self.assertEqual(start_response.call_args[0][0],
'400 BAD REQUEST')
self.assertNotIn('foo', s.sockets)

def test_post_request_application_error(self):
s = server.Server()
mock_socket = self._get_mock_socket()
mock_socket.handle_post_request = mock.MagicMock(
side_effect=[ZeroDivisionError])
s.sockets['foo'] = mock_socket
environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo'}
start_response = mock.MagicMock()
self.assertRaises(ZeroDivisionError, s.handle_request, environ,
start_response)
self.assertNotIn('foo', s.sockets)

@staticmethod
def _gzip_decompress(b):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_create(self):
def test_empty_poll(self):
mock_server = self._get_mock_server()
s = socket.Socket(mock_server, 'sid')
self.assertRaises(IOError, s.poll)
self.assertRaises(exceptions.QueueEmpty, s.poll)

def test_poll(self):
mock_server = self._get_mock_server()
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_polling_read_error(self):
s = socket.Socket(mock_server, 'foo')
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
start_response = mock.MagicMock()
self.assertRaises(IOError, s.handle_get_request, environ,
self.assertRaises(exceptions.QueueEmpty, s.handle_get_request, environ,
start_response)

def test_polling_write(self):
Expand Down

0 comments on commit 246edc3

Please sign in to comment.