Skip to content

Commit

Permalink
correct handling of disconnect event
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Dec 16, 2018
1 parent aeabccd commit ccf1ddf
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 10 deletions.
20 changes: 15 additions & 5 deletions engineio/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ async def disconnect(self, abort=False):
await self._send_packet(packet.Packet(packet.CLOSE))
await self.queue.put(None)
self.state = 'disconnecting'
await self._trigger_event('disconnect')
if not abort:
await self.queue.join()
if self.current_transport == 'websocket':
Expand Down Expand Up @@ -372,27 +373,31 @@ async def _read_loop_polling(self):
self.logger.warning(
'Connection refused by the server, aborting')
await self.queue.put(None)
self._reset()
break
if r.status != 200:
self.logger.warning('Unexpected status code %s in server '
'response, aborting', r.status)
await self.queue.put(None)
self._reset()
break
try:
p = payload.Payload(encoded_payload=await r.read())
except ValueError:
self.logger.warning(
'Unexpected packet from server, aborting')
await self.queue.put(None)
self._reset()
break
for pkt in p.packets:
await self._receive_packet(pkt)

self.logger.info('Waiting for write loop task to end')
await self.write_loop_task
if self.state == 'connected':
await self._trigger_event('disconnect')
try:
client.connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
self.logger.info('Exiting read loop task')

async def _read_loop_websocket(self):
Expand All @@ -405,13 +410,11 @@ async def _read_loop_websocket(self):
self.logger.warning(
'WebSocket connection was closed, aborting')
await self.queue.put(None)
self._reset()
break
except Exception as e:
self.logger.info(
'Unexpected error "%s", aborting', str(e))
await self.queue.put(None)
self._reset()
break
if isinstance(p, six.text_type): # pragma: no cover
p = p.encode('utf-8')
Expand All @@ -420,6 +423,13 @@ async def _read_loop_websocket(self):

self.logger.info('Waiting for write loop task to end')
await self.write_loop_task
if self.state == 'connected':
await self._trigger_event('disconnect')
try:
client.connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
self.logger.info('Exiting read loop task')

async def _write_loop(self):
Expand Down
20 changes: 15 additions & 5 deletions engineio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def disconnect(self, abort=False):
self._send_packet(packet.Packet(packet.CLOSE))
self.queue.put(None)
self.state = 'disconnecting'
self._trigger_event('disconnect')
if not abort:
self.queue.join()
if self.current_transport == 'websocket':
Expand Down Expand Up @@ -470,27 +471,31 @@ def _read_loop_polling(self):
self.logger.warning(
'Connection refused by the server, aborting')
self.queue.put(None)
self._reset()
break
if r.status != 200:
self.logger.warning('Unexpected status code %s in server '
'response, aborting', r.status)
self.queue.put(None)
self._reset()
break
try:
p = payload.Payload(encoded_payload=r.data)
except ValueError:
self.logger.warning(
'Unexpected packet from server, aborting')
self.queue.put(None)
self._reset()
break
for pkt in p.packets:
self._receive_packet(pkt)

self.logger.info('Waiting for write loop task to end')
self.write_loop_task.join()
if self.state == 'connected':
self._trigger_event('disconnect')
try:
connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
self.logger.info('Exiting read loop task')

def _read_loop_websocket(self):
Expand All @@ -503,13 +508,11 @@ def _read_loop_websocket(self):
self.logger.warning(
'WebSocket connection was closed, aborting')
self.queue.put(None)
self._reset()
break
except Exception as e:
self.logger.info(
'Unexpected error "%s", aborting', str(e))
self.queue.put(None)
self._reset()
break
if isinstance(p, six.text_type): # pragma: no cover
p = p.encode('utf-8')
Expand All @@ -518,6 +521,13 @@ def _read_loop_websocket(self):

self.logger.info('Waiting for write loop task to end')
self.write_loop_task.join()
if self.state == 'connected':
self._trigger_event('disconnect')
try:
connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
self.logger.info('Exiting read loop task')

def _write_loop(self):
Expand Down
8 changes: 8 additions & 0 deletions tests/test_asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,12 @@ def test_disconnect_polling(self):
c.read_loop_task = AsyncMock()()
c.ws = mock.MagicMock()
c.ws.close = AsyncMock()
c._trigger_event = AsyncMock()
_run(c.disconnect())
c.queue.join.mock.assert_called_once_with()
c.ws.close.mock.assert_not_called()
self.assertNotIn(c, client.connected_clients)
c._trigger_event.mock.assert_called_once_with('disconnect')

def test_disconnect_websocket(self):
c = asyncio_client.AsyncClient()
Expand All @@ -199,10 +201,12 @@ def test_disconnect_websocket(self):
c.read_loop_task = AsyncMock()()
c.ws = mock.MagicMock()
c.ws.close = AsyncMock()
c._trigger_event = AsyncMock()
_run(c.disconnect())
c.queue.join.mock.assert_called_once_with()
c.ws.close.mock.assert_called_once_with()
self.assertNotIn(c, client.connected_clients)
c._trigger_event.mock.assert_called_once_with('disconnect')

def test_disconnect_polling_abort(self):
c = asyncio_client.AsyncClient()
Expand Down Expand Up @@ -779,8 +783,10 @@ def fake_sleep(interval):
def test_read_loop_polling_disconnected(self):
c = asyncio_client.AsyncClient()
c.state = 'disconnected'
c._trigger_event = AsyncMock()
c.write_loop_task = AsyncMock()()
_run(c._read_loop_polling())
c._trigger_event.mock.assert_not_called()
# should not block

@mock.patch('engineio.client.time.time', return_value=123.456)
Expand All @@ -791,12 +797,14 @@ def test_read_loop_polling_no_response(self, _time):
c.queue = mock.MagicMock()
c.queue.put = AsyncMock()
c._send_request = AsyncMock(return_value=None)
c._trigger_event = AsyncMock()
c.write_loop_task = AsyncMock()()
_run(c._read_loop_polling())
self.assertEqual(c.state, 'disconnected')
c.queue.put.mock.assert_called_once_with(None)
c._send_request.mock.assert_called_once_with(
'GET', 'http://foo&t=123.456')
c._trigger_event.mock.assert_called_once_with('disconnect')

@mock.patch('engineio.client.time.time', return_value=123.456)
def test_read_loop_polling_bad_status(self, _time):
Expand Down
8 changes: 8 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,13 @@ def test_disconnect_polling(self):
c.queue = mock.MagicMock()
c.read_loop_task = mock.MagicMock()
c.ws = mock.MagicMock()
c._trigger_event = mock.MagicMock()
c.disconnect()
c.queue.join.assert_called_once_with()
c.read_loop_task.join.assert_called_once_with()
c.ws.mock.assert_not_called()
self.assertNotIn(c, client.connected_clients)
c._trigger_event.assert_called_once_with('disconnect')

def test_disconnect_websocket(self):
c = client.Client()
Expand All @@ -200,11 +202,13 @@ def test_disconnect_websocket(self):
c.queue = mock.MagicMock()
c.read_loop_task = mock.MagicMock()
c.ws = mock.MagicMock()
c._trigger_event = mock.MagicMock()
c.disconnect()
c.queue.join.assert_called_once_with()
c.read_loop_task.join.assert_called_once_with()
c.ws.close.assert_called_once_with()
self.assertNotIn(c, client.connected_clients)
c._trigger_event.assert_called_once_with('disconnect')

def test_disconnect_polling_abort(self):
c = client.Client()
Expand Down Expand Up @@ -722,9 +726,11 @@ def fake_sleep(interval):
def test_read_loop_polling_disconnected(self):
c = client.Client()
c.state = 'disconnected'
c._trigger_event = mock.MagicMock()
c.write_loop_task = mock.MagicMock()
c._read_loop_polling()
c.write_loop_task.join.assert_called_once_with()
c._trigger_event.assert_not_called()

@mock.patch('engineio.client.time.time', return_value=123.456)
def test_read_loop_polling_no_response(self, _time):
Expand All @@ -733,12 +739,14 @@ def test_read_loop_polling_no_response(self, _time):
c.base_url = 'http://foo'
c.queue = mock.MagicMock()
c._send_request = mock.MagicMock(return_value=None)
c._trigger_event = mock.MagicMock()
c.write_loop_task = mock.MagicMock()
c._read_loop_polling()
self.assertEqual(c.state, 'disconnected')
c.queue.put.assert_called_once_with(None)
c.write_loop_task.join.assert_called_once_with()
c._send_request.assert_called_once_with('GET', 'http://foo&t=123.456')
c._trigger_event.assert_called_once_with('disconnect')

@mock.patch('engineio.client.time.time', return_value=123.456)
def test_read_loop_polling_bad_status(self, _time):
Expand Down

0 comments on commit ccf1ddf

Please sign in to comment.