Skip to content

Commit

Permalink
Support async_handlers option for the asyncio server
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Jun 27, 2017
1 parent 22ee3dc commit 6609416
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 44 deletions.
53 changes: 34 additions & 19 deletions engineio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class AsyncServer(server.Server):
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
:param async_handlers: If set to ``True``, run message event handlers in
non-blocking threads. To run handlers synchronously,
set to ``False``. The default is ``True``.
:param kwargs: Reserved for future extensions, any additional parameters
given as keyword arguments will be silently ignored.
"""
Expand Down Expand Up @@ -223,7 +226,8 @@ async def _handle_connect(self, environ, transport, b64=False):
'pingInterval': int(self.ping_interval * 1000)})
await s.send(pkt)

ret = await self._trigger_event('connect', sid, environ)
ret = await self._trigger_event('connect', sid, environ,
run_async=False)
if ret is False:
del self.sockets[sid]
self.logger.warning('Application rejected connection')
Expand All @@ -244,26 +248,37 @@ async def _handle_connect(self, environ, transport, b64=False):

async def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
run_async = kwargs.pop('run_async', False)
ret = None
if event in self.handlers:
if asyncio.iscoroutinefunction(self.handlers[event]) is True:
try:
ret = await self.handlers[event](*args)
except asyncio.CancelledError: # pragma: no cover
pass
except:
self.logger.exception(event + ' async handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
if run_async:
return self.start_background_task(self.handlers[event],
*args)
else:
try:
ret = await self.handlers[event](*args)
except asyncio.CancelledError: # pragma: no cover
pass
except:
self.logger.exception(event + ' async handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
else:
try:
return self.handlers[event](*args)
except:
self.logger.exception(event + ' handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
if run_async:
async def async_handler():
return self.handlers[event](*args)

return self.start_background_task(async_handler)
else:
try:
ret = self.handlers[event](*args)
except:
self.logger.exception(event + ' handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
return ret
4 changes: 3 additions & 1 deletion engineio/asyncio_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ async def receive(self, pkt):
self.last_ping = time.time()
await self.send(packet.Packet(packet.PONG, pkt.data))
elif pkt.packet_type == packet.MESSAGE:
await self.server._trigger_event('message', self.sid, pkt.data)
await self.server._trigger_event(
'message', self.sid, pkt.data,
run_async=self.server.async_handlers)
elif pkt.packet_type == packet.UPGRADE:
await self.send(packet.Packet(packet.NOOP))
elif pkt.packet_type == packet.CLOSE:
Expand Down
6 changes: 3 additions & 3 deletions engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def _handle_connect(self, environ, start_response, transport, b64=False):
'pingInterval': int(self.ping_interval * 1000)})
s.send(pkt)

ret = self._trigger_event('connect', sid, environ, async=False)
ret = self._trigger_event('connect', sid, environ, run_async=False)
if ret is False:
del self.sockets[sid]
self.logger.warning('Application rejected connection')
Expand Down Expand Up @@ -383,9 +383,9 @@ def _upgrades(self, sid, transport):

def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
async = kwargs.pop('async', False)
run_async = kwargs.pop('run_async', False)
if event in self.handlers:
if async:
if run_async:
return self.start_background_task(self.handlers[event], *args)
else:
try:
Expand Down
4 changes: 2 additions & 2 deletions engineio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def receive(self, pkt):
self.send(packet.Packet(packet.PONG, pkt.data))
elif pkt.packet_type == packet.MESSAGE:
self.server._trigger_event('message', self.sid, pkt.data,
async=self.server.async_handlers)
run_async=self.server.async_handlers)
elif pkt.packet_type == packet.UPGRADE:
self.send(packet.Packet(packet.NOOP))
elif pkt.packet_type == packet.CLOSE:
Expand Down Expand Up @@ -109,7 +109,7 @@ def close(self, wait=True, abort=False):
"""Close the socket connection."""
if not self.closed and not self.closing:
self.closing = True
self.server._trigger_event('disconnect', self.sid, async=False)
self.server._trigger_event('disconnect', self.sid, run_async=False)
if not abort:
self.send(packet.Packet(packet.CLOSE))
self.closed = True
Expand Down
58 changes: 58 additions & 0 deletions tests/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,3 +768,61 @@ def foo_handler(arg):
s.on('message', handler=foo_handler)
self.assertFalse(_run(s._trigger_event('connect', '123')))
self.assertIsNone(_run(s._trigger_event('message', 'bar')))

def test_trigger_event_function_async(self):
result = []

def foo_handler(arg):
result.append('ok')
result.append(arg)

s = asyncio_server.AsyncServer()
s.on('message', handler=foo_handler)
fut = _run(s._trigger_event('message', 'bar', run_async=True))
asyncio.get_event_loop().run_until_complete(fut)
self.assertEqual(result, ['ok', 'bar'])

def test_trigger_event_coroutine_async(self):
result = []

@asyncio.coroutine
def foo_handler(arg):
result.append('ok')
result.append(arg)

s = asyncio_server.AsyncServer()
s.on('message', handler=foo_handler)
fut = _run(s._trigger_event('message', 'bar', run_async=True))
asyncio.get_event_loop().run_until_complete(fut)
self.assertEqual(result, ['ok', 'bar'])

def test_trigger_event_function_async_error(self):
result = []

def foo_handler(arg):
result.append(arg)
return 1 / 0

s = asyncio_server.AsyncServer()
s.on('message', handler=foo_handler)
fut = _run(s._trigger_event('message', 'bar', run_async=True))
self.assertRaises(
ZeroDivisionError, asyncio.get_event_loop().run_until_complete,
fut)
self.assertEqual(result, ['bar'])

def test_trigger_event_coroutine_async_error(self):
result = []

@asyncio.coroutine
def foo_handler(arg):
result.append(arg)
return 1 / 0

s = asyncio_server.AsyncServer()
s.on('message', handler=foo_handler)
fut = _run(s._trigger_event('message', 'bar', run_async=True))
self.assertRaises(
ZeroDivisionError, asyncio.get_event_loop().run_until_complete,
fut)
self.assertEqual(result, ['bar'])
20 changes: 14 additions & 6 deletions tests/test_asyncio_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _get_mock_server(self):
mock_server = mock.Mock()
mock_server.ping_timeout = 0.2
mock_server.ping_interval = 0.2
mock_server.async_handlers = True
mock_server.async_handlers = False
mock_server._async = {'asyncio': True,
'create_route': mock.MagicMock(),
'translate_request': mock.MagicMock(),
Expand Down Expand Up @@ -102,12 +102,20 @@ def test_ping_pong(self):
self.assertEqual(len(r), 1)
self.assertTrue(r[0].encode(), b'3abc')

def test_message_handler(self):
def test_message_sync_handler(self):
mock_server = self._get_mock_server()
s = asyncio_socket.AsyncSocket(mock_server, 'sid')
_run(s.receive(packet.Packet(packet.MESSAGE, data='foo')))
mock_server._trigger_event.mock.assert_called_once_with(
'message', 'sid', 'foo')
'message', 'sid', 'foo', run_async=False)

def test_message_async_handler(self):
mock_server = self._get_mock_server()
s = asyncio_socket.AsyncSocket(mock_server, 'sid')
mock_server.async_handlers = True
_run(s.receive(packet.Packet(packet.MESSAGE, data='foo')))
mock_server._trigger_event.mock.assert_called_once_with(
'message', 'sid', 'foo', run_async=True)

def test_invalid_packet(self):
mock_server = self._get_mock_server()
Expand Down Expand Up @@ -281,7 +289,7 @@ def test_websocket_read_write(self):
self.assertTrue(s.upgraded)
self.assertEqual(mock_server._trigger_event.mock.call_count, 2)
mock_server._trigger_event.mock.assert_has_calls([
mock.call('message', 'sid', 'foo'),
mock.call('message', 'sid', 'foo', run_async=False),
mock.call('disconnect', 'sid')])
ws.send.mock.assert_called_with('4bar')

Expand Down Expand Up @@ -309,7 +317,7 @@ def test_websocket_upgrade_read_write(self):
self.assertTrue(s.upgraded)
self.assertEqual(mock_server._trigger_event.mock.call_count, 2)
mock_server._trigger_event.mock.assert_has_calls([
mock.call('message', 'sid', 'foo'),
mock.call('message', 'sid', 'foo', run_async=False),
mock.call('disconnect', 'sid')])
ws.send.mock.assert_called_with('4bar')

Expand Down Expand Up @@ -372,7 +380,7 @@ def test_websocket_ignore_invalid_packet(self):
self.assertTrue(s.connected)
self.assertEqual(mock_server._trigger_event.mock.call_count, 2)
mock_server._trigger_event.mock.assert_has_calls([
mock.call('message', 'sid', foo),
mock.call('message', 'sid', foo, run_async=False),
mock.call('disconnect', 'sid')])
ws.send.mock.assert_called_with('4bar')

Expand Down
8 changes: 4 additions & 4 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,9 @@ def bar(sid, data):
f['bar'] = sid + data
return 'bar'

r = s._trigger_event('connect', 1, 2, async=False)
r = s._trigger_event('connect', 1, 2, run_async=False)
self.assertEqual(r, 3)
r = s._trigger_event('message', 3, 4, async=True)
r = s._trigger_event('message', 3, 4, run_async=True)
r.join()
self.assertEqual(f['bar'], 7)
r = s._trigger_event('message', 5, 6)
Expand All @@ -286,9 +286,9 @@ def foo(sid, environ):
def bar(sid, data):
return 1 / 0

r = s._trigger_event('connect', 1, 2, async=False)
r = s._trigger_event('connect', 1, 2, run_async=False)
self.assertEqual(r, False)
r = s._trigger_event('message', 3, 4, async=False)
r = s._trigger_event('message', 3, 4, run_async=False)
self.assertEqual(r, None)

def test_close_one_socket(self):
Expand Down
20 changes: 11 additions & 9 deletions tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,17 @@ def test_message_async_handler(self):
s = socket.Socket(mock_server, 'sid')
s.receive(packet.Packet(packet.MESSAGE, data='foo'))
mock_server._trigger_event.assert_called_once_with('message', 'sid',
'foo', async=True)
'foo',
run_async=True)

def test_message_sync_handler(self):
mock_server = self._get_mock_server()
mock_server.async_handlers = False
s = socket.Socket(mock_server, 'sid')
s.receive(packet.Packet(packet.MESSAGE, data='foo'))
mock_server._trigger_event.assert_called_once_with('message', 'sid',
'foo', async=False)
'foo',
run_async=False)

def test_invalid_packet(self):
mock_server = self._get_mock_server()
Expand Down Expand Up @@ -275,8 +277,8 @@ def test_websocket_read_write(self):
self.assertTrue(s.upgraded)
self.assertEqual(mock_server._trigger_event.call_count, 2)
mock_server._trigger_event.assert_has_calls([
mock.call('message', 'sid', 'foo', async=True),
mock.call('disconnect', 'sid', async=False)])
mock.call('message', 'sid', 'foo', run_async=True),
mock.call('disconnect', 'sid', run_async=False)])
ws.send.assert_called_with('4bar')

def test_websocket_upgrade_read_write(self):
Expand All @@ -302,8 +304,8 @@ def test_websocket_upgrade_read_write(self):
self.assertTrue(s.upgraded)
self.assertEqual(mock_server._trigger_event.call_count, 2)
mock_server._trigger_event.assert_has_calls([
mock.call('message', 'sid', 'foo', async=True),
mock.call('disconnect', 'sid', async=False)])
mock.call('message', 'sid', 'foo', run_async=True),
mock.call('disconnect', 'sid', run_async=False)])
ws.send.assert_called_with('4bar')

def test_websocket_upgrade_with_payload(self):
Expand Down Expand Up @@ -362,8 +364,8 @@ def test_websocket_ignore_invalid_packet(self):
self.assertTrue(s.connected)
self.assertEqual(mock_server._trigger_event.call_count, 2)
mock_server._trigger_event.assert_has_calls([
mock.call('message', 'sid', foo, async=True),
mock.call('disconnect', 'sid', async=False)])
mock.call('message', 'sid', foo, run_async=True),
mock.call('disconnect', 'sid', run_async=False)])
ws.send.assert_called_with('4bar')

def test_send_after_close(self):
Expand All @@ -379,7 +381,7 @@ def test_close_after_close(self):
self.assertTrue(s.closed)
self.assertEqual(mock_server._trigger_event.call_count, 1)
mock_server._trigger_event.assert_called_once_with('disconnect', 'sid',
async=False)
run_async=False)
s.close()
self.assertEqual(mock_server._trigger_event.call_count, 1)

Expand Down

0 comments on commit 6609416

Please sign in to comment.