Skip to content

Commit

Permalink
Enable or disable specific transports (Fixes #259)
Browse files Browse the repository at this point in the history
  • Loading branch information
Neverous authored and miguelgrinberg committed Oct 23, 2021
1 parent bcd1a42 commit 8a0e4c3
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 9 deletions.
17 changes: 13 additions & 4 deletions src/engineio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class AsyncServer(server.Server):
:param async_handlers: If set to ``True``, run message event handlers in
non-blocking threads. To run handlers synchronously,
set to ``False``. The default is ``True``.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. Defaults to
``['polling', 'websocket']``.
:param kwargs: Reserved for future extensions, any additional parameters
given as keyword arguments will be silently ignored.
"""
Expand Down Expand Up @@ -213,6 +216,13 @@ async def handle_request(self, *args, **kwargs):
jsonp = False
jsonp_index = None

# make sure the client uses an allowed transport
transport = query.get('transport', ['polling'])[0]
if transport not in self.transports:
self._log_error_once('Invalid transport', 'bad-transport')
return await self._make_response(
self._bad_request('Invalid transport'), environ)

# make sure the client speaks a compatible Engine.IO version
sid = query['sid'][0] if 'sid' in query else None
if sid is None and query.get('EIO') != ['4']:
Expand All @@ -239,7 +249,6 @@ async def handle_request(self, *args, **kwargs):
r = self._bad_request('Invalid JSONP index number')
elif method == 'GET':
if sid is None:
transport = query.get('transport', ['polling'])[0]
# transport must be one of 'polling' or 'websocket'.
# if 'websocket', the HTTP_UPGRADE header must match.
upgrade_header = environ.get('HTTP_UPGRADE').lower() \
Expand All @@ -249,9 +258,9 @@ async def handle_request(self, *args, **kwargs):
r = await self._handle_connect(environ, transport,
jsonp_index)
else:
self._log_error_once('Invalid transport ' + transport,
'bad-transport')
r = self._bad_request('Invalid transport ' + transport)
self._log_error_once('Invalid websocket upgrade',
'bad-upgrade')
r = self._bad_request('Invalid websocket upgrade')
else:
if sid not in self.sockets:
self._log_error_once('Invalid session ' + sid, 'bad-sid')
Expand Down
30 changes: 25 additions & 5 deletions src/engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,15 @@ class Server(object):
inactive clients are closed. Set to ``False`` to
disable the monitoring task (not recommended). The
default is ``True``.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. Defaults to
``['polling', 'websocket']``.
:param kwargs: Reserved for future extensions, any additional parameters
given as keyword arguments will be silently ignored.
"""
compression_methods = ['gzip', 'deflate']
event_names = ['connect', 'disconnect', 'message']
valid_transports = ['polling', 'websocket']
_default_monitor_clients = True
sequence_number = 0

Expand All @@ -91,7 +95,8 @@ def __init__(self, async_mode=None, ping_interval=25, ping_timeout=20,
http_compression=True, compression_threshold=1024,
cookie=None, cors_allowed_origins=None,
cors_credentials=True, logger=False, json=None,
async_handlers=True, monitor_clients=None, **kwargs):
async_handlers=True, monitor_clients=None, transports=None,
**kwargs):
self.ping_timeout = ping_timeout
if isinstance(ping_interval, tuple):
self.ping_interval = ping_interval[0]
Expand Down Expand Up @@ -152,6 +157,14 @@ def __init__(self, async_mode=None, ping_interval=25, ping_timeout=20,
self._async['asyncio']: # pragma: no cover
raise ValueError('The selected async_mode requires asyncio and '
'must use the AsyncServer class')
if transports is not None:
if isinstance(transports, str):
transports = [transports]
transports = [transport for transport in transports
if transport in self.valid_transports]
if not transports:
raise ValueError('No valid transports provided')
self.transports = transports or self.valid_transports
self.logger.info('Server initialized for %s.', self.async_mode)

def is_asyncio_based(self):
Expand Down Expand Up @@ -342,6 +355,14 @@ def handle_request(self, environ, start_response):
jsonp = False
jsonp_index = None

# make sure the client uses an allowed transport
transport = query.get('transport', ['polling'])[0]
if transport not in self.transports:
self._log_error_once('Invalid transport', 'bad-transport')
r = self._bad_request('Invalid transport')
start_response(r['status'], r['headers'])
return [r['response']]

# make sure the client speaks a compatible Engine.IO version
sid = query['sid'][0] if 'sid' in query else None
if sid is None and query.get('EIO') != ['4']:
Expand All @@ -368,7 +389,6 @@ def handle_request(self, environ, start_response):
r = self._bad_request('Invalid JSONP index number')
elif method == 'GET':
if sid is None:
transport = query.get('transport', ['polling'])[0]
# transport must be one of 'polling' or 'websocket'.
# if 'websocket', the HTTP_UPGRADE header must match.
upgrade_header = environ.get('HTTP_UPGRADE').lower() \
Expand All @@ -378,9 +398,9 @@ def handle_request(self, environ, start_response):
r = self._handle_connect(environ, start_response,
transport, jsonp_index)
else:
self._log_error_once('Invalid transport ' + transport,
'bad-transport')
r = self._bad_request('Invalid transport')
self._log_error_once('Invalid websocket upgrade',
'bad-upgrade')
r = self._bad_request('Invalid websocket upgrade')
else:
if sid not in self.sockets:
self._log_error_once('Invalid session ' + sid, 'bad-sid')
Expand Down
13 changes: 13 additions & 0 deletions tests/asyncio/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,3 +1121,16 @@ def test_service_task_started(self, import_module):
s._service_task = AsyncMock()
_run(s.handle_request('request'))
s._service_task.mock.assert_called_once_with()

@mock.patch('importlib.import_module')
def test_transports_disallowed(self, import_module):
a = self.get_async_mock(
{'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'transport=polling'}
)
import_module.side_effect = [a]
s = asyncio_server.AsyncServer(transports='websocket')
response = _run(s.handle_request('request'))
assert response == 'response'
a._async['translate_request'].assert_called_once_with('request')
assert a._async['make_response'].call_count == 1
assert a._async['make_response'].call_args[0][0] == '400 BAD REQUEST'
16 changes: 16 additions & 0 deletions tests/common/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,3 +1149,19 @@ def test_service_task_started(self):
start_response = mock.MagicMock()
s.handle_request(environ, start_response)
s._service_task.assert_called_once_with()

def test_transports_invalid(self):
with pytest.raises(ValueError):
server.Server(transports='invalid')
with pytest.raises(ValueError):
server.Server(transports=['invalid', 'foo'])

def test_transports_disallowed(self):
s = server.Server(transports='websocket')
environ = {
'REQUEST_METHOD': 'GET',
'QUERY_STRING': 'transport=polling',
}
start_response = mock.MagicMock()
s.handle_request(environ, start_response)
assert start_response.call_args[0][0] == '400 BAD REQUEST'

0 comments on commit 8a0e4c3

Please sign in to comment.