Skip to content

Commit

Permalink
better error handling strategy
Browse files Browse the repository at this point in the history
Fixes #49 (again and hopefully better)
  • Loading branch information
miguelgrinberg committed Jun 22, 2017
1 parent 66b8f5d commit 8cc004a
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 194 deletions.
40 changes: 21 additions & 19 deletions engineio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,11 @@ async def handle_request(self, *args, **kwargs):
if sid in self.sockets: # pragma: no cover
await self.disconnect(sid)
r = self._bad_request()
except Exception as e:
# for any other unexpected errors, we disconnect
# the cient and reraise
print('yo')
if sid in self.sockets: # pragma: no cover
await self.disconnect(sid)
raise e
except: # pragma: no cover
# for any other unexpected errors, we log the error
# and keep going
self.logger.exception('post request handler error')
r = self._ok()
else:
self.logger.warning('Method %s not supported', method)
r = self._method_not_found()
Expand Down Expand Up @@ -225,19 +223,10 @@ async def _handle_connect(self, environ, transport, b64=False):
'pingInterval': int(self.ping_interval * 1000)})
await s.send(pkt)

reraise_exc = None
try:
ret = await self._trigger_event('connect', sid, environ)
except Exception as e:
ret = False
reraise_exc = e
ret = await self._trigger_event('connect', sid, environ)
if ret is False:
del self.sockets[sid]
if reraise_exc is None:
self.logger.warning('Application rejected connection')
else:
self.logger.error('Connect handler raised an exception')
raise reraise_exc
self.logger.warning('Application rejected connection')
return self._unauthorized()

if transport == 'websocket':
Expand All @@ -262,6 +251,19 @@ async def _trigger_event(self, event, *args, **kwargs):
ret = await self.handlers[event](*args)
except asyncio.CancelledError: # pragma: no cover
pass
except:
self.logger.exception(event + ' async handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
else:
ret = self.handlers[event](*args)
try:
return self.handlers[event](*args)
except:
self.logger.exception(event + ' handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
return ret
26 changes: 7 additions & 19 deletions engineio/asyncio_socket.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import six
import sys
import time

from . import exceptions
Expand Down Expand Up @@ -74,9 +73,9 @@ async def handle_get_request(self, environ):
return await getattr(self, '_upgrade_' + transport)(environ)
try:
packets = await self.poll()
except exceptions.QueueEmpty as e:
except exceptions.QueueEmpty:
await self.close(wait=False)
raise e
raise
return packets

async def handle_post_request(self, environ):
Expand All @@ -94,18 +93,12 @@ async def close(self, wait=True, abort=False):
"""Close the socket connection."""
if not self.closed and not self.closing:
self.closing = True
reraise_exc = None
try:
await self.server._trigger_event('disconnect', self.sid)
except:
reraise_exc = sys.exc_info()
await self.server._trigger_event('disconnect', self.sid)
if not abort:
await self.send(packet.Packet(packet.CLOSE))
self.closed = True
if wait:
await self.queue.join()
if reraise_exc:
six.reraise(*reraise_exc)

async def _upgrade_websocket(self, environ):
"""Upgrade the connection from polling to websocket."""
Expand Down Expand Up @@ -173,7 +166,6 @@ async def writer():
self.server.logger.info(
'%s: Upgrade to websocket successful', self.sid)

reraise_exc = None
while True:
p = None
wait_task = asyncio.ensure_future(ws.wait())
Expand All @@ -199,15 +191,11 @@ async def writer():
await self.receive(pkt)
except exceptions.UnknownPacketError:
pass
except:
# if we get an unexpected exception (such as something in an
# application event handler) we close the connection properly
# and then reraise the exception
reraise_exc = sys.exc_info()
break
except: # pragma: no cover
# if we get an unexpected exception we log the error and exit
# the connection properly
self.server.logger.exception('Receive error')

await self.queue.put(None) # unlock the writer task so it can exit
await asyncio.wait_for(writer_task, timeout=None)
await self.close(wait=True, abort=True)
if reraise_exc:
six.reraise(*reraise_exc)
33 changes: 15 additions & 18 deletions engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,11 @@ def handle_request(self, environ, start_response):
if sid in self.sockets: # pragma: no cover
self.disconnect(sid)
r = self._bad_request()
except Exception as e:
# for any other unexpected errors, we disconnect
# the cient and reraise
if sid in self.sockets: # pragma: no cover
self.disconnect(sid)
raise e
except: # pragma: no cover
# for any other unexpected errors, we log the error
# and keep going
self.logger.exception('post request handler error')
r = self._ok()
else:
self.logger.warning('Method %s not supported', method)
r = self._method_not_found()
Expand Down Expand Up @@ -355,19 +354,10 @@ def _handle_connect(self, environ, start_response, transport, b64=False):
'pingInterval': int(self.ping_interval * 1000)})
s.send(pkt)

reraise_exc = None
try:
ret = self._trigger_event('connect', sid, environ, async=False)
except Exception as e:
ret = False
reraise_exc = e
ret = self._trigger_event('connect', sid, environ, async=False)
if ret is False:
del self.sockets[sid]
if reraise_exc is None:
self.logger.warning('Application rejected connection')
else:
self.logger.error('Connect handler raised an exception')
raise reraise_exc
self.logger.warning('Application rejected connection')
return self._unauthorized()

if transport == 'websocket':
Expand Down Expand Up @@ -398,7 +388,14 @@ def _trigger_event(self, event, *args, **kwargs):
if async:
return self.start_background_task(self.handlers[event], *args)
else:
return self.handlers[event](*args)
try:
return self.handlers[event](*args)
except:
self.logger.exception(event + ' handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False

def _get_socket(self, sid):
"""Return the socket object for a given session."""
Expand Down
25 changes: 7 additions & 18 deletions engineio/socket.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import six
import sys
import time

from . import exceptions
Expand Down Expand Up @@ -90,9 +89,9 @@ def handle_get_request(self, environ, start_response):
start_response)
try:
packets = self.poll()
except exceptions.QueueEmpty as e:
except exceptions.QueueEmpty:
self.close(wait=False)
raise e
raise
return packets

def handle_post_request(self, environ):
Expand All @@ -110,18 +109,12 @@ def close(self, wait=True, abort=False):
"""Close the socket connection."""
if not self.closed and not self.closing:
self.closing = True
reraise_exc = None
try:
self.server._trigger_event('disconnect', self.sid, async=False)
except:
reraise_exc = sys.exc_info()
self.server._trigger_event('disconnect', self.sid, async=False)
if not abort:
self.send(packet.Packet(packet.CLOSE))
self.closed = True
if wait:
self.queue.join()
if reraise_exc:
six.reraise(*reraise_exc)

def _upgrade_websocket(self, environ, start_response):
"""Upgrade the connection from polling to websocket."""
Expand Down Expand Up @@ -194,7 +187,6 @@ def writer():
self.server.logger.info(
'%s: Upgrade to websocket successful', self.sid)

reraise_exc = None
while True:
p = None
try:
Expand All @@ -217,17 +209,14 @@ def writer():
self.receive(pkt)
except exceptions.UnknownPacketError:
pass
except:
# if we get an unexpected exception (such as something in an
# application event handler) we close the connection properly
# and then reraise the exception
reraise_exc = sys.exc_info()
except: # pragma: no cover
# if we get an unexpected exception we log the error and exit
# the connection properly
self.server.logger.exception('Receive error')
break

self.queue.put(None) # unlock the writer task so that it can exit
writer_task.join()
self.close(wait=True, abort=True)
if reraise_exc:
six.reraise(*reraise_exc)

return []
58 changes: 28 additions & 30 deletions tests/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,20 +395,6 @@ def mock_connect(sid, environ):
self.assertEqual(a._async['make_response'].call_args[0][0],
'401 UNAUTHORIZED')

@mock.patch('importlib.import_module')
def test_connect_event_error(self, import_module):
a = self.get_async_mock()
import_module.side_effect = [a]
s = asyncio_server.AsyncServer()
s._generate_id = mock.MagicMock(return_value='123')

def mock_connect(sid, environ):
return 1 / 0

s.on('connect')(mock_connect)
self.assertRaises(ZeroDivisionError, _run, s.handle_request('request'))
self.assertEqual(len(s.sockets), 0)

@mock.patch('importlib.import_module')
def test_method_not_found(self, import_module):
a = self.get_async_mock({'REQUEST_METHOD': 'PUT', 'QUERY_STRING': ''})
Expand Down Expand Up @@ -551,22 +537,6 @@ def mock_post_request(*args, **kwargs):
self.assertEqual(a._async['make_response'].call_args[0][0],
'400 BAD REQUEST')

@mock.patch('importlib.import_module')
def test_post_request_application_error(self, import_module):
a = self.get_async_mock({'REQUEST_METHOD': 'POST',
'QUERY_STRING': 'sid=foo'})
import_module.side_effect = [a]
s = asyncio_server.AsyncServer()
s.sockets['foo'] = mock_socket = self._get_mock_socket()

@asyncio.coroutine
def mock_get_request(*args, **kwargs):
raise ZeroDivisionError()

mock_socket.handle_post_request.mock.return_value = mock_get_request()
self.assertRaises(ZeroDivisionError, _run, s.handle_request('request'))
self.assertEqual(len(s.sockets), 0)

@staticmethod
def _gzip_decompress(b):
bytesio = six.BytesIO(b)
Expand Down Expand Up @@ -770,3 +740,31 @@ def foo_handler(arg):
s.on('message', handler=foo_handler)
_run(s._trigger_event('message', 'bar'))
self.assertEqual(result, ['ok', 'bar'])

def test_trigger_event_function_error(self):
def connect_handler(arg):
return 1 / 0

def foo_handler(arg):
return 1 / 0

s = asyncio_server.AsyncServer()
s.on('connect', handler=connect_handler)
s.on('message', handler=foo_handler)
self.assertFalse(_run(s._trigger_event('connect', '123')))
self.assertIsNone(_run(s._trigger_event('message', 'bar')))

def test_trigger_event_coroutine_error(self):
@asyncio.coroutine
def connect_handler(arg):
return 1 / 0

@asyncio.coroutine
def foo_handler(arg):
return 1 / 0

s = asyncio_server.AsyncServer()
s.on('connect', handler=connect_handler)
s.on('message', handler=foo_handler)
self.assertFalse(_run(s._trigger_event('connect', '123')))
self.assertIsNone(_run(s._trigger_event('message', 'bar')))
34 changes: 0 additions & 34 deletions tests/test_asyncio_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,30 +351,6 @@ def test_websocket_read_write_wait_fail(self):
_run(s._websocket_handler(ws))
self.assertEqual(s.closed, True)

def test_websocket_read_write_receive_fail(self):
mock_server = self._get_mock_server()
s = asyncio_socket.AsyncSocket(mock_server, 'sid')
s.connected = False
s.queue.join = AsyncMock(return_value=None)
foo = six.text_type('foo')
bar = six.text_type('bar')
s.poll = AsyncMock(side_effect=[
[packet.Packet(packet.MESSAGE, data=bar)],
[packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
ws = mock.MagicMock()
ws.send = AsyncMock()
ws.wait = AsyncMock()
ws.wait.mock.side_effect = [
packet.Packet(packet.MESSAGE, data=foo).encode(
always_bytes=False),
packet.Packet(packet.MESSAGE, data=bar).encode(
always_bytes=False)]
ws.send.mock.side_effect = [None, None]
s.receive = AsyncMock()
s.receive.mock.side_effect = [None, ZeroDivisionError]
self.assertRaises(ZeroDivisionError, _run, s._websocket_handler(ws))
self.assertEqual(s.closed, True)

def test_websocket_ignore_invalid_packet(self):
mock_server = self._get_mock_server()
s = asyncio_socket.AsyncSocket(mock_server, 'sid')
Expand Down Expand Up @@ -435,13 +411,3 @@ def test_close_without_wait(self):
s.queue.join = AsyncMock()
_run(s.close(wait=False))
self.assertEqual(s.queue.join.mock.call_count, 0)

def test_close_disconnect_error(self):
mock_server = self._get_mock_server()
mock_server._trigger_event.mock.side_effect = ZeroDivisionError
s = asyncio_socket.AsyncSocket(mock_server, 'sid')
self.assertRaises(ZeroDivisionError, _run, s.close(wait=False))
self.assertTrue(s.closed)
self.assertEqual(mock_server._trigger_event.mock.call_count, 1)
mock_server._trigger_event.mock.assert_called_once_with('disconnect',
'sid')
Loading

0 comments on commit 8cc004a

Please sign in to comment.