Skip to content

Commit

Permalink
proper handling of closed sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Mar 15, 2017
1 parent 52b37e5 commit 2144535
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 6 deletions.
6 changes: 5 additions & 1 deletion engineio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,11 @@ async def _handle_connect(self, environ, transport, b64=False):
return self._unauthorized()

if transport == 'websocket':
return await s.handle_get_request(environ)
ret = await s.handle_get_request(environ)
if s.closed:
# websocket connection ended, so we are done
del self.sockets[sid]
return ret
else:
s.connected = True
headers = None
Expand Down
5 changes: 3 additions & 2 deletions engineio/asyncio_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ async def writer():
await ws.send(pkt.encode(always_bytes=False))
except:
break
asyncio.ensure_future(writer())
writer_task = asyncio.ensure_future(writer())

self.server.logger.info(
'%s: Upgrade to websocket successful', self.sid)
Expand All @@ -182,5 +182,6 @@ async def writer():
except ValueError:
pass

await self.close(wait=True, abort=True)
await self.queue.put(None) # unlock the writer task so it can exit
await asyncio.wait_for(writer_task, timeout=None)
await self.close(wait=True, abort=True)
6 changes: 5 additions & 1 deletion engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,11 @@ def _handle_connect(self, environ, start_response, transport, b64=False):
return self._unauthorized()

if transport == 'websocket':
return s.handle_get_request(environ, start_response)
ret = s.handle_get_request(environ, start_response)
if s.closed:
# websocket connection ended, so we are done
del self.sockets[sid]
return ret
else:
s.connected = True
headers = None
Expand Down
5 changes: 3 additions & 2 deletions engineio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def writer():
ws.send(pkt.encode(always_bytes=False))
except:
break
self.server.start_background_task(writer)
writer_task = self.server.start_background_task(writer)

self.server.logger.info(
'%s: Upgrade to websocket successful', self.sid)
Expand All @@ -197,7 +197,8 @@ def writer():
except ValueError:
pass

self.close(wait=True, abort=True)
self.queue.put(None) # unlock the writer task so that it can exit
writer_task.join()
self.close(wait=True, abort=True)

return []
23 changes: 23 additions & 0 deletions tests/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,34 @@ def test_connect_transport_websocket(self, import_module, AsyncSocket):
AsyncSocket.return_value = self._get_mock_socket()
s = asyncio_server.AsyncServer()
s._generate_id = mock.MagicMock(return_value='123')
# force socket to stay open, so that we can check it later
AsyncSocket().closed = False
_run(s.handle_request('request'))
self.assertEqual(
s.sockets['123'].send.mock.call_args[0][0].packet_type,
packet.OPEN)

@mock.patch('engineio.asyncio_socket.AsyncSocket')
@mock.patch('importlib.import_module')
def test_connect_transport_websocket_closed(self, import_module,
AsyncSocket):
a = self.get_async_mock({'REQUEST_METHOD': 'GET',
'QUERY_STRING': 'transport=websocket'})
import_module.side_effect = [a]
AsyncSocket.return_value = self._get_mock_socket()
s = asyncio_server.AsyncServer()
s._generate_id = mock.MagicMock(return_value='123')

# this mock handler just closes the socket, as it would happen on a
# real websocket exchange
@asyncio.coroutine
def mock_handle(environ):
s.sockets['123'].closed = True

AsyncSocket().handle_get_request = mock_handle
_run(s.handle_request('request'))
self.assertNotIn('123', s.sockets) # socket should close on its own

@mock.patch('importlib.import_module')
def test_connect_transport_invalid(self, import_module):
a = self.get_async_mock({'REQUEST_METHOD': 'GET',
Expand Down
18 changes: 18 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,10 +432,28 @@ def test_connect_transport_websocket(self, Socket):
environ = {'REQUEST_METHOD': 'GET',
'QUERY_STRING': 'transport=websocket'}
start_response = mock.MagicMock()
# force socket to stay open, so that we can check it later
Socket().closed = False
s.handle_request(environ, start_response)
self.assertEqual(s.sockets['123'].send.call_args[0][0].packet_type,
packet.OPEN)

@mock.patch('engineio.socket.Socket',
return_value=mock.MagicMock(connected=False, closed=False))
def test_connect_transport_websocket_closed(self, Socket):
s = server.Server()
s._generate_id = mock.MagicMock(return_value='123')
environ = {'REQUEST_METHOD': 'GET',
'QUERY_STRING': 'transport=websocket'}
start_response = mock.MagicMock()

def mock_handle(environ, start_response):
s.sockets['123'].closed = True

Socket().handle_get_request = mock_handle
s.handle_request(environ, start_response)
self.assertNotIn('123', s.sockets)

def test_connect_transport_invalid(self):
s = server.Server()
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'transport=foo'}
Expand Down

0 comments on commit 2144535

Please sign in to comment.