Skip to content

Commit

Permalink
Restore CORS disable option
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Jul 31, 2019
1 parent 64004d5 commit 9f4cd8c
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 2 deletions.
3 changes: 2 additions & 1 deletion engineio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class AsyncServer(server.Server):
:param cors_allowed_origins: Origin or list of origins that are allowed to
connect to this server. Only the same origin
is allowed by default. Set this argument to
``'*'`` to allow all origins.
``'*'`` to allow all origins, or to ``[]`` to
disable CORS handling.
:param cors_credentials: Whether credentials (cookies, authentication) are
allowed in requests to this server.
:param logger: To enable logging set to ``True`` or pass a logger object to
Expand Down
6 changes: 5 additions & 1 deletion engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class Server(object):
:param cors_allowed_origins: Origin or list of origins that are allowed to
connect to this server. Only the same origin
is allowed by default. Set this argument to
``'*'`` to allow all origins.
``'*'`` to allow all origins, or to ``[]`` to
disable CORS handling.
:param cors_credentials: Whether credentials (cookies, authentication) are
allowed in requests to this server. The default
is ``True``.
Expand Down Expand Up @@ -602,6 +603,9 @@ def _cors_allowed_origins(self, environ):

def _cors_headers(self, environ):
"""Return the cross-origin-resource-sharing headers."""
if self.cors_allowed_origins == []:
# special case, CORS handling is completely disabled
return []
headers = []
allowed_origins = self._cors_allowed_origins(environ)
if allowed_origins is None or \
Expand Down
10 changes: 10 additions & 0 deletions tests/asyncio/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,16 @@ def test_connect_cors_options(self, import_module):
self.assertIn(('Access-Control-Allow-Methods',
'OPTIONS, GET, POST'), headers)

@mock.patch('importlib.import_module')
def test_connect_cors_disabled(self, import_module):
a = self.get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''})
import_module.side_effect = [a]
s = asyncio_server.AsyncServer(cors_allowed_origins=[])
_run(s.handle_request('request'))
headers = a._async['make_response'].call_args[0][1]
for header in headers:
self.assertFalse(header[0].startswith('Access-Control-'))

@mock.patch('importlib.import_module')
def test_connect_event(self, import_module):
a = self.get_async_mock()
Expand Down
9 changes: 9 additions & 0 deletions tests/common/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,15 @@ def test_cors_request_headers(self):
headers = start_response.call_args[0][1]
self.assertIn(('Access-Control-Allow-Headers', 'Foo, Bar'), headers)

def test_connect_cors_disabled(self):
s = server.Server(cors_allowed_origins=[])
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
start_response = mock.MagicMock()
s.handle_request(environ, start_response)
headers = start_response.call_args[0][1]
for header in headers:
self.assertFalse(header[0].startswith('Access-Control-'))

def test_connect_event(self):
s = server.Server()
s._generate_id = mock.MagicMock(return_value='123')
Expand Down

0 comments on commit 9f4cd8c

Please sign in to comment.