Skip to content

Commit

Permalink
Reject request with incorrect transport (Fixes #367)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Oct 13, 2024
1 parent 5b5d67d commit 4d614e5
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 25 deletions.
30 changes: 18 additions & 12 deletions src/engineio/async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,18 +288,24 @@ async def handle_request(self, *args, **kwargs):
r = self._bad_request('Invalid session ' + sid)
else:
socket = self._get_socket(sid)
try:
packets = await socket.handle_get_request(environ)
if isinstance(packets, list):
r = self._ok(packets, jsonp_index=jsonp_index)
else:
r = packets
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
await self.disconnect(sid)
r = self._bad_request()
if sid in self.sockets and self.sockets[sid].closed:
del self.sockets[sid]
if self.transport(sid) != transport:
self._log_error_once(
'Invalid transport for session ' + sid,
'bad-transport')
r = self._bad_request('Invalid transport')
else:
try:
packets = await socket.handle_get_request(environ)
if isinstance(packets, list):
r = self._ok(packets, jsonp_index=jsonp_index)
else:
r = packets
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
await self.disconnect(sid)
r = self._bad_request()
if sid in self.sockets and self.sockets[sid].closed:
del self.sockets[sid]
elif method == 'POST':
if sid is None or sid not in self.sockets:
self._log_error_once('Invalid session ' + sid, 'bad-sid')
Expand Down
32 changes: 19 additions & 13 deletions src/engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,19 +270,25 @@ def handle_request(self, environ, start_response):
r = self._bad_request('Invalid session')
else:
socket = self._get_socket(sid)
try:
packets = socket.handle_get_request(
environ, start_response)
if isinstance(packets, list):
r = self._ok(packets, jsonp_index=jsonp_index)
else:
r = packets
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
self.disconnect(sid)
r = self._bad_request()
if sid in self.sockets and self.sockets[sid].closed:
del self.sockets[sid]
if self.transport(sid) != transport:
self._log_error_once(
'Invalid transport for session ' + sid,
'bad-transport')
r = self._bad_request('Invalid transport')
else:
try:
packets = socket.handle_get_request(
environ, start_response)
if isinstance(packets, list):
r = self._ok(packets, jsonp_index=jsonp_index)
else:
r = packets
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
self.disconnect(sid)
r = self._bad_request()
if sid in self.sockets and self.sockets[sid].closed:
del self.sockets[sid]
elif method == 'POST':
if sid is None or sid not in self.sockets:
self._log_error_once(
Expand Down
26 changes: 26 additions & 0 deletions tests/async/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,32 @@ def test_get_request_with_bad_sid(self, import_module):
assert len(s.sockets) == 0
assert a._async['make_response'].call_args[0][0] == '400 BAD REQUEST'

@mock.patch('importlib.import_module')
def test_get_request_bad_websocket_transport(self, import_module):
a = self.get_async_mock(
{'REQUEST_METHOD': 'GET',
'QUERY_STRING': 'EIO=4&transport=websocket&sid=foo'}
)
import_module.side_effect = [a]
s = async_server.AsyncServer()
s.sockets['foo'] = mock_socket = self._get_mock_socket()
mock_socket.upgraded = False
_run(s.handle_request('request'))
assert a._async['make_response'].call_args[0][0] == '400 BAD REQUEST'

@mock.patch('importlib.import_module')
def test_get_request_bad_polling_transport(self, import_module):
a = self.get_async_mock(
{'REQUEST_METHOD': 'GET',
'QUERY_STRING': 'EIO=4&transport=polling&sid=foo'}
)
import_module.side_effect = [a]
s = async_server.AsyncServer()
s.sockets['foo'] = mock_socket = self._get_mock_socket()
mock_socket.upgraded = True
_run(s.handle_request('request'))
assert a._async['make_response'].call_args[0][0] == '400 BAD REQUEST'

@mock.patch('importlib.import_module')
def test_post_request_with_bad_sid(self, import_module):
a = self.get_async_mock(
Expand Down
22 changes: 22 additions & 0 deletions tests/common/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,28 @@ def test_get_request_error(self):
assert start_response.call_args[0][0] == '400 BAD REQUEST'
assert len(s.sockets) == 0

def test_get_request_bad_websocket_transport(self):
s = server.Server()
mock_socket = self._get_mock_socket()
mock_socket.upgraded = False
s.sockets['foo'] = mock_socket
environ = {'REQUEST_METHOD': 'GET',
'QUERY_STRING': 'EIO=4&transport=websocket&sid=foo'}
start_response = mock.MagicMock()
s.handle_request(environ, start_response)
assert start_response.call_args[0][0] == '400 BAD REQUEST'

def test_get_request_bad_polling_transport(self):
s = server.Server()
mock_socket = self._get_mock_socket()
mock_socket.upgraded = True
s.sockets['foo'] = mock_socket
environ = {'REQUEST_METHOD': 'GET',
'QUERY_STRING': 'EIO=4&transport=polling&sid=foo'}
start_response = mock.MagicMock()
s.handle_request(environ, start_response)
assert start_response.call_args[0][0] == '400 BAD REQUEST'

def test_post_request(self):
s = server.Server()
mock_socket = self._get_mock_socket()
Expand Down

0 comments on commit 4d614e5

Please sign in to comment.