Skip to content

Commit

Permalink
properly handle crashes in connect/disconnect handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Apr 21, 2017
1 parent f772cf6 commit 5b24410
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 6 deletions.
14 changes: 12 additions & 2 deletions engineio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,19 @@ async def _handle_connect(self, environ, transport, b64=False):
'pingInterval': int(self.ping_interval * 1000)})
await s.send(pkt)

if await self._trigger_event('connect', sid, environ) is False:
self.logger.warning('Application rejected connection')
reraise_exc = None
try:
ret = await self._trigger_event('connect', sid, environ)
except Exception as e:
ret = False
reraise_exc = e
if ret is False:
del self.sockets[sid]
if reraise_exc is None:
self.logger.warning('Application rejected connection')
else:
self.logger.error('Connect handler raised an exception')
raise reraise_exc
return self._unauthorized()

if transport == 'websocket':
Expand Down
8 changes: 7 additions & 1 deletion engineio/asyncio_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,18 @@ async def close(self, wait=True, abort=False):
"""Close the socket connection."""
if not self.closed and not self.closing:
self.closing = True
await self.server._trigger_event('disconnect', self.sid)
reraise_exc = None
try:
await self.server._trigger_event('disconnect', self.sid)
except Exception as e:
reraise_exc = e
if not abort:
await self.send(packet.Packet(packet.CLOSE))
self.closed = True
if wait:
await self.queue.join()
if reraise_exc:
raise reraise_exc

async def _upgrade_websocket(self, environ):
"""Upgrade the connection from polling to websocket."""
Expand Down
14 changes: 12 additions & 2 deletions engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,19 @@ def _handle_connect(self, environ, start_response, transport, b64=False):
'pingInterval': int(self.ping_interval * 1000)})
s.send(pkt)

if self._trigger_event('connect', sid, environ, async=False) is False:
self.logger.warning('Application rejected connection')
reraise_exc = None
try:
ret = self._trigger_event('connect', sid, environ, async=False)
except Exception as e:
ret = False
reraise_exc = e
if ret is False:
del self.sockets[sid]
if reraise_exc is None:
self.logger.warning('Application rejected connection')
else:
self.logger.error('Connect handler raised an exception')
raise reraise_exc
return self._unauthorized()

if transport == 'websocket':
Expand Down
8 changes: 7 additions & 1 deletion engineio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,18 @@ def close(self, wait=True, abort=False):
"""Close the socket connection."""
if not self.closed and not self.closing:
self.closing = True
self.server._trigger_event('disconnect', self.sid, async=False)
reraise_exc = None
try:
self.server._trigger_event('disconnect', self.sid, async=False)
except Exception as e:
reraise_exc = e
if not abort:
self.send(packet.Packet(packet.CLOSE))
self.closed = True
if wait:
self.queue.join()
if reraise_exc:
raise reraise_exc

def _upgrade_websocket(self, environ, start_response):
"""Upgrade the connection from polling to websocket."""
Expand Down
14 changes: 14 additions & 0 deletions tests/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,20 @@ def mock_connect(sid, environ):
self.assertEqual(a._async['make_response'].call_args[0][0],
'401 UNAUTHORIZED')

@mock.patch('importlib.import_module')
def test_connect_event_error(self, import_module):
a = self.get_async_mock()
import_module.side_effect = [a]
s = asyncio_server.AsyncServer()
s._generate_id = mock.MagicMock(return_value='123')

def mock_connect(sid, environ):
return 1 / 0

s.on('connect')(mock_connect)
self.assertRaises(ZeroDivisionError, _run, s.handle_request('request'))
self.assertEqual(len(s.sockets), 0)

@mock.patch('importlib.import_module')
def test_method_not_found(self, import_module):
a = self.get_async_mock({'REQUEST_METHOD': 'PUT', 'QUERY_STRING': ''})
Expand Down
10 changes: 10 additions & 0 deletions tests/test_asyncio_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,13 @@ def test_close_without_wait(self):
s.queue.join = AsyncMock()
_run(s.close(wait=False))
self.assertEqual(s.queue.join.mock.call_count, 0)

def test_close_disconnect_error(self):
mock_server = self._get_mock_server()
mock_server._trigger_event.mock.side_effect = ZeroDivisionError
s = asyncio_socket.AsyncSocket(mock_server, 'sid')
self.assertRaises(ZeroDivisionError, _run, s.close(wait=False))
self.assertTrue(s.closed)
self.assertEqual(mock_server._trigger_event.mock.call_count, 1)
mock_server._trigger_event.mock.assert_called_once_with('disconnect',
'sid')
11 changes: 11 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,17 @@ def test_connect_event_rejects(self):
self.assertEqual(len(s.sockets), 0)
self.assertEqual(start_response.call_args[0][0], '401 UNAUTHORIZED')

def test_connect_event_error(self):
s = server.Server()
s._generate_id = mock.MagicMock(return_value='123')
mock_event = mock.MagicMock(side_effect=ZeroDivisionError)
s.on('connect')(mock_event)
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
start_response = mock.MagicMock()
self.assertRaises(ZeroDivisionError, s.handle_request, environ,
start_response)
self.assertEqual(len(s.sockets), 0)

def test_method_not_found(self):
s = server.Server()
environ = {'REQUEST_METHOD': 'PUT', 'QUERY_STRING': ''}
Expand Down
10 changes: 10 additions & 0 deletions tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,13 @@ def test_close_without_wait(self):
s.queue = mock.MagicMock()
s.close(wait=False)
self.assertEqual(s.queue.join.call_count, 0)

def test_close_disconnect_error(self):
mock_server = self._get_mock_server()
mock_server._trigger_event.side_effect = ZeroDivisionError
s = socket.Socket(mock_server, 'sid')
self.assertRaises(ZeroDivisionError, s.close, wait=False)
self.assertTrue(s.closed)
self.assertEqual(mock_server._trigger_event.call_count, 1)
mock_server._trigger_event.assert_called_once_with('disconnect', 'sid',
async=False)

0 comments on commit 5b24410

Please sign in to comment.