Skip to content

Commit

Permalink
Support async make_response function in ASGI driver (Fixes #145)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Nov 3, 2019
1 parent fc984aa commit 3f6391c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
8 changes: 6 additions & 2 deletions engineio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,12 @@ async def handle_request(self, *args, **kwargs):
self.logger.info(origin + ' is not an accepted origin.')
r = self._bad_request()
make_response = self._async['make_response']
response = make_response(r['status'], r['headers'],
r['response'], environ)
if asyncio.iscoroutinefunction(make_response):
response = await make_response(
r['status'], r['headers'], r['response'], environ)
else:
response = make_response(r['status'], r['headers'],
r['response'], environ)
return response

method = environ['REQUEST_METHOD']
Expand Down
17 changes: 16 additions & 1 deletion tests/asyncio/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def test_jsonp_index(self, import_module):
a._async['translate_request'].assert_called_once_with('request')
self.assertEqual(a._async['make_response'].call_count, 1)
self.assertEqual(a._async['make_response'].call_args[0][0], '200 OK')
print('***', a._async['make_response'].call_args[0][2])
self.assertTrue(a._async['make_response'].call_args[0][2].startswith(
b'___eio[233]("'))
self.assertTrue(a._async['make_response'].call_args[0][2].endswith(
Expand Down Expand Up @@ -428,6 +427,22 @@ def test_connect_cors_not_allowed_origin(self, import_module):
self.assertNotIn(('Access-Control-Allow-Origin', 'c'), headers)
self.assertNotIn(('Access-Control-Allow-Origin', '*'), headers)

@mock.patch('importlib.import_module')
def test_connect_cors_not_allowed_origin_async_response(self,
import_module):
a = self.get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
'HTTP_ORIGIN': 'c'})
a._async['make_response'] = AsyncMock(
return_value=a._async['make_response'].return_value)
import_module.side_effect = [a]
s = asyncio_server.AsyncServer(cors_allowed_origins=['a', 'b'])
_run(s.handle_request('request'))
self.assertEqual(a._async['make_response'].mock.call_args[0][0],
'400 BAD REQUEST')
headers = a._async['make_response'].mock.call_args[0][1]
self.assertNotIn(('Access-Control-Allow-Origin', 'c'), headers)
self.assertNotIn(('Access-Control-Allow-Origin', '*'), headers)

@mock.patch('importlib.import_module')
def test_connect_cors_all_origins(self, import_module):
a = self.get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
Expand Down

0 comments on commit 3f6391c

Please sign in to comment.