Skip to content

Commit

Permalink
Improved handling of rejected connections
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed May 22, 2020
1 parent 5b79e28 commit 0e0b26f
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 15 deletions.
8 changes: 8 additions & 0 deletions docs/server.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,14 @@ The ``sid`` argument passed into all the event handlers is a connection
identifier for the client. All the events from a client will use the same
``sid`` value.

The ``connect`` handler is the place where the server can perform
authentication. The value returned by this handler is used to determine if the
connection is accepted or rejected. When the handler does not return any value
(which is the same as returning ``None``) or when it returns ``True`` the
connection is accepted. If the handler returns ``False`` or any JSON
compatible data type (string, integer, list or dictionary) the connection is
rejected. A rejected connection triggers a response with a 401 status code.

The ``data`` argument passed to the ``'message'`` event handler contains
application-specific data provided by the client with the event.

Expand Down
3 changes: 2 additions & 1 deletion engineio/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,10 @@ async def _connect_polling(self, url, headers, engineio_path):
raise exceptions.ConnectionError(
'Connection refused by the server')
if r.status < 200 or r.status >= 300:
self._reset()
raise exceptions.ConnectionError(
'Unexpected status code {} in server response'.format(
r.status))
r.status), await r.json())
try:
p = payload.Payload(encoded_payload=await r.read())
except ValueError:
Expand Down
4 changes: 2 additions & 2 deletions engineio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,10 +384,10 @@ async def _handle_connect(self, environ, transport, b64=False,

ret = await self._trigger_event('connect', sid, environ,
run_async=False)
if ret is False:
if ret is not None and ret is not True:
del self.sockets[sid]
self.logger.warning('Application rejected connection')
return self._unauthorized()
return self._unauthorized(ret or None)

if transport == 'websocket':
ret = await s.handle_get_request(environ)
Expand Down
3 changes: 2 additions & 1 deletion engineio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,10 @@ def _connect_polling(self, url, headers, engineio_path):
raise exceptions.ConnectionError(
'Connection refused by the server')
if r.status_code < 200 or r.status_code >= 300:
self._reset()
raise exceptions.ConnectionError(
'Unexpected status code {} in server response'.format(
r.status_code))
r.status_code), r.json())
try:
p = payload.Payload(encoded_payload=r.content)
except ValueError:
Expand Down
13 changes: 8 additions & 5 deletions engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,10 +506,10 @@ def _handle_connect(self, environ, start_response, transport, b64=False,
s.send(pkt)

ret = self._trigger_event('connect', sid, environ, run_async=False)
if ret is False:
if ret is not None and ret is not True:
del self.sockets[sid]
self.logger.warning('Application rejected connection')
return self._unauthorized()
return self._unauthorized(ret or None)

if transport == 'websocket':
ret = s.handle_get_request(environ, start_response)
Expand Down Expand Up @@ -592,11 +592,14 @@ def _method_not_found(self):
'headers': [('Content-Type', 'text/plain')],
'response': b'Method Not Found'}

def _unauthorized(self):
def _unauthorized(self, message=None):
"""Generate a unauthorized HTTP error response."""
if message is None:
message = 'Unauthorized'
message = packet.Packet.json.dumps(message)
return {'status': '401 UNAUTHORIZED',
'headers': [('Content-Type', 'text/plain')],
'response': b'Unauthorized'}
'headers': [('Content-Type', 'application/json')],
'response': message.encode('utf-8')}

def _cors_allowed_origins(self, environ):
default_origins = []
Expand Down
11 changes: 9 additions & 2 deletions tests/asyncio/test_asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,15 @@ def test_polling_connection_404(self):
c = asyncio_client.AsyncClient()
c._send_request = AsyncMock()
c._send_request.mock.return_value.status = 404
self.assertRaises(
exceptions.ConnectionError, _run, c.connect('http://foo'))
c._send_request.mock.return_value.json = AsyncMock(
return_value={'foo': 'bar'})
try:
_run(c.connect('http://foo'))
except exceptions.ConnectionError as exc:
self.assertEqual(len(exc.args), 2)
self.assertEqual(exc.args[0],
'Unexpected status code 404 in server response')
self.assertEqual(exc.args[1], {'foo': 'bar'})

def test_polling_connection_invalid_packet(self):
c = asyncio_client.AsyncClient()
Expand Down
20 changes: 20 additions & 0 deletions tests/asyncio/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,26 @@ def mock_connect(sid, environ):
self.assertEqual(len(s.sockets), 0)
self.assertEqual(a._async['make_response'].call_args[0][0],
'401 UNAUTHORIZED')
self.assertEqual(a._async['make_response'].call_args[0][2],
b'"Unauthorized"')

@mock.patch('importlib.import_module')
def test_connect_event_rejects_with_message(self, import_module):
a = self.get_async_mock()
import_module.side_effect = [a]
s = asyncio_server.AsyncServer()
s._generate_id = mock.MagicMock(return_value='123')

def mock_connect(sid, environ):
return {'not': 'allowed'}

s.on('connect')(mock_connect)
_run(s.handle_request('request'))
self.assertEqual(len(s.sockets), 0)
self.assertEqual(a._async['make_response'].call_args[0][0],
'401 UNAUTHORIZED')
self.assertEqual(a._async['make_response'].call_args[0][2],
b'{"not": "allowed"}')

@mock.patch('importlib.import_module')
def test_method_not_found(self, import_module):
Expand Down
11 changes: 9 additions & 2 deletions tests/common/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,15 @@ def test_polling_connection_failed(self, _send_request, _time):
@mock.patch('engineio.client.Client._send_request')
def test_polling_connection_404(self, _send_request):
_send_request.return_value.status_code = 404
c = client.Client()
self.assertRaises(exceptions.ConnectionError, c.connect, 'http://foo')
_send_request.return_value.json.return_value = {'foo': 'bar'}
c = client.Client()
try:
c.connect('http://foo')
except exceptions.ConnectionError as exc:
self.assertEqual(len(exc.args), 2)
self.assertEqual(exc.args[0],
'Unexpected status code 404 in server response')
self.assertEqual(exc.args[1], {'foo': 'bar'})

@mock.patch('engineio.client.Client._send_request')
def test_polling_connection_invalid_packet(self, _send_request):
Expand Down
17 changes: 15 additions & 2 deletions tests/common/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ def test_connect_cors_disabled_no_origin(self):
def test_connect_event(self):
s = server.Server()
s._generate_id = mock.MagicMock(return_value='123')
mock_event = mock.MagicMock()
mock_event = mock.MagicMock(return_value=None)
s.on('connect')(mock_event)
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
start_response = mock.MagicMock()
Expand All @@ -739,9 +739,22 @@ def test_connect_event_rejects(self):
s.on('connect')(mock_event)
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
start_response = mock.MagicMock()
s.handle_request(environ, start_response)
ret = s.handle_request(environ, start_response)
self.assertEqual(len(s.sockets), 0)
self.assertEqual(start_response.call_args[0][0], '401 UNAUTHORIZED')
self.assertEqual(ret, [b'"Unauthorized"'])

def test_connect_event_rejects_with_message(self):
s = server.Server()
s._generate_id = mock.MagicMock(return_value='123')
mock_event = mock.MagicMock(return_value='not allowed')
s.on('connect')(mock_event)
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
start_response = mock.MagicMock()
ret = s.handle_request(environ, start_response)
self.assertEqual(len(s.sockets), 0)
self.assertEqual(start_response.call_args[0][0], '401 UNAUTHORIZED')
self.assertEqual(ret, [b'"not allowed"'])

def test_method_not_found(self):
s = server.Server()
Expand Down

0 comments on commit 0e0b26f

Please sign in to comment.