Skip to content

Commit

Permalink
shutdown() method for the Engine.IO server
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Aug 20, 2023
1 parent 35cc5ec commit 87f6003
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 8 deletions.
30 changes: 26 additions & 4 deletions src/engineio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,18 @@ async def handle_request(self, *args, **kwargs):
break
return await self._make_response(r, environ)

async def shutdown(self):
"""Stop Socket.IO background tasks.
This method stops background activity initiated by the Socket.IO
server. It must be called before shutting down the web server.
"""
self.logger.info('Socket.IO is shutting down')
if self.service_task_event: # pragma: no cover
self.service_task_event.set()
await self.service_task_handle
self.service_task_handle = None

def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
Expand Down Expand Up @@ -392,7 +404,8 @@ async def _handle_connect(self, environ, transport, jsonp_index=None):
if self.start_service_task:
# start the service task to monitor connected clients
self.start_service_task = False
self.start_background_task(self._service_task)
self.service_task_handle = self.start_background_task(
self._service_task)

sid = self.generate_id()
s = asyncio_socket.AsyncSocket(self, sid)
Expand Down Expand Up @@ -480,10 +493,15 @@ async def async_handler():

async def _service_task(self): # pragma: no cover
"""Monitor connected clients and clean up those that time out."""
while True:
self.service_task_event = self.create_event()
while not self.service_task_event.is_set():
if len(self.sockets) == 0:
# nothing to do
await self.sleep(self.ping_timeout)
try:
await asyncio.wait_for(self.service_task_event.wait(),
timeout=self.ping_timeout)
except asyncio.TimeoutError:
break
continue

# go through the entire client list in a ping interval cycle
Expand All @@ -494,7 +512,11 @@ async def _service_task(self): # pragma: no cover
for socket in self.sockets.copy().values():
if not socket.closing and not socket.closed:
await socket.check_ping_timeout()
await self.sleep(sleep_interval)
try:
await asyncio.wait_for(self.service_task_event.wait(),
timeout=sleep_interval)
except asyncio.TimeoutError:
raise KeyboardInterrupt()
except (
SystemExit,
KeyboardInterrupt,
Expand Down
26 changes: 22 additions & 4 deletions src/engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def __init__(self, async_mode=None, ping_interval=25, ping_timeout=20,
self.log_message_keys = set()
self.start_service_task = monitor_clients \
if monitor_clients is not None else self._default_monitor_clients
self.service_task_handle = None
self.service_task_event = None
if json is not None:
packet.Packet.json = json
if not isinstance(logger, bool):
Expand Down Expand Up @@ -460,6 +462,18 @@ def handle_request(self, environ, start_response):
start_response(r['status'], r['headers'] + cors_headers)
return [r['response']]

def shutdown(self):
"""Stop Socket.IO background tasks.
This method stops background activity initiated by the Socket.IO
server. It must be called before shutting down the web server.
"""
self.logger.info('Socket.IO is shutting down')
if self.service_task_event: # pragma: no cover
self.service_task_event.set()
self.service_task_handle.join()
self.service_task_handle = None

def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
Expand Down Expand Up @@ -543,7 +557,8 @@ def _handle_connect(self, environ, start_response, transport,
if self.start_service_task:
# start the service task to monitor connected clients
self.start_service_task = False
self.start_background_task(self._service_task)
self.service_task_handle = self.start_background_task(
self._service_task)

sid = self.generate_id()
s = socket.Socket(self, sid)
Expand Down Expand Up @@ -747,10 +762,12 @@ def _log_error_once(self, message, message_key):

def _service_task(self): # pragma: no cover
"""Monitor connected clients and clean up those that time out."""
while True:
self.service_task_event = self.create_event()
while not self.service_task_event.is_set():
if len(self.sockets) == 0:
# nothing to do
self.sleep(self.ping_timeout)
if self.service_task_event.wait(timeout=self.ping_timeout):
break
continue

# go through the entire client list in a ping interval cycle
Expand All @@ -761,7 +778,8 @@ def _service_task(self): # pragma: no cover
for s in self.sockets.copy().values():
if not s.closing and not s.closed:
s.check_ping_timeout()
self.sleep(sleep_interval)
if self.service_task_event.wait(timeout=sleep_interval):
raise KeyboardInterrupt()
except (SystemExit, KeyboardInterrupt):
self.logger.info('service task canceled')
break
Expand Down
10 changes: 10 additions & 0 deletions tests/asyncio/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,16 @@ def test_service_task_started(self, import_module):
_run(s.handle_request('request'))
s._service_task.mock.assert_called_once_with()

@mock.patch('importlib.import_module')
def test_shutdown(self, import_module):
a = self.get_async_mock()
import_module.side_effect = [a]
s = asyncio_server.AsyncServer(monitor_clients=True)
_run(s.handle_request('request'))
assert s.service_task_handle is not None
_run(s.shutdown())
assert s.service_task_handle is None

@mock.patch('importlib.import_module')
def test_transports_disallowed(self, import_module):
a = self.get_async_mock(
Expand Down
9 changes: 9 additions & 0 deletions tests/common/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,15 @@ def test_service_task_started(self):
s.handle_request(environ, start_response)
s._service_task.assert_called_once_with()

def test_shutdown(self):
s = server.Server(monitor_clients=True)
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'EIO=4'}
start_response = mock.MagicMock()
s.handle_request(environ, start_response)
assert s.service_task_handle is not None
s.shutdown()
assert s.service_task_handle is None

def test_transports_invalid(self):
with pytest.raises(ValueError):
server.Server(transports='invalid')
Expand Down

0 comments on commit 87f6003

Please sign in to comment.