Skip to content

Commit

Permalink
Add JSONP support in the server (Fixes #98)
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneMoe authored and miguelgrinberg committed May 26, 2019
1 parent 3de4488 commit 36a1598
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 101 deletions.
110 changes: 62 additions & 48 deletions engineio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,65 +181,77 @@ async def handle_request(self, *args, **kwargs):
environ = translate_request(*args, **kwargs)
method = environ['REQUEST_METHOD']
query = urllib.parse.parse_qs(environ.get('QUERY_STRING', ''))

sid = query['sid'][0] if 'sid' in query else None
b64 = False
jsonp = False
jsonp_index = None

if 'b64' in query:
if query['b64'][0] == "1" or query['b64'][0].lower() == "true":
b64 = True
if 'j' in query:
self.logger.warning('JSONP requests are not supported')
jsonp = True
try:
jsonp_index = int(query['j'][0])
except (ValueError, KeyError, IndexError):
# Invalid JSONP index number
pass

if jsonp and jsonp_index is None:
self.logger.warning('Invalid JSONP index number')
r = self._bad_request()
else:
sid = query['sid'][0] if 'sid' in query else None
b64 = False
if 'b64' in query:
if query['b64'][0] == "1" or query['b64'][0].lower() == "true":
b64 = True
if method == 'GET':
if sid is None:
transport = query.get('transport', ['polling'])[0]
if transport != 'polling' and transport != 'websocket':
self.logger.warning('Invalid transport %s', transport)
r = self._bad_request()
else:
r = await self._handle_connect(environ, transport,
b64)
elif method == 'GET':
if sid is None:
transport = query.get('transport', ['polling'])[0]
if transport != 'polling' and transport != 'websocket':
self.logger.warning('Invalid transport %s', transport)
r = self._bad_request()
else:
if sid not in self.sockets:
self.logger.warning('Invalid session %s', sid)
r = self._bad_request()
else:
socket = self._get_socket(sid)
try:
packets = await socket.handle_get_request(environ)
if isinstance(packets, list):
r = self._ok(packets, b64=b64)
else:
r = packets
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
await self.disconnect(sid)
r = self._bad_request()
if sid in self.sockets and self.sockets[sid].closed:
del self.sockets[sid]
elif method == 'POST':
if sid is None or sid not in self.sockets:
r = await self._handle_connect(environ, transport,
b64, jsonp_index)
else:
if sid not in self.sockets:
self.logger.warning('Invalid session %s', sid)
r = self._bad_request()
else:
socket = self._get_socket(sid)
try:
await socket.handle_post_request(environ)
r = self._ok()
packets = await socket.handle_get_request(environ)
if isinstance(packets, list):
r = self._ok(packets, b64=b64,
jsonp_index=jsonp_index)
else:
r = packets
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
await self.disconnect(sid)
r = self._bad_request()
except: # pragma: no cover
# for any other unexpected errors, we log the error
# and keep going
self.logger.exception('post request handler error')
r = self._ok()
elif method == 'OPTIONS':
r = self._ok()
if sid in self.sockets and self.sockets[sid].closed:
del self.sockets[sid]
elif method == 'POST':
if sid is None or sid not in self.sockets:
self.logger.warning('Invalid session %s', sid)
r = self._bad_request()
else:
self.logger.warning('Method %s not supported', method)
r = self._method_not_found()
socket = self._get_socket(sid)
try:
await socket.handle_post_request(environ)
r = self._ok(jsonp_index=jsonp_index)
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
await self.disconnect(sid)
r = self._bad_request()
except: # pragma: no cover
# for any other unexpected errors, we log the error
# and keep going
self.logger.exception('post request handler error')
r = self._ok(jsonp_index=jsonp_index)
elif method == 'OPTIONS':
r = self._ok()
else:
self.logger.warning('Method %s not supported', method)
r = self._method_not_found()
if not isinstance(r, dict):
return r if r is not None else []
if self.http_compression and \
Expand Down Expand Up @@ -320,7 +332,8 @@ def create_event(self, *args, **kwargs):
"""
return asyncio.Event(*args, **kwargs)

async def _handle_connect(self, environ, transport, b64=False):
async def _handle_connect(self, environ, transport, b64=False,
jsonp_index=None):
"""Handle a client connection request."""
if self.start_service_task:
# start the service task to monitor connected clients
Expand Down Expand Up @@ -357,7 +370,8 @@ async def _handle_connect(self, environ, transport, b64=False):
if self.cookie:
headers = [('Set-Cookie', self.cookie + '=' + sid)]
try:
return self._ok(await s.poll(), headers=headers, b64=b64)
return self._ok(await s.poll(), headers=headers, b64=b64,
jsonp_index=jsonp_index)
except exceptions.QueueEmpty:
return self._bad_request()

Expand Down
15 changes: 14 additions & 1 deletion engineio/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from . import packet

from six.moves import urllib


class Payload(object):
"""Engine.IO payload."""
Expand All @@ -10,7 +12,7 @@ def __init__(self, packets=None, encoded_payload=None):
if encoded_payload is not None:
self.decode(encoded_payload)

def encode(self, b64=False):
def encode(self, b64=False, jsonp_index=None):
"""Encode the payload for transmission."""
encoded_payload = b''
for pkt in self.packets:
Expand All @@ -29,12 +31,23 @@ def encode(self, b64=False):
else:
encoded_payload += b'\1'
encoded_payload += binary_len + b'\xff' + encoded_packet
if jsonp_index is not None:
encoded_payload = b'___eio[' + \
str(jsonp_index).encode() + \
b']("' + \
encoded_payload.replace(b'"', b'\\"') + \
b'");'
return encoded_payload

def decode(self, encoded_payload):
"""Decode a transmitted payload."""
self.packets = []
while encoded_payload:
# JSONP POST payload starts with 'd='
if encoded_payload.startswith(b'd='):
encoded_payload = urllib.parse.parse_qs(
encoded_payload)[b'd'][0]

if six.byte2int(encoded_payload[0:1]) <= 1:
packet_len = 0
i = 1
Expand Down
118 changes: 67 additions & 51 deletions engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,66 +311,79 @@ def handle_request(self, environ, start_response):
"""
method = environ['REQUEST_METHOD']
query = urllib.parse.parse_qs(environ.get('QUERY_STRING', ''))

sid = query['sid'][0] if 'sid' in query else None
b64 = False
jsonp = False
jsonp_index = None

if 'b64' in query:
if query['b64'][0] == "1" or query['b64'][0].lower() == "true":
b64 = True
if 'j' in query:
self.logger.warning('JSONP requests are not supported')
jsonp = True
try:
jsonp_index = int(query['j'][0])
except (ValueError, KeyError, IndexError):
# Invalid JSONP index number
pass

if jsonp and jsonp_index is None:
self.logger.warning('Invalid JSONP index number')
r = self._bad_request()
else:
sid = query['sid'][0] if 'sid' in query else None
b64 = False
if 'b64' in query:
if query['b64'][0] == "1" or query['b64'][0].lower() == "true":
b64 = True
if method == 'GET':
if sid is None:
transport = query.get('transport', ['polling'])[0]
if transport != 'polling' and transport != 'websocket':
self.logger.warning('Invalid transport %s', transport)
r = self._bad_request()
else:
r = self._handle_connect(environ, start_response,
transport, b64)
elif method == 'GET':
if sid is None:
transport = query.get('transport', ['polling'])[0]
if transport != 'polling' and transport != 'websocket':
self.logger.warning('Invalid transport %s', transport)
r = self._bad_request()
else:
if sid not in self.sockets:
self.logger.warning('Invalid session %s', sid)
r = self._bad_request()
else:
socket = self._get_socket(sid)
try:
packets = socket.handle_get_request(
environ, start_response)
if isinstance(packets, list):
r = self._ok(packets, b64=b64)
else:
r = packets
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
self.disconnect(sid)
r = self._bad_request()
if sid in self.sockets and self.sockets[sid].closed:
del self.sockets[sid]
elif method == 'POST':
if sid is None or sid not in self.sockets:
r = self._handle_connect(environ, start_response,
transport, b64, jsonp_index)
else:
if sid not in self.sockets:
self.logger.warning('Invalid session %s', sid)
r = self._bad_request()
else:
socket = self._get_socket(sid)
try:
socket.handle_post_request(environ)
r = self._ok()
packets = socket.handle_get_request(
environ, start_response)
if isinstance(packets, list):
r = self._ok(packets, b64=b64,
jsonp_index=jsonp_index)
else:
r = packets
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
self.disconnect(sid)
r = self._bad_request()
except: # pragma: no cover
# for any other unexpected errors, we log the error
# and keep going
self.logger.exception('post request handler error')
r = self._ok()
elif method == 'OPTIONS':
r = self._ok()
if sid in self.sockets and self.sockets[sid].closed:
del self.sockets[sid]
elif method == 'POST':
if sid is None or sid not in self.sockets:
self.logger.warning('Invalid session %s', sid)
r = self._bad_request()
else:
self.logger.warning('Method %s not supported', method)
r = self._method_not_found()
socket = self._get_socket(sid)
try:
socket.handle_post_request(environ)
r = self._ok(jsonp_index=jsonp_index)
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
self.disconnect(sid)
r = self._bad_request()
except: # pragma: no cover
# for any other unexpected errors, we log the error
# and keep going
self.logger.exception('post request handler error')
r = self._ok(jsonp_index=jsonp_index)
elif method == 'OPTIONS':
r = self._ok()
else:
self.logger.warning('Method %s not supported', method)
r = self._method_not_found()

if not isinstance(r, dict):
return r or []
if self.http_compression and \
Expand Down Expand Up @@ -447,7 +460,8 @@ def _generate_id(self):
"""Generate a unique session id."""
return uuid.uuid4().hex

def _handle_connect(self, environ, start_response, transport, b64=False):
def _handle_connect(self, environ, start_response, transport, b64=False,
jsonp_index=None):
"""Handle a client connection request."""
if self.start_service_task:
# start the service task to monitor connected clients
Expand Down Expand Up @@ -483,7 +497,8 @@ def _handle_connect(self, environ, start_response, transport, b64=False):
if self.cookie:
headers = [('Set-Cookie', self.cookie + '=' + sid)]
try:
return self._ok(s.poll(), headers=headers, b64=b64)
return self._ok(s.poll(), headers=headers, b64=b64,
jsonp_index=jsonp_index)
except exceptions.QueueEmpty:
return self._bad_request()

Expand Down Expand Up @@ -521,7 +536,7 @@ def _get_socket(self, sid):
raise KeyError('Session is disconnected')
return s

def _ok(self, packets=None, headers=None, b64=False):
def _ok(self, packets=None, headers=None, b64=False, jsonp_index=None):
"""Generate a successful HTTP response."""
if packets is not None:
if headers is None:
Expand All @@ -532,7 +547,8 @@ def _ok(self, packets=None, headers=None, b64=False):
headers += [('Content-Type', 'application/octet-stream')]
return {'status': '200 OK',
'headers': headers,
'response': payload.Payload(packets=packets).encode(b64)}
'response': payload.Payload(packets=packets).encode(
b64=b64, jsonp_index=jsonp_index)}
else:
return {'status': '200 OK',
'headers': [('Content-Type', 'text/plain')],
Expand Down
17 changes: 17 additions & 0 deletions tests/asyncio/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,23 @@ def test_jsonp_not_supported(self, import_module):
self.assertEqual(a._async['make_response'].call_args[0][0],
'400 BAD REQUEST')

@mock.patch('importlib.import_module')
def test_jsonp_index(self, import_module):
a = self.get_async_mock({'REQUEST_METHOD': 'GET',
'QUERY_STRING': 'j=233'})
import_module.side_effect = [a]
s = asyncio_server.AsyncServer()
response = _run(s.handle_request('request'))
self.assertEqual(response, 'response')
a._async['translate_request'].assert_called_once_with('request')
self.assertEqual(a._async['make_response'].call_count, 1)
self.assertEqual(a._async['make_response'].call_args[0][0], '200 OK')
print('***', a._async['make_response'].call_args[0][2])
self.assertTrue(a._async['make_response'].call_args[0][2].startswith(
b'___eio[233]("'))
self.assertTrue(a._async['make_response'].call_args[0][2].endswith(
b'");'))

@mock.patch('importlib.import_module')
def test_connect(self, import_module):
a = self.get_async_mock()
Expand Down
13 changes: 13 additions & 0 deletions tests/common/test_payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,19 @@ def test_decode_payload_xhr_binary(self):
p = payload.Payload(encoded_payload=b'6:b4AAEC')
self.assertEqual(p.encode(), b'\x01\x04\xff\x04\x00\x01\x02')

def test_encode_jsonp_payload(self):
pkt = packet.Packet(packet.MESSAGE, data=six.text_type('abc'))
p = payload.Payload([pkt])
self.assertEqual(p.packets, [pkt])
self.assertEqual(p.encode(jsonp_index=233),
b'___eio[233]("\x00\x04\xff4abc");')
self.assertEqual(p.encode(jsonp_index=233, b64=True),
b'___eio[233]("4:4abc");')

def test_decode_jsonp_payload(self):
p = payload.Payload(encoded_payload=b'd=4:4abc')
self.assertEqual(p.encode(), b'\x00\x04\xff4abc')

def test_decode_invalid_payload(self):
self.assertRaises(ValueError, payload.Payload,
encoded_payload=b'bad payload')
Expand Down
Loading

0 comments on commit 36a1598

Please sign in to comment.