Skip to content

Commit

Permalink
make ping loop task more responsive to cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Dec 19, 2018
1 parent ccf1ddf commit 6a997b9
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 51 deletions.
27 changes: 20 additions & 7 deletions engineio/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@ async def disconnect(self, abort=False):
await self.queue.put(None)
self.state = 'disconnecting'
await self._trigger_event('disconnect')
if not abort:
await self.queue.join()
if self.current_transport == 'websocket':
await self.ws.close()
if not abort:
Expand Down Expand Up @@ -198,7 +196,7 @@ async def _connect_polling(self, url, headers, engineio_path):
# upgrade to websocket succeeded, we're done here
return

self.start_background_task(self._ping_loop)
self.ping_loop_task = self.start_background_task(self._ping_loop)
self.write_loop_task = self.start_background_task(self._write_loop)
self.read_loop_task = self.start_background_task(
self._read_loop_polling)
Expand Down Expand Up @@ -241,6 +239,8 @@ async def _connect_websocket(self, url, headers, engineio_path):
return False
await ws.send(packet.Packet(packet.UPGRADE).encode())
self.current_transport = 'websocket'
if self.http:
await self.http.close()
self.logger.info('WebSocket upgrade was successful')
else:
open_packet = packet.Packet(encoded_packet=await ws.recv())
Expand All @@ -259,7 +259,7 @@ async def _connect_websocket(self, url, headers, engineio_path):
await self._trigger_event('connect')

self.ws = ws
self.start_background_task(self._ping_loop)
self.ping_loop_task = self.start_background_task(self._ping_loop)
self.write_loop_task = self.start_background_task(self._write_loop)
self.read_loop_task = self.start_background_task(
self._read_loop_websocket)
Expand Down Expand Up @@ -306,6 +306,10 @@ def _create_queue(self):
"""Create the client's send queue."""
return asyncio.Queue(), asyncio.QueueEmpty

def _create_event(self):
"""Create an event."""
return asyncio.Event()

async def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
run_async = kwargs.pop('run_async', False)
Expand Down Expand Up @@ -348,6 +352,7 @@ async def _ping_loop(self):
interval.
"""
self.pong_received = True
self.ping_loop_event.clear()
while self.state == 'connected':
if not self.pong_received:
self.logger.warning(
Expand All @@ -359,7 +364,8 @@ async def _ping_loop(self):
break
self.pong_received = False
await self._send_packet(packet.Packet(packet.PING))
await self.sleep(self.ping_interval)
await asyncio.wait_for(self.ping_loop_event.wait(),
self.ping_interval)
self.logger.info('Exiting ping task')

async def _read_loop_polling(self):
Expand Down Expand Up @@ -391,6 +397,9 @@ async def _read_loop_polling(self):

self.logger.info('Waiting for write loop task to end')
await self.write_loop_task
self.logger.info('Waiting for ping loop task to end')
self.ping_loop_event.set()
await self.ping_loop_task
if self.state == 'connected':
await self._trigger_event('disconnect')
try:
Expand All @@ -408,7 +417,7 @@ async def _read_loop_websocket(self):
p = await self.ws.recv()
except websockets.exceptions.ConnectionClosed:
self.logger.warning(
'WebSocket connection was closed, aborting')
'Read loop: WebSocket connection was closed, aborting')
await self.queue.put(None)
break
except Exception as e:
Expand All @@ -423,6 +432,9 @@ async def _read_loop_websocket(self):

self.logger.info('Waiting for write loop task to end')
await self.write_loop_task
self.logger.info('Waiting for ping loop task to end')
self.ping_loop_event.set()
await self.ping_loop_task
if self.state == 'connected':
await self._trigger_event('disconnect')
try:
Expand Down Expand Up @@ -486,7 +498,8 @@ async def _write_loop(self):
self.queue.task_done()
except websockets.exceptions.ConnectionClosed:
self.logger.warning(
'WebSocket connection was closed, aborting')
'Write loop: WebSocket connection was closed, '
'aborting')
self._reset()
break
self.logger.info('Exiting write loop task')
24 changes: 16 additions & 8 deletions engineio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def __init__(self, logger=False, json=None):
self.ws = None
self.read_loop_task = None
self.write_loop_task = None
self.ping_loop_task = None
self.ping_loop_event = self._create_event()
self.queue = None
self.queue_empty = None
self.state = 'disconnected'
Expand Down Expand Up @@ -198,8 +200,6 @@ def disconnect(self, abort=False):
self.queue.put(None)
self.state = 'disconnecting'
self._trigger_event('disconnect')
if not abort:
self.queue.join()
if self.current_transport == 'websocket':
self.ws.close()
if not abort:
Expand Down Expand Up @@ -233,10 +233,7 @@ def start_background_task(self, target, *args, **kwargs):
the Python standard library. The `start()` method on this object is
already called by this function.
"""
daemon = kwargs.pop('_daemon', None)
th = threading.Thread(target=target, args=args, kwargs=kwargs)
if daemon:
th.daemon = daemon
th.start()
return th

Expand Down Expand Up @@ -297,7 +294,7 @@ def _connect_polling(self, url, headers, engineio_path):
return

# start background tasks associated with this client
self.start_background_task(self._ping_loop, _daemon=True)
self.ping_loop_task = self.start_background_task(self._ping_loop)
self.write_loop_task = self.start_background_task(self._write_loop)
self.read_loop_task = self.start_background_task(
self._read_loop_polling)
Expand Down Expand Up @@ -358,7 +355,7 @@ def _connect_websocket(self, url, headers, engineio_path):
self.ws = ws

# start background tasks associated with this client
self.start_background_task(self._ping_loop, _daemon=True)
self.ping_loop_task = self.start_background_task(self._ping_loop)
self.write_loop_task = self.start_background_task(self._write_loop)
self.read_loop_task = self.start_background_task(
self._read_loop_websocket)
Expand Down Expand Up @@ -404,6 +401,10 @@ def _create_queue(self):
"""Create the client's send queue."""
return queue.Queue(), queue.Empty

def _create_event(self):
"""Create an event."""
return threading.Event()

def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
run_async = kwargs.pop('run_async', False)
Expand Down Expand Up @@ -446,6 +447,7 @@ def _ping_loop(self):
interval.
"""
self.pong_received = True
self.ping_loop_event.clear()
while self.state == 'connected':
if not self.pong_received:
self.logger.warning(
Expand All @@ -457,7 +459,7 @@ def _ping_loop(self):
break
self.pong_received = False
self._send_packet(packet.Packet(packet.PING))
self.sleep(self.ping_interval)
self.ping_loop_event.wait(timeout=self.ping_interval)
self.logger.info('Exiting ping task')

def _read_loop_polling(self):
Expand Down Expand Up @@ -489,6 +491,9 @@ def _read_loop_polling(self):

self.logger.info('Waiting for write loop task to end')
self.write_loop_task.join()
self.logger.info('Waiting for ping loop task to end')
self.ping_loop_event.set()
self.ping_loop_task.join()
if self.state == 'connected':
self._trigger_event('disconnect')
try:
Expand Down Expand Up @@ -521,6 +526,9 @@ def _read_loop_websocket(self):

self.logger.info('Waiting for write loop task to end')
self.write_loop_task.join()
self.logger.info('Waiting for ping loop task to end')
self.ping_loop_event.set()
self.ping_loop_task.join()
if self.state == 'connected':
self._trigger_event('disconnect')
try:
Expand Down
26 changes: 15 additions & 11 deletions tests/test_asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def test_disconnect_polling(self):
c.ws.close = AsyncMock()
c._trigger_event = AsyncMock()
_run(c.disconnect())
c.queue.join.mock.assert_called_once_with()
c.ws.close.mock.assert_not_called()
self.assertNotIn(c, client.connected_clients)
c._trigger_event.mock.assert_called_once_with('disconnect')
Expand All @@ -203,7 +202,6 @@ def test_disconnect_websocket(self):
c.ws.close = AsyncMock()
c._trigger_event = AsyncMock()
_run(c.disconnect())
c.queue.join.mock.assert_called_once_with()
c.ws.close.mock.assert_called_once_with()
self.assertNotIn(c, client.connected_clients)
c._trigger_event.mock.assert_called_once_with('disconnect')
Expand Down Expand Up @@ -724,11 +722,10 @@ def test_ping_loop_disconnect(self):
]

@coroutine
def fake_sleep(interval):
self.assertEqual(interval, 10)
def fake_wait():
c.state, c.pong_received = states.pop(0)

c.sleep = fake_sleep
c.ping_loop_event.wait = fake_wait
_run(c._ping_loop())
self.assertEqual(
c._send_packet.mock.call_args_list[0][0][0].encode(), b'2')
Expand All @@ -746,11 +743,10 @@ def test_ping_loop_missing_pong(self):
]

@coroutine
def fake_sleep(interval):
self.assertEqual(interval, 10)
def fake_wait():
c.state, c.pong_received = states.pop(0)

c.sleep = fake_sleep
c.ping_loop_event.wait = fake_wait
_run(c._ping_loop())
self.assertEqual(c.state, 'disconnected')
c.queue.put.mock.assert_called_once_with(None)
Expand All @@ -770,11 +766,10 @@ def test_ping_loop_missing_pong_websocket(self):
]

@coroutine
def fake_sleep(interval):
self.assertEqual(interval, 10)
def fake_wait():
c.state, c.pong_received = states.pop(0)

c.sleep = fake_sleep
c.ping_loop_event.wait = fake_wait
_run(c._ping_loop())
self.assertEqual(c.state, 'disconnected')
c.queue.put.mock.assert_called_once_with(None)
Expand All @@ -785,6 +780,7 @@ def test_read_loop_polling_disconnected(self):
c.state = 'disconnected'
c._trigger_event = AsyncMock()
c.write_loop_task = AsyncMock()()
c.ping_loop_task = AsyncMock()()
_run(c._read_loop_polling())
c._trigger_event.mock.assert_not_called()
# should not block
Expand All @@ -799,6 +795,7 @@ def test_read_loop_polling_no_response(self, _time):
c._send_request = AsyncMock(return_value=None)
c._trigger_event = AsyncMock()
c.write_loop_task = AsyncMock()()
c.ping_loop_task = AsyncMock()()
_run(c._read_loop_polling())
self.assertEqual(c.state, 'disconnected')
c.queue.put.mock.assert_called_once_with(None)
Expand All @@ -816,6 +813,7 @@ def test_read_loop_polling_bad_status(self, _time):
c._send_request = AsyncMock()
c._send_request.mock.return_value.status = 400
c.write_loop_task = AsyncMock()()
c.ping_loop_task = AsyncMock()()
_run(c._read_loop_polling())
self.assertEqual(c.state, 'disconnected')
c.queue.put.mock.assert_called_once_with(None)
Expand All @@ -834,6 +832,7 @@ def test_read_loop_polling_bad_packet(self, _time):
c._send_request.mock.return_value.read = AsyncMock(
return_value=b'foo')
c.write_loop_task = AsyncMock()()
c.ping_loop_task = AsyncMock()()
_run(c._read_loop_polling())
self.assertEqual(c.state, 'disconnected')
c.queue.put.mock.assert_called_once_with(None)
Expand All @@ -855,6 +854,7 @@ def test_read_loop_polling(self):
None
]
c.write_loop_task = AsyncMock()()
c.ping_loop_task = AsyncMock()()
c._receive_packet = AsyncMock()
_run(c._read_loop_polling())
self.assertEqual(c.state, 'disconnected')
Expand All @@ -870,6 +870,7 @@ def test_read_loop_websocket_disconnected(self):
c = asyncio_client.AsyncClient()
c.state = 'disconnected'
c.write_loop_task = AsyncMock()()
c.ping_loop_task = AsyncMock()()
_run(c._read_loop_websocket())
# should not block

Expand All @@ -883,6 +884,7 @@ def test_read_loop_websocket_no_response(self):
c.ws.recv = AsyncMock(
side_effect=websockets.exceptions.ConnectionClosed(1, 'foo'))
c.write_loop_task = AsyncMock()()
c.ping_loop_task = AsyncMock()()
_run(c._read_loop_websocket())
self.assertEqual(c.state, 'disconnected')
c.queue.put.mock.assert_called_once_with(None)
Expand All @@ -896,6 +898,7 @@ def test_read_loop_websocket_unexpected_error(self):
c.ws = mock.MagicMock()
c.ws.recv = AsyncMock(side_effect=ValueError)
c.write_loop_task = AsyncMock()()
c.ping_loop_task = AsyncMock()()
_run(c._read_loop_websocket())
self.assertEqual(c.state, 'disconnected')
c.queue.put.mock.assert_called_once_with(None)
Expand All @@ -910,6 +913,7 @@ def test_read_loop_websocket(self):
c.ws.recv = AsyncMock(side_effect=[
packet.Packet(packet.PING).encode(), ValueError])
c.write_loop_task = AsyncMock()()
c.ping_loop_task = AsyncMock()()
c._receive_packet = AsyncMock()
_run(c._read_loop_websocket())
self.assertEqual(c.state, 'disconnected')
Expand Down
Loading

0 comments on commit 6a997b9

Please sign in to comment.