Skip to content

Commit

Permalink
tolerate payloads in UPGRADE packet (fixes #7)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Dec 2, 2015
1 parent b46aced commit 0193b5c
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
2 changes: 1 addition & 1 deletion engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def start_background_task(self, target, *args, **kwargs):
self.async['thread_class'])(target=target, args=args,
kwargs=kwargs)
th.start()
return th
return th # pragma: no cover

def _generate_id(self):
"""Generate a unique session id."""
Expand Down
8 changes: 5 additions & 3 deletions engineio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,13 @@ def _websocket_handler(self, ws):
self.send(packet.Packet(packet.NOOP))

pkt = ws.wait()
if pkt != packet.Packet(packet.UPGRADE).encode(always_bytes=False):
decoded_pkt = packet.Packet(encoded_packet=pkt)
if decoded_pkt.packet_type != packet.UPGRADE:
self.upgraded = False
self.server.logger.info(
'%s: Failed websocket upgrade, no UPGRADE packet',
self.sid)
('%s: Failed websocket upgrade, expected UPGRADE packet, '
'received %s instead.'),
self.sid, pkt)
return
self.upgraded = True
else:
Expand Down
18 changes: 17 additions & 1 deletion tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_upgrade_packet(self):
s.receive(packet.Packet(packet.UPGRADE))
r = s.poll()
self.assertEqual(len(r), 1)
self.assertTrue(r[0].encode(), b'6')
self.assertEqual(r[0].encode(), packet.Packet(packet.NOOP).encode())

def test_upgrade_no_probe(self):
mock_server = self._get_mock_server()
Expand Down Expand Up @@ -255,6 +255,22 @@ def test_websocket_upgrade_read_write(self):
mock.call('disconnect', 'sid')])
ws.send.assert_called_with('4bar')

def test_websocket_upgrade_with_payload(self):
mock_server = self._get_mock_server()
s = socket.Socket(mock_server, 'sid')
s.connected = True
s.queue.join = mock.MagicMock(return_value=None)
probe = six.text_type('probe')
ws = mock.MagicMock()
ws.wait.side_effect = [
packet.Packet(packet.PING, data=probe).encode(
always_bytes=False),
packet.Packet(packet.UPGRADE, data=b'2').encode(
always_bytes=False)]
s._websocket_handler(ws)
time.sleep(0)
self.assertTrue(s.upgraded)

def test_websocket_read_write_fail(self):
mock_server = self._get_mock_server()
s = socket.Socket(mock_server, 'sid')
Expand Down

0 comments on commit 0193b5c

Please sign in to comment.