Skip to content

Commit

Permalink
Correct handling of CORS origin header
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Aug 17, 2015
1 parent c02a587 commit 394a878
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
7 changes: 4 additions & 3 deletions engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,11 @@ def _unauthorized(self):
def _cors_headers(self, environ):
"""Return the cross-origin-resource-sharing headers."""
if self.cors_allowed_origins is not None and \
environ.get('ORIGIN', '') not in self.cors_allowed_origins:
environ.get('HTTP_ORIGIN', '') not in \
self.cors_allowed_origins:
return []
if 'ORIGIN' in environ:
headers = [('Access-Control-Allow-Origin', environ['ORIGIN'])]
if 'HTTP_ORIGIN' in environ:
headers = [('Access-Control-Allow-Origin', environ['HTTP_ORIGIN'])]
else:
headers = [('Access-Control-Allow-Origin', '*')]
if self.cors_credentials:
Expand Down
7 changes: 5 additions & 2 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,19 +200,22 @@ def test_connect_cors_headers(self):

def test_connect_cors_allowed_origin(self):
s = server.Server(cors_allowed_origins=['a', 'b'])
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': '', 'ORIGIN': 'b'}
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
'HTTP_ORIGIN': 'b'}
start_response = mock.MagicMock()
s.handle_request(environ, start_response)
headers = start_response.call_args[0][1]
self.assertIn(('Access-Control-Allow-Origin', 'b'), headers)

def test_connect_cors_not_allowed_origin(self):
s = server.Server(cors_allowed_origins=['a', 'b'])
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': '', 'ORIGIN': 'c'}
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
'HTTP_ORIGIN': 'c'}
start_response = mock.MagicMock()
s.handle_request(environ, start_response)
headers = start_response.call_args[0][1]
self.assertNotIn(('Access-Control-Allow-Origin', 'c'), headers)
self.assertNotIn(('Access-Control-Allow-Origin', '*'), headers)

def test_connect_cors_no_credentials(self):
s = server.Server(cors_credentials=False)
Expand Down

0 comments on commit 394a878

Please sign in to comment.