Skip to content

Commit

Permalink
invoke disconnect handler when websocket handler crashes
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Apr 18, 2017
1 parent 246edc3 commit f772cf6
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 12 deletions.
11 changes: 10 additions & 1 deletion engineio/asyncio_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ async def writer():
packets = None
try:
packets = await self.poll()
except IOError:
except exceptions.QueueEmpty:
break
if not packets:
# empty packet list returned -> connection closed
Expand All @@ -166,6 +166,7 @@ async def writer():
self.server.logger.info(
'%s: Upgrade to websocket successful', self.sid)

reraise_exc = None
while True:
p = None
try:
Expand All @@ -182,7 +183,15 @@ async def writer():
await self.receive(pkt)
except exceptions.UnknownPacketError:
pass
except Exception as e:
# if we get an unexpected exception (such as something in an
# application event handler) we close the connection properly
# and then reraise the exception
reraise_exc = e
break

await self.queue.put(None) # unlock the writer task so it can exit
await asyncio.wait_for(writer_task, timeout=None)
await self.close(wait=True, abort=True)
if reraise_exc:
raise reraise_exc
11 changes: 10 additions & 1 deletion engineio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def writer():
while True:
try:
packets = self.poll()
except IOError:
except exceptions.QueueEmpty:
break
if not packets:
# empty packet list returned -> connection closed
Expand All @@ -181,6 +181,7 @@ def writer():
self.server.logger.info(
'%s: Upgrade to websocket successful', self.sid)

reraise_exc = None
while True:
p = None
try:
Expand All @@ -197,9 +198,17 @@ def writer():
self.receive(pkt)
except exceptions.UnknownPacketError:
pass
except Exception as e:
# if we get an unexpected exception (such as something in an
# application event handler) we close the connection properly
# and then reraise the exception
reraise_exc = e
break

self.queue.put(None) # unlock the writer task so that it can exit
writer_task.join()
self.close(wait=True, abort=True)
if reraise_exc:
raise reraise_exc

return []
1 change: 0 additions & 1 deletion tests/test_async_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def test_create_route(self, add_route):
app = web.Application()
mock_server = mock.MagicMock()
async_aiohttp.create_route(app, mock_server, '/foo')
print(add_route.call_args_list)
add_route.assert_any_call('GET', '/foo', mock_server.handle_request,
name=None)
add_route.assert_any_call('POST', '/foo', mock_server.handle_request)
Expand Down
32 changes: 28 additions & 4 deletions tests/test_asyncio_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def test_websocket_upgrade_read_write(self):
bar = six.text_type('bar')
probe = six.text_type('probe')
s.poll = AsyncMock(side_effect=[
[packet.Packet(packet.MESSAGE, data=bar)], IOError])
[packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
ws = mock.MagicMock()
ws.send = AsyncMock()
ws.wait = AsyncMock()
Expand Down Expand Up @@ -330,7 +330,7 @@ def test_websocket_upgrade_with_payload(self):
_run(s._websocket_handler(ws))
self.assertTrue(s.upgraded)

def test_websocket_read_write_fail(self):
def test_websocket_read_write_wait_fail(self):
mock_server = self._get_mock_server()
s = asyncio_socket.AsyncSocket(mock_server, 'sid')
s.connected = False
Expand All @@ -339,7 +339,7 @@ def test_websocket_read_write_fail(self):
bar = six.text_type('bar')
s.poll = AsyncMock(side_effect=[
[packet.Packet(packet.MESSAGE, data=bar)],
[packet.Packet(packet.MESSAGE, data=bar)], IOError])
[packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
ws = mock.MagicMock()
ws.send = AsyncMock()
ws.wait = AsyncMock()
Expand All @@ -351,6 +351,30 @@ def test_websocket_read_write_fail(self):
_run(s._websocket_handler(ws))
self.assertEqual(s.closed, True)

def test_websocket_read_write_receive_fail(self):
mock_server = self._get_mock_server()
s = asyncio_socket.AsyncSocket(mock_server, 'sid')
s.connected = False
s.queue.join = AsyncMock(return_value=None)
foo = six.text_type('foo')
bar = six.text_type('bar')
s.poll = AsyncMock(side_effect=[
[packet.Packet(packet.MESSAGE, data=bar)],
[packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
ws = mock.MagicMock()
ws.send = AsyncMock()
ws.wait = AsyncMock()
ws.wait.mock.side_effect = [
packet.Packet(packet.MESSAGE, data=foo).encode(
always_bytes=False),
packet.Packet(packet.MESSAGE, data=bar).encode(
always_bytes=False)]
ws.send.mock.side_effect = [None, None]
s.receive = AsyncMock()
s.receive.mock.side_effect = [None, ZeroDivisionError]
self.assertRaises(ZeroDivisionError, _run, s._websocket_handler(ws))
self.assertEqual(s.closed, True)

def test_websocket_ignore_invalid_packet(self):
mock_server = self._get_mock_server()
s = asyncio_socket.AsyncSocket(mock_server, 'sid')
Expand All @@ -359,7 +383,7 @@ def test_websocket_ignore_invalid_packet(self):
foo = six.text_type('foo')
bar = six.text_type('bar')
s.poll = AsyncMock(side_effect=[
[packet.Packet(packet.MESSAGE, data=bar)], IOError])
[packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
ws = mock.MagicMock()
ws.send = AsyncMock()
ws.wait = AsyncMock()
Expand Down
33 changes: 28 additions & 5 deletions tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def test_websocket_read_write(self):
foo = six.text_type('foo')
bar = six.text_type('bar')
s.poll = mock.MagicMock(side_effect=[
[packet.Packet(packet.MESSAGE, data=bar)], IOError])
[packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
ws = mock.MagicMock()
ws.wait.side_effect = [
packet.Packet(packet.MESSAGE, data=foo).encode(
Expand All @@ -288,7 +288,7 @@ def test_websocket_upgrade_read_write(self):
bar = six.text_type('bar')
probe = six.text_type('probe')
s.poll = mock.MagicMock(side_effect=[
[packet.Packet(packet.MESSAGE, data=bar)], IOError])
[packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
ws = mock.MagicMock()
ws.wait.side_effect = [
packet.Packet(packet.PING, data=probe).encode(
Expand Down Expand Up @@ -322,7 +322,7 @@ def test_websocket_upgrade_with_payload(self):
self._join_bg_tasks()
self.assertTrue(s.upgraded)

def test_websocket_read_write_fail(self):
def test_websocket_read_write_wait_fail(self):
mock_server = self._get_mock_server()
s = socket.Socket(mock_server, 'sid')
s.connected = False
Expand All @@ -331,7 +331,7 @@ def test_websocket_read_write_fail(self):
bar = six.text_type('bar')
s.poll = mock.MagicMock(side_effect=[
[packet.Packet(packet.MESSAGE, data=bar)],
[packet.Packet(packet.MESSAGE, data=bar)], IOError])
[packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
ws = mock.MagicMock()
ws.wait.side_effect = [
packet.Packet(packet.MESSAGE, data=foo).encode(
Expand All @@ -342,6 +342,29 @@ def test_websocket_read_write_fail(self):
self._join_bg_tasks()
self.assertEqual(s.closed, True)

def test_websocket_read_write_receive_fail(self):
mock_server = self._get_mock_server()
s = socket.Socket(mock_server, 'sid')
s.connected = False
s.queue.join = mock.MagicMock(return_value=None)
foo = six.text_type('foo')
bar = six.text_type('bar')
s.poll = mock.MagicMock(side_effect=[
[packet.Packet(packet.MESSAGE, data=bar)],
[packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
ws = mock.MagicMock()
ws.wait.side_effect = [
packet.Packet(packet.MESSAGE, data=foo).encode(
always_bytes=False),
packet.Packet(packet.MESSAGE, data=bar).encode(
always_bytes=False)]
ws.send.side_effect = [None, None]
s.receive = mock.MagicMock(side_effect=[
None, ZeroDivisionError])
self.assertRaises(ZeroDivisionError, s._websocket_handler, ws)
self._join_bg_tasks()
self.assertEqual(s.closed, True)

def test_websocket_ignore_invalid_packet(self):
mock_server = self._get_mock_server()
s = socket.Socket(mock_server, 'sid')
Expand All @@ -350,7 +373,7 @@ def test_websocket_ignore_invalid_packet(self):
foo = six.text_type('foo')
bar = six.text_type('bar')
s.poll = mock.MagicMock(side_effect=[
[packet.Packet(packet.MESSAGE, data=bar)], IOError])
[packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
ws = mock.MagicMock()
ws.wait.side_effect = [
packet.Packet(packet.OPEN).encode(always_bytes=False),
Expand Down

0 comments on commit f772cf6

Please sign in to comment.