Skip to content

Commit

Permalink
Address potential websocket cross-origin attacks (Fixes #128)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Jul 28, 2019
1 parent f23a405 commit 7548f70
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 24 deletions.
22 changes: 19 additions & 3 deletions engineio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ class AsyncServer(server.Server):
is greater than this value.
:param cookie: Name of the HTTP cookie that contains the client session
id. If set to ``None``, a cookie is not sent to the client.
:param cors_allowed_origins: List of origins that are allowed to connect
to this server. All origins are allowed by
default.
:param cors_allowed_origins: Origin or list of origins that are allowed to
connect to this server. Only the same server
is allowed by default. Set this argument to
``'*'`` to allow all origins.
:param cors_credentials: Whether credentials (cookies, authentication) are
allowed in requests to this server.
:param logger: To enable logging set to ``True`` or pass a logger object to
Expand Down Expand Up @@ -181,6 +182,21 @@ async def handle_request(self, *args, **kwargs):
environ = await translate_request(*args, **kwargs)
else:
environ = translate_request(*args, **kwargs)

# Validate the origin header if present
# This is important for WebSocket more than for HTTP, since browsers
# only apply CORS controls to HTTP.
origin = environ.get('HTTP_ORIGIN')
if origin:
allowed_origins = self._cors_allowed_origins(environ)
if allowed_origins is not None and origin not in allowed_origins:
self.logger.info(origin + ' is not an accepted origin.')
r = self._bad_request()
make_response = self._async['make_response']
response = make_response(r['status'], r['headers'],
r['response'], environ)
return response

method = environ['REQUEST_METHOD']
query = urllib.parse.parse_qs(environ.get('QUERY_STRING', ''))

Expand Down
54 changes: 37 additions & 17 deletions engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ class Server(object):
id. If set to ``None``, a cookie is not sent to the client.
The default is ``'io'``.
:param cors_allowed_origins: Origin or list of origins that are allowed to
connect to this server. All origins are
allowed by default, which is equivalent to
setting this argument to ``'*'``.
connect to this server. Only the same server
is allowed by default. Set this argument to
``'*'`` to allow all origins.
:param cors_credentials: Whether credentials (cookies, authentication) are
allowed in requests to this server. The default
is ``True``.
Expand Down Expand Up @@ -309,6 +309,18 @@ def handle_request(self, environ, start_response):
This function returns the HTTP response body to deliver to the client
as a byte sequence.
"""
# Validate the origin header if present
# This is important for WebSocket more than for HTTP, since browsers
# only apply CORS controls to HTTP.
origin = environ.get('HTTP_ORIGIN')
if origin:
allowed_origins = self._cors_allowed_origins(environ)
if allowed_origins is not None and origin not in allowed_origins:
self.logger.info(origin + ' is not an accepted origin.')
r = self._bad_request()
start_response(r['status'], r['headers'])
return [r['response']]

method = environ['REQUEST_METHOD']
query = urllib.parse.parse_qs(environ.get('QUERY_STRING', ''))

Expand Down Expand Up @@ -572,27 +584,35 @@ def _unauthorized(self):
'headers': [('Content-Type', 'text/plain')],
'response': b'Unauthorized'}

def _cors_headers(self, environ):
"""Return the cross-origin-resource-sharing headers."""
if isinstance(self.cors_allowed_origins, six.string_types):
if self.cors_allowed_origins == '*':
allowed_origins = None
else:
allowed_origins = [self.cors_allowed_origins]
def _cors_allowed_origins(self, environ):
default_origin = None
if 'wsgi.url_scheme' in environ and 'HTTP_HOST' in environ:
default_origin = '{scheme}://{host}'.format(
scheme=environ['wsgi.url_scheme'], host=environ['HTTP_HOST'])
if self.cors_allowed_origins is None:
allowed_origins = [default_origin] \
if default_origin is not None else[]

This comment has been minimized.

Copy link
@zhoufenqin

zhoufenqin Aug 5, 2019

L594 if default_origin is not None else[] may be updated to if default_origin is not None else None, in old python-engineio version, the cors_allowed_origins does not exist so the self.cors_allowed_origins is None, In L318-L322, the allowed_origins is not None and origin not in allowed_origins is True, then bad request will occurred, see Issue #132

elif self.cors_allowed_origins == '*':
allowed_origins = None
elif isinstance(self.cors_allowed_origins, six.string_types):
allowed_origins = [self.cors_allowed_origins]
else:
allowed_origins = self.cors_allowed_origins
if allowed_origins is not None and \
environ.get('HTTP_ORIGIN', '') not in allowed_origins:
return []
if 'HTTP_ORIGIN' in environ:
return allowed_origins

def _cors_headers(self, environ):
"""Return the cross-origin-resource-sharing headers."""
headers = []
allowed_origins = self._cors_allowed_origins(environ)
if allowed_origins is None or \
('HTTP_ORIGIN' in environ and environ['HTTP_ORIGIN'] in
allowed_origins):
headers = [('Access-Control-Allow-Origin', environ['HTTP_ORIGIN'])]
else:
headers = [('Access-Control-Allow-Origin', '*')]
if environ['REQUEST_METHOD'] == 'OPTIONS':
headers += [('Access-Control-Allow-Methods', 'OPTIONS, GET, POST')]
if 'HTTP_ACCESS_CONTROL_REQUEST_HEADERS' in environ:
headers += [('Access-Control-Allow-Headers',
environ['HTTP_ACCESS_CONTROL_REQUEST_HEADERS'])]
environ['HTTP_ACCESS_CONTROL_REQUEST_HEADERS'])]
if self.cors_credentials:
headers += [('Access-Control-Allow-Credentials', 'true')]
return headers
Expand Down
47 changes: 46 additions & 1 deletion tests/asyncio/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,6 @@ def test_connect_cors_headers(self, import_module):
s = asyncio_server.AsyncServer()
_run(s.handle_request('request'))
headers = a._async['make_response'].call_args[0][1]
self.assertIn(('Access-Control-Allow-Origin', '*'), headers)
self.assertIn(('Access-Control-Allow-Credentials', 'true'), headers)

@mock.patch('importlib.import_module')
Expand All @@ -423,6 +422,52 @@ def test_connect_cors_not_allowed_origin(self, import_module):
self.assertNotIn(('Access-Control-Allow-Origin', 'c'), headers)
self.assertNotIn(('Access-Control-Allow-Origin', '*'), headers)

@mock.patch('importlib.import_module')
def test_connect_cors_all_origins(self, import_module):
a = self.get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
'HTTP_ORIGIN': 'foo'})
import_module.side_effect = [a]
s = asyncio_server.AsyncServer(cors_allowed_origins='*')
_run(s.handle_request('request'))
headers = a._async['make_response'].call_args[0][1]
self.assertIn(('Access-Control-Allow-Origin', 'foo'), headers)
self.assertIn(('Access-Control-Allow-Credentials', 'true'), headers)

@mock.patch('importlib.import_module')
def test_connect_cors_one_origin(self, import_module):
a = self.get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
'HTTP_ORIGIN': 'a'})
import_module.side_effect = [a]
s = asyncio_server.AsyncServer(cors_allowed_origins='a')
_run(s.handle_request('request'))
headers = a._async['make_response'].call_args[0][1]
self.assertIn(('Access-Control-Allow-Origin', 'a'), headers)
self.assertIn(('Access-Control-Allow-Credentials', 'true'), headers)

@mock.patch('importlib.import_module')
def test_connect_cors_one_origin_not_allowed(self, import_module):
a = self.get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
'HTTP_ORIGIN': 'b'})
import_module.side_effect = [a]
s = asyncio_server.AsyncServer(cors_allowed_origins='a')
_run(s.handle_request('request'))
headers = a._async['make_response'].call_args[0][1]
self.assertNotIn(('Access-Control-Allow-Origin', 'b'), headers)
self.assertNotIn(('Access-Control-Allow-Origin', '*'), headers)

@mock.patch('importlib.import_module')
def test_connect_cors_headers_default_origin(self, import_module):
a = self.get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
'wsgi.url_scheme': 'http',
'HTTP_HOST': 'foo',
'HTTP_ORIGIN': 'http://foo'})
import_module.side_effect = [a]
s = asyncio_server.AsyncServer()
_run(s.handle_request('request'))
headers = a._async['make_response'].call_args[0][1]
self.assertIn(('Access-Control-Allow-Origin', 'http://foo'),
headers)

@mock.patch('importlib.import_module')
def test_connect_cors_no_credentials(self, import_module):
a = self.get_async_mock()
Expand Down
17 changes: 14 additions & 3 deletions tests/common/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,6 @@ def test_connect_cors_headers(self):
start_response = mock.MagicMock()
s.handle_request(environ, start_response)
headers = start_response.call_args[0][1]
self.assertIn(('Access-Control-Allow-Origin', '*'), headers)
self.assertIn(('Access-Control-Allow-Credentials', 'true'), headers)

def test_connect_cors_allowed_origin(self):
Expand All @@ -577,11 +576,12 @@ def test_connect_cors_not_allowed_origin(self):

def test_connect_cors_headers_all_origins(self):
s = server.Server(cors_allowed_origins='*')
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
'HTTP_ORIGIN': 'foo'}
start_response = mock.MagicMock()
s.handle_request(environ, start_response)
headers = start_response.call_args[0][1]
self.assertIn(('Access-Control-Allow-Origin', '*'), headers)
self.assertIn(('Access-Control-Allow-Origin', 'foo'), headers)
self.assertIn(('Access-Control-Allow-Credentials', 'true'), headers)

def test_connect_cors_headers_one_origin(self):
Expand All @@ -604,6 +604,17 @@ def test_connect_cors_headers_one_origin_not_allowed(self):
self.assertNotIn(('Access-Control-Allow-Origin', 'b'), headers)
self.assertNotIn(('Access-Control-Allow-Origin', '*'), headers)

def test_connect_cors_headers_default_origin(self):
s = server.Server()
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
'wsgi.url_scheme': 'http', 'HTTP_HOST': 'foo',
'HTTP_ORIGIN': 'http://foo'}
start_response = mock.MagicMock()
s.handle_request(environ, start_response)
headers = start_response.call_args[0][1]
self.assertIn(('Access-Control-Allow-Origin', 'http://foo'),
headers)

def test_connect_cors_no_credentials(self):
s = server.Server(cors_credentials=False)
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
Expand Down

0 comments on commit 7548f70

Please sign in to comment.