Skip to content

Commit

Permalink
Graceful failure when websocket is request and the async mode does no…
Browse files Browse the repository at this point in the history
…t support it
  • Loading branch information
miguelgrinberg committed Jan 10, 2016
1 parent 5af2c8b commit 2a5cdf2
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
3 changes: 1 addition & 2 deletions engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,7 @@ def _handle_connect(self, environ, start_response, transport, b64=False):
return self._unauthorized()

if transport == 'websocket':
s.handle_get_request(environ, start_response)
return self._ok()
return s.handle_get_request(environ, start_response)
else:
s.connected = True
headers = None
Expand Down
4 changes: 4 additions & 0 deletions engineio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def _upgrade_websocket(self, environ, start_response):
"""Upgrade the connection from polling to websocket."""
if self.upgraded:
raise IOError('Socket has been upgraded already')
if self.server.async['websocket'] is None or \
self.server.async['websocket_class'] is None:
# the selected async mode does not support websocket
return self.server._bad_request()
websocket_class = getattr(self.server.async['websocket'],
self.server.async['websocket_class'])
ws = websocket_class(self._websocket_handler)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,17 @@ def test_upgrade_no_upgrade_packet(self):
self.assertEqual(s.queue.get().packet_type, packet.NOOP)
self.assertFalse(s.upgraded)

def test_upgrade_not_supported(self):
mock_server = self._get_mock_server()
mock_server.async['websocket'] = None
mock_server.async['websocket_class'] = None
s = socket.Socket(mock_server, 'sid')
s.connected = True
environ = "foo"
start_response = "bar"
s._upgrade_websocket(environ, start_response)
mock_server._bad_request.assert_called_once_with()

def test_websocket_read_write(self):
mock_server = self._get_mock_server()
s = socket.Socket(mock_server, 'sid')
Expand Down

0 comments on commit 2a5cdf2

Please sign in to comment.