Skip to content

Commit

Permalink
Accept direct websocket connections
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Sep 6, 2015
1 parent fbc018f commit 448acfb
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 34 deletions.
Empty file modified .travis.yml
100755 → 100644
Empty file.
25 changes: 19 additions & 6 deletions engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,13 @@ def handle_request(self, environ, start_response):
b64 = True
if method == 'GET':
if sid is None:
r = self._handle_connect(environ, b64)
transport = query.get('transport', ['polling'])[0]
if transport != 'polling' and transport != 'websocket':
self.logger.warning('Invalid transport %s', transport)
r = self._bad_request()
else:
r = self._handle_connect(environ, start_response,
transport, b64)
else:
if sid not in self.sockets:
self.logger.warning('Invalid session %s', sid)
Expand Down Expand Up @@ -252,25 +258,32 @@ def _generate_id(self):
"""Generate a unique session id."""
return uuid.uuid4().hex

def _handle_connect(self, environ, b64=False):
def _handle_connect(self, environ, start_response, transport, b64=False):
"""Handle a client connection request."""
sid = self._generate_id()
s = socket.Socket(self, sid)
self.sockets[sid] = s

pkt = packet.Packet(
packet.OPEN, {'sid': sid,
'upgrades': self._upgrades(sid),
'pingTimeout': int(self.ping_timeout * 1000),
'pingInterval': int(self.ping_interval * 1000)})
s.send(pkt)

if self._trigger_event('connect', sid, environ) is False:
self.logger.warning('Application rejected connection')
del self.sockets[sid]
return self._unauthorized()
headers = None
if self.cookie:
headers = [('Set-Cookie', self.cookie + '=' + sid)]
return self._ok(s.poll(), headers=headers, b64=b64)

if transport == 'websocket':
s.handle_get_request(environ, start_response)
return self._ok()
else:
headers = None
if self.cookie:
headers = [('Set-Cookie', self.cookie + '=' + sid)]
return self._ok(s.poll(), headers=headers, b64=b64)

def _upgrades(self, sid):
"""Return the list of possible upgrades for a client connection."""
Expand Down
46 changes: 27 additions & 19 deletions engineio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self, server, sid):
self.queue = getattr(self.server.async['queue'],
self.server.async['queue_class'])()
self.last_ping = time.time()
self.connected = False
self.upgraded = False
self.closed = False

Expand Down Expand Up @@ -72,6 +73,7 @@ def handle_get_request(self, environ, start_response):
self.sid, transport)
return getattr(self, '_upgrade_' + transport)(environ,
start_response)
self.connected = True
try:
packets = self.poll()
except IOError as e:
Expand Down Expand Up @@ -110,25 +112,31 @@ def _upgrade_websocket(self, environ, start_response):

def _websocket_handler(self, ws):
"""Engine.IO handler for websocket transport."""
pkt = ws.wait()
if pkt != packet.Packet(packet.PING,
data=six.text_type('probe')).encode(
always_bytes=False):
self.server.logger.info(
'%s: Failed websocket upgrade, no PING packet', self.sid)
return
ws.send(packet.Packet(packet.PONG, data=six.text_type('probe')).encode(
always_bytes=False))
self.send(packet.Packet(packet.NOOP))
self.upgraded = True
self.queue.join()

pkt = ws.wait()
if pkt != packet.Packet(packet.UPGRADE).encode(always_bytes=False):
self.upgraded = False
self.server.logger.info(
'%s: Failed websocket upgrade, no UPGRADE packet', self.sid)
return
if self.connected:
# the socket was already connected, so this is an upgrade
pkt = ws.wait()
if pkt != packet.Packet(packet.PING,
data=six.text_type('probe')).encode(
always_bytes=False):
self.server.logger.info(
'%s: Failed websocket upgrade, no PING packet', self.sid)
return
ws.send(packet.Packet(
packet.PONG,
data=six.text_type('probe')).encode(always_bytes=False))
self.send(packet.Packet(packet.NOOP))
self.upgraded = True
self.queue.join()

pkt = ws.wait()
if pkt != packet.Packet(packet.UPGRADE).encode(always_bytes=False):
self.upgraded = False
self.server.logger.info(
'%s: Failed websocket upgrade, no UPGRADE packet',
self.sid)
return
else:
self.connected = True

def writer():
while True:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,26 @@ def test_connect_custom_ping_times(self):
self.assertEqual(packets[0].data['pingTimeout'], 123000)
self.assertEqual(packets[0].data['pingInterval'], 456000)

@mock.patch('engineio.socket.Socket',
return_value=mock.MagicMock(connected=False, closed=False))
def test_connect_transport_websocket(self, Socket):
s = server.Server()
s._generate_id = mock.MagicMock(return_value='123')
environ = {'REQUEST_METHOD': 'GET',
'QUERY_STRING': 'transport=websocket'}
start_response = mock.MagicMock()
s.handle_request(environ, start_response)
self.assertEqual(s.sockets['123'].send.call_args[0][0].packet_type,
packet.OPEN)

def test_connect_transport_invalid(self):
s = server.Server()
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'transport=foo'}
start_response = mock.MagicMock()
s.handle_request(environ, start_response)
self.assertEqual(start_response.call_args[0][0],
'400 BAD REQUEST')

def test_connect_cors_headers(self):
s = server.Server()
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
Expand Down
43 changes: 34 additions & 9 deletions tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def test_upgrade(self):
mock_server.async['websocket'].WebSocket.configure_mock(
return_value=mock_ws)
s = socket.Socket(mock_server, 'sid')
s.connected = True
environ = "foo"
start_response = "bar"
s._upgrade_websocket(environ, start_response)
Expand All @@ -153,6 +154,7 @@ def test_upgrade_twice(self):
mock_server = self._get_mock_server()
mock_server.async['websocket'] = mock.MagicMock()
s = socket.Socket(mock_server, 'sid')
s.connected = True
s.upgraded = True
environ = "foo"
start_response = "bar"
Expand All @@ -162,6 +164,7 @@ def test_upgrade_twice(self):
def test_upgrade_packet(self):
mock_server = self._get_mock_server()
s = socket.Socket(mock_server, 'sid')
s.connected = True
s.receive(packet.Packet(packet.UPGRADE))
r = s.poll()
self.assertEqual(len(r), 1)
Expand All @@ -170,6 +173,7 @@ def test_upgrade_packet(self):
def test_upgrade_no_probe(self):
mock_server = self._get_mock_server()
s = socket.Socket(mock_server, 'sid')
s.connected = True
ws = mock.MagicMock()
ws.wait.return_value = packet.Packet(packet.NOOP).encode(
always_bytes=False)
Expand All @@ -179,6 +183,7 @@ def test_upgrade_no_probe(self):
def test_upgrade_no_upgrade_packet(self):
mock_server = self._get_mock_server()
s = socket.Socket(mock_server, 'sid')
s.connected = True
s.queue.join = mock.MagicMock(return_value=None)
ws = mock.MagicMock()
probe = six.text_type('probe')
Expand All @@ -195,6 +200,31 @@ def test_upgrade_no_upgrade_packet(self):
def test_websocket_read_write(self):
mock_server = self._get_mock_server()
s = socket.Socket(mock_server, 'sid')
s.connected = False
s.queue.join = mock.MagicMock(return_value=None)
foo = six.text_type('foo')
bar = six.text_type('bar')
s.poll = mock.MagicMock(side_effect=[
[packet.Packet(packet.MESSAGE, data=bar)], IOError])
ws = mock.MagicMock()
ws.wait.side_effect = [
packet.Packet(packet.MESSAGE, data=foo).encode(
always_bytes=False),
None]
s._websocket_handler(ws)
time.sleep(0)
self.assertTrue(s.connected)
self.assertFalse(s.upgraded)
self.assertEqual(mock_server._trigger_event.call_count, 2)
mock_server._trigger_event.assert_has_calls([
mock.call('message', 'sid', 'foo'),
mock.call('disconnect', 'sid')])
ws.send.assert_called_with('4bar')

def test_websocket_upgrade_read_write(self):
mock_server = self._get_mock_server()
s = socket.Socket(mock_server, 'sid')
s.connected = True
s.queue.join = mock.MagicMock(return_value=None)
foo = six.text_type('foo')
bar = six.text_type('bar')
Expand All @@ -221,17 +251,15 @@ def test_websocket_read_write(self):
def test_websocket_read_write_fail(self):
mock_server = self._get_mock_server()
s = socket.Socket(mock_server, 'sid')
s.connected = False
s.queue.join = mock.MagicMock(return_value=None)
foo = six.text_type('foo')
bar = six.text_type('bar')
probe = six.text_type('probe')
s.poll = mock.MagicMock(side_effect=[
[packet.Packet(packet.MESSAGE, data=bar)],
[packet.Packet(packet.MESSAGE, data=bar)], IOError])
ws = mock.MagicMock()
ws.wait.side_effect = [
packet.Packet(packet.PING, data=probe).encode(
always_bytes=False),
packet.Packet(packet.UPGRADE).encode(always_bytes=False),
packet.Packet(packet.MESSAGE, data=foo).encode(
always_bytes=False),
RuntimeError]
Expand All @@ -243,24 +271,21 @@ def test_websocket_read_write_fail(self):
def test_websocket_ignore_invalid_packet(self):
mock_server = self._get_mock_server()
s = socket.Socket(mock_server, 'sid')
s.connected = False
s.queue.join = mock.MagicMock(return_value=None)
foo = six.text_type('foo')
bar = six.text_type('bar')
probe = six.text_type('probe')
s.poll = mock.MagicMock(side_effect=[
[packet.Packet(packet.MESSAGE, data=bar)], IOError])
ws = mock.MagicMock()
ws.wait.side_effect = [
packet.Packet(packet.PING, data=probe).encode(
always_bytes=False),
packet.Packet(packet.UPGRADE).encode(always_bytes=False),
packet.Packet(packet.OPEN).encode(always_bytes=False),
packet.Packet(packet.MESSAGE, data=foo).encode(
always_bytes=False),
None]
s._websocket_handler(ws)
time.sleep(0)
self.assertTrue(s.upgraded)
self.assertTrue(s.connected)
self.assertEqual(mock_server._trigger_event.call_count, 2)
mock_server._trigger_event.assert_has_calls([
mock.call('message', 'sid', foo),
Expand Down

0 comments on commit 448acfb

Please sign in to comment.