Skip to content

Commit

Permalink
Actively monitor clients for disconnections
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Sep 23, 2018
1 parent 6735659 commit 3f583c8
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 8 deletions.
27 changes: 27 additions & 0 deletions engineio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,11 @@ async def sleep(self, seconds=0):

async def _handle_connect(self, environ, transport, b64=False):
"""Handle a client connection request."""
if self.start_service_task:
# start the service task to monitor connected clients
self.start_service_task = False
self.start_background_task(self._service_task)

sid = self._generate_id()
s = asyncio_socket.AsyncSocket(self, sid)
self.sockets[sid] = s
Expand Down Expand Up @@ -295,3 +300,25 @@ async def async_handler():
# connection
return False
return ret

async def _service_task(self): # pragma: no cover
"""Monitor connected clients and clean up those that time out."""
while True:
if len(self.sockets) == 0:
# nothing to do
await self.sleep(self.ping_timeout)
continue

# go through the entire client list in a ping interval cycle
sleep_interval = self.ping_timeout / len(self.sockets)

try:
# iterate over the current clients
for socket in self.sockets.copy().values():
if socket.closed:
continue
await socket.check_ping_timeout()
await self.sleep(sleep_interval)
except:
# an unexpected exception has occurred, log it and continue
self.logger.exception('service task exception')
26 changes: 21 additions & 5 deletions engineio/asyncio_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,24 @@ async def receive(self, pkt):
else:
raise exceptions.UnknownPacketError()

async def send(self, pkt):
"""Send a packet to the client."""
async def check_ping_timeout(self):
"""Make sure the client is still sending pings.
This helps detect disconnections for long-polling clients.
"""
if self.closed:
raise exceptions.SocketIsClosedError()
if time.time() - self.last_ping > self.server.ping_timeout:
self.server.logger.info('%s: Client is gone, closing socket',
self.sid)
return await self.close(wait=False, abort=True)
await self.close(wait=False, abort=True)
return False
return True

async def send(self, pkt):
"""Send a packet to the client."""
if not await self.check_ping_timeout():
return
self.server.logger.info('%s: Sending packet %s data %s',
self.sid, packet.packet_names[pkt.packet_type],
pkt.data if not isinstance(pkt.data, bytes)
Expand Down Expand Up @@ -123,7 +133,10 @@ async def _websocket_handler(self, ws):
# the socket was already connected, so this is an upgrade
await self.queue.join() # flush the queue first

pkt = await ws.wait()
try:
pkt = await ws.wait()
except IOError: # pragma: no cover
return
if pkt != packet.Packet(packet.PING,
data=six.text_type('probe')).encode(
always_bytes=False):
Expand All @@ -135,7 +148,10 @@ async def _websocket_handler(self, ws):
data=six.text_type('probe')).encode(always_bytes=False))
await self.send(packet.Packet(packet.NOOP))

pkt = await ws.wait()
try:
pkt = await ws.wait()
except IOError: # pragma: no cover
return
decoded_pkt = packet.Packet(encoded_packet=pkt)
if decoded_pkt.packet_type != packet.UPGRADE:
self.upgraded = False
Expand Down
36 changes: 35 additions & 1 deletion engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,23 @@ class Server(object):
:param async_handlers: If set to ``True``, run message event handlers in
non-blocking threads. To run handlers synchronously,
set to ``False``. The default is ``True``.
:param monitor_clients: If set to ``True``, a background task will ensure
inactive clients are closed. Set to ``False`` to
disable the monitoring task (not recommended). The
default is ``True``.
:param kwargs: Reserved for future extensions, any additional parameters
given as keyword arguments will be silently ignored.
"""
compression_methods = ['gzip', 'deflate']
event_names = ['connect', 'disconnect', 'message']
_default_monitor_clients = True

def __init__(self, async_mode=None, ping_timeout=60, ping_interval=25,
max_http_buffer_size=100000000, allow_upgrades=True,
http_compression=True, compression_threshold=1024,
cookie='io', cors_allowed_origins=None,
cors_credentials=True, logger=False, json=None,
async_handlers=True, **kwargs):
async_handlers=True, monitor_clients=None, **kwargs):
self.ping_timeout = ping_timeout
self.ping_interval = ping_interval
self.max_http_buffer_size = max_http_buffer_size
Expand All @@ -88,6 +93,8 @@ def __init__(self, async_mode=None, ping_timeout=60, ping_interval=25,
self.async_handlers = async_handlers
self.sockets = {}
self.handlers = {}
self.start_service_task = monitor_clients \
if monitor_clients is not None else self._default_monitor_clients
if json is not None:
packet.Packet.json = json
if not isinstance(logger, bool):
Expand Down Expand Up @@ -359,6 +366,11 @@ def _generate_id(self):

def _handle_connect(self, environ, start_response, transport, b64=False):
"""Handle a client connection request."""
if self.start_service_task:
# start the service task to monitor connected clients
self.start_service_task = False
self.start_background_task(self._service_task)

sid = self._generate_id()
s = socket.Socket(self, sid)
self.sockets[sid] = s
Expand Down Expand Up @@ -497,3 +509,25 @@ def _gzip(self, response):
def _deflate(self, response):
"""Apply deflate compression to a response."""
return zlib.compress(response)

def _service_task(self): # pragma: no cover
"""Monitor connected clients and clean up those that time out."""
while True:
if len(self.sockets) == 0:
# nothing to do
self.sleep(self.ping_timeout)
continue

# go through the entire client list in a ping interval cycle
sleep_interval = self.ping_timeout / len(self.sockets)

try:
# iterate over the current clients
for s in self.sockets.copy().values():
if s.closed:
continue
s.check_ping_timeout()
self.sleep(sleep_interval)
except:
# an unexpected exception has occurred, log it and continue
self.logger.exception('service task exception')
13 changes: 11 additions & 2 deletions engineio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,23 @@ def receive(self, pkt):
else:
raise exceptions.UnknownPacketError()

def send(self, pkt):
"""Send a packet to the client."""
def check_ping_timeout(self):
"""Make sure the client is still sending pings.
This helps detect disconnections for long-polling clients.
"""
if self.closed:
raise exceptions.SocketIsClosedError()
if time.time() - self.last_ping > self.server.ping_timeout:
self.server.logger.info('%s: Client is gone, closing socket',
self.sid)
self.close(wait=False, abort=True)
return False
return True

def send(self, pkt):
"""Send a packet to the client."""
if not self.check_ping_timeout():
return
self.queue.put(pkt)
self.server.logger.info('%s: Sending packet %s data %s',
Expand Down
18 changes: 18 additions & 0 deletions tests/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,18 @@ def _get_mock_socket(self):
mock_socket.send = AsyncMock()
mock_socket.handle_get_request = AsyncMock()
mock_socket.handle_post_request = AsyncMock()
mock_socket.check_ping_timeout = AsyncMock()
mock_socket.close = AsyncMock()
return mock_socket

@classmethod
def setUpClass(cls):
asyncio_server.AsyncServer._default_monitor_clients = False

@classmethod
def tearDownClass(cls):
asyncio_server.AsyncServer._default_monitor_clients = True

def setUp(self):
logging.getLogger('engineio').setLevel(logging.NOTSET)

Expand Down Expand Up @@ -839,3 +848,12 @@ def foo_handler(arg):
ZeroDivisionError, asyncio.get_event_loop().run_until_complete,
fut)
self.assertEqual(result, ['bar'])

@mock.patch('importlib.import_module')
def test_service_task_started(self, import_module):
a = self.get_async_mock()
import_module.side_effect = [a]
s = asyncio_server.AsyncServer(monitor_clients=True)
s._service_task = AsyncMock()
_run(s.handle_request('request'))
s._service_task.mock.assert_called_once_with()
16 changes: 16 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def _get_mock_socket(self):
mock_socket.upgraded = False
return mock_socket

@classmethod
def setUpClass(cls):
server.Server._default_monitor_clients = False

@classmethod
def tearDownClass(cls):
server.Server._default_monitor_clients = True

def setUp(self):
logging.getLogger('engineio').setLevel(logging.NOTSET)

Expand Down Expand Up @@ -863,3 +871,11 @@ def test_sleep(self):
t = time.time()
s.sleep(0.1)
self.assertTrue(time.time() - t > 0.1)

def test_service_task_started(self):
s = server.Server(monitor_clients=True)
s._service_task = mock.MagicMock()
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
start_response = mock.MagicMock()
s.handle_request(environ, start_response)
s._service_task.assert_called_once_with()

0 comments on commit 3f583c8

Please sign in to comment.