Skip to content

Commit

Permalink
Apply timeouts to all HTTP requests (Fixes #127)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Aug 4, 2019
1 parent 8acfb81 commit 6666d6a
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 34 deletions.
25 changes: 17 additions & 8 deletions engineio/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class AsyncClient(client.Client):
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
:param request_timeout: A timeout in seconds for requests. The default is
5 seconds.
"""
def is_asyncio_based(self):
return True
Expand Down Expand Up @@ -172,7 +174,8 @@ async def _connect_polling(self, url, headers, engineio_path):
self.base_url = self._get_engineio_url(url, engineio_path, 'polling')
self.logger.info('Attempting polling connection to ' + self.base_url)
r = await self._send_request(
'GET', self.base_url + self._get_url_timestamp(), headers=headers)
'GET', self.base_url + self._get_url_timestamp(), headers=headers,
timeout=self.request_timeout)
if r is None:
self._reset()
raise exceptions.ConnectionError(
Expand Down Expand Up @@ -348,14 +351,18 @@ async def _send_packet(self, pkt):
pkt.data if not isinstance(pkt.data, bytes) else '<binary>')

async def _send_request(
self, method, url, headers=None, body=None): # pragma: no cover
self, method, url, headers=None, body=None,
timeout=None): # pragma: no cover
if self.http is None or self.http.closed:
self.http = aiohttp.ClientSession()
method = getattr(self.http, method.lower())
http_method = getattr(self.http, method.lower())
try:
return await method(url, headers=headers, data=body)
except aiohttp.ClientError:
return
return await http_method(
url, headers=headers, data=body,
timeout=aiohttp.ClientTimeout(total=timeout))
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
self.logger.info('HTTP %s request to %s failed with error %s.',
method, url, exc)

async def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
Expand Down Expand Up @@ -424,7 +431,8 @@ async def _read_loop_polling(self):
self.logger.info(
'Sending polling GET request to ' + self.base_url)
r = await self._send_request(
'GET', self.base_url + self._get_url_timestamp())
'GET', self.base_url + self._get_url_timestamp(),
timeout=max(self.ping_interval, self.ping_timeout) + 5)
if r is None:
self.logger.warning(
'Connection refused by the server, aborting')
Expand Down Expand Up @@ -530,7 +538,8 @@ async def _write_loop(self):
p = payload.Payload(packets=packets)
r = await self._send_request(
'POST', self.base_url, body=p.encode(),
headers={'Content-Type': 'application/octet-stream'})
headers={'Content-Type': 'application/octet-stream'},
timeout=self.request_timeout)
for pkt in packets:
self.queue.task_done()
if r is None:
Expand Down
26 changes: 18 additions & 8 deletions engineio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ class Client(object):
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
:param request_timeout: A timeout in seconds for requests. The default is
5 seconds.
"""
event_names = ['connect', 'disconnect', 'message']

def __init__(self, logger=False, json=None):
def __init__(self, logger=False, json=None, request_timeout=5):
self.handlers = {}
self.base_url = None
self.transports = None
Expand Down Expand Up @@ -93,6 +95,8 @@ def __init__(self, logger=False, json=None):
self.logger.setLevel(logging.ERROR)
self.logger.addHandler(logging.StreamHandler())

self.request_timeout = request_timeout

def is_asyncio_based(self):
return False

Expand Down Expand Up @@ -264,7 +268,8 @@ def _connect_polling(self, url, headers, engineio_path):
self.base_url = self._get_engineio_url(url, engineio_path, 'polling')
self.logger.info('Attempting polling connection to ' + self.base_url)
r = self._send_request(
'GET', self.base_url + self._get_url_timestamp(), headers=headers)
'GET', self.base_url + self._get_url_timestamp(), headers=headers,
timeout=self.request_timeout)
if r is None:
self._reset()
raise exceptions.ConnectionError(
Expand Down Expand Up @@ -437,13 +442,16 @@ def _send_packet(self, pkt):
pkt.data if not isinstance(pkt.data, bytes) else '<binary>')

def _send_request(
self, method, url, headers=None, body=None): # pragma: no cover
self, method, url, headers=None, body=None,
timeout=None): # pragma: no cover
if self.http is None:
self.http = requests.Session()
try:
return self.http.request(method, url, headers=headers, data=body)
except requests.exceptions.RequestException:
pass
return self.http.request(method, url, headers=headers, data=body,
timeout=timeout)
except requests.exceptions.RequestException as exc:
self.logger.info('HTTP %s request to %s failed with error %s.',
method, url, exc)

def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
Expand Down Expand Up @@ -507,7 +515,8 @@ def _read_loop_polling(self):
self.logger.info(
'Sending polling GET request to ' + self.base_url)
r = self._send_request(
'GET', self.base_url + self._get_url_timestamp())
'GET', self.base_url + self._get_url_timestamp(),
timeout=max(self.ping_interval, self.ping_timeout) + 5)
if r is None:
self.logger.warning(
'Connection refused by the server, aborting')
Expand Down Expand Up @@ -612,7 +621,8 @@ def _write_loop(self):
p = payload.Payload(packets=packets)
r = self._send_request(
'POST', self.base_url, body=p.encode(),
headers={'Content-Type': 'application/octet-stream'})
headers={'Content-Type': 'application/octet-stream'},
timeout=self.request_timeout)
for pkt in packets:
self.queue.task_done()
if r is None:
Expand Down
26 changes: 17 additions & 9 deletions tests/asyncio/test_asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def test_polling_connection_failed(self, _time):
'http://foo', headers={'Foo': 'Bar'}))
c._send_request.mock.assert_called_once_with(
'GET', 'http://foo/engine.io/?transport=polling&EIO=3&t=123.456',
headers={'Foo': 'Bar'})
headers={'Foo': 'Bar'}, timeout=5)

def test_polling_connection_404(self):
c = asyncio_client.AsyncClient()
Expand Down Expand Up @@ -824,6 +824,8 @@ def test_read_loop_polling_disconnected(self):
@mock.patch('engineio.client.time.time', return_value=123.456)
def test_read_loop_polling_no_response(self, _time):
c = asyncio_client.AsyncClient()
c.ping_interval = 25
c.ping_timeout = 5
c.state = 'connected'
c.base_url = 'http://foo'
c.queue = mock.MagicMock()
Expand All @@ -836,13 +838,15 @@ def test_read_loop_polling_no_response(self, _time):
self.assertEqual(c.state, 'disconnected')
c.queue.put.mock.assert_called_once_with(None)
c._send_request.mock.assert_called_once_with(
'GET', 'http://foo&t=123.456')
'GET', 'http://foo&t=123.456', timeout=30)
c._trigger_event.mock.assert_called_once_with('disconnect',
run_async=False)

@mock.patch('engineio.client.time.time', return_value=123.456)
def test_read_loop_polling_bad_status(self, _time):
c = asyncio_client.AsyncClient()
c.ping_interval = 25
c.ping_timeout = 5
c.state = 'connected'
c.base_url = 'http://foo'
c.queue = mock.MagicMock()
Expand All @@ -855,11 +859,13 @@ def test_read_loop_polling_bad_status(self, _time):
self.assertEqual(c.state, 'disconnected')
c.queue.put.mock.assert_called_once_with(None)
c._send_request.mock.assert_called_once_with(
'GET', 'http://foo&t=123.456')
'GET', 'http://foo&t=123.456', timeout=30)

@mock.patch('engineio.client.time.time', return_value=123.456)
def test_read_loop_polling_bad_packet(self, _time):
c = asyncio_client.AsyncClient()
c.ping_interval = 25
c.ping_timeout = 60
c.state = 'connected'
c.base_url = 'http://foo'
c.queue = mock.MagicMock()
Expand All @@ -874,10 +880,12 @@ def test_read_loop_polling_bad_packet(self, _time):
self.assertEqual(c.state, 'disconnected')
c.queue.put.mock.assert_called_once_with(None)
c._send_request.mock.assert_called_once_with(
'GET', 'http://foo&t=123.456')
'GET', 'http://foo&t=123.456', timeout=65)

def test_read_loop_polling(self):
c = asyncio_client.AsyncClient()
c.ping_interval = 25
c.ping_timeout = 5
c.state = 'connected'
c.base_url = 'http://foo'
c.queue = mock.MagicMock()
Expand Down Expand Up @@ -1008,7 +1016,7 @@ def test_write_loop_polling_one_packet(self):
packets=[packet.Packet(packet.MESSAGE, {'foo': 'bar'})])
c._send_request.mock.assert_called_once_with(
'POST', 'http://foo', body=p.encode(),
headers={'Content-Type': 'application/octet-stream'})
headers={'Content-Type': 'application/octet-stream'}, timeout=5)

def test_write_loop_polling_three_packets(self):
c = asyncio_client.AsyncClient()
Expand Down Expand Up @@ -1039,7 +1047,7 @@ def test_write_loop_polling_three_packets(self):
])
c._send_request.mock.assert_called_once_with(
'POST', 'http://foo', body=p.encode(),
headers={'Content-Type': 'application/octet-stream'})
headers={'Content-Type': 'application/octet-stream'}, timeout=5)

def test_write_loop_polling_two_packets_done(self):
c = asyncio_client.AsyncClient()
Expand Down Expand Up @@ -1068,7 +1076,7 @@ def test_write_loop_polling_two_packets_done(self):
])
c._send_request.mock.assert_called_once_with(
'POST', 'http://foo', body=p.encode(),
headers={'Content-Type': 'application/octet-stream'})
headers={'Content-Type': 'application/octet-stream'}, timeout=5)
self.assertEqual(c.state, 'connected')

def test_write_loop_polling_bad_connection(self):
Expand All @@ -1093,7 +1101,7 @@ def test_write_loop_polling_bad_connection(self):
packets=[packet.Packet(packet.MESSAGE, {'foo': 'bar'})])
c._send_request.mock.assert_called_once_with(
'POST', 'http://foo', body=p.encode(),
headers={'Content-Type': 'application/octet-stream'})
headers={'Content-Type': 'application/octet-stream'}, timeout=5)
self.assertEqual(c.state, 'connected')

def test_write_loop_polling_bad_status(self):
Expand All @@ -1119,7 +1127,7 @@ def test_write_loop_polling_bad_status(self):
packets=[packet.Packet(packet.MESSAGE, {'foo': 'bar'})])
c._send_request.mock.assert_called_once_with(
'POST', 'http://foo', body=p.encode(),
headers={'Content-Type': 'application/octet-stream'})
headers={'Content-Type': 'application/octet-stream'}, timeout=5)
self.assertEqual(c.state, 'disconnected')

def test_write_loop_websocket_one_packet(self):
Expand Down
35 changes: 26 additions & 9 deletions tests/common/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def test_logger(self):
c = client.Client(logger=my_logger)
self.assertEqual(c.logger, my_logger)

def test_custon_timeout(self):
c = client.Client()
self.assertEqual(c.request_timeout, 5)
c = client.Client(request_timeout=27)
self.assertEqual(c.request_timeout, 27)

def test_on_event(self):
c = client.Client()

Expand Down Expand Up @@ -282,7 +288,7 @@ def test_polling_connection_failed(self, _send_request, _time):
headers={'Foo': 'Bar'})
_send_request.assert_called_once_with(
'GET', 'http://foo/engine.io/?transport=polling&EIO=3&t=123.456',
headers={'Foo': 'Bar'})
headers={'Foo': 'Bar'}, timeout=5)

@mock.patch('engineio.client.Client._send_request')
def test_polling_connection_404(self, _send_request):
Expand Down Expand Up @@ -773,6 +779,8 @@ def test_read_loop_polling_disconnected(self):
@mock.patch('engineio.client.time.time', return_value=123.456)
def test_read_loop_polling_no_response(self, _time):
c = client.Client()
c.ping_interval = 25
c.ping_timeout = 5
c.state = 'connected'
c.base_url = 'http://foo'
c.queue = mock.MagicMock()
Expand All @@ -785,13 +793,16 @@ def test_read_loop_polling_no_response(self, _time):
c.queue.put.assert_called_once_with(None)
c.write_loop_task.join.assert_called_once_with()
c.ping_loop_task.join.assert_called_once_with()
c._send_request.assert_called_once_with('GET', 'http://foo&t=123.456')
c._send_request.assert_called_once_with('GET', 'http://foo&t=123.456',
timeout=30)
c._trigger_event.assert_called_once_with('disconnect',
run_async=False)

@mock.patch('engineio.client.time.time', return_value=123.456)
def test_read_loop_polling_bad_status(self, _time):
c = client.Client()
c.ping_interval = 25
c.ping_timeout = 5
c.state = 'connected'
c.base_url = 'http://foo'
c.queue = mock.MagicMock()
Expand All @@ -804,11 +815,14 @@ def test_read_loop_polling_bad_status(self, _time):
c.queue.put.assert_called_once_with(None)
c.write_loop_task.join.assert_called_once_with()
c.ping_loop_task.join.assert_called_once_with()
c._send_request.assert_called_once_with('GET', 'http://foo&t=123.456')
c._send_request.assert_called_once_with('GET', 'http://foo&t=123.456',
timeout=30)

@mock.patch('engineio.client.time.time', return_value=123.456)
def test_read_loop_polling_bad_packet(self, _time):
c = client.Client()
c.ping_interval = 25
c.ping_timeout = 60
c.state = 'connected'
c.base_url = 'http://foo'
c.queue = mock.MagicMock()
Expand All @@ -822,10 +836,13 @@ def test_read_loop_polling_bad_packet(self, _time):
c.queue.put.assert_called_once_with(None)
c.write_loop_task.join.assert_called_once_with()
c.ping_loop_task.join.assert_called_once_with()
c._send_request.assert_called_once_with('GET', 'http://foo&t=123.456')
c._send_request.assert_called_once_with('GET', 'http://foo&t=123.456',
timeout=65)

def test_read_loop_polling(self):
c = client.Client()
c.ping_interval = 25
c.ping_timeout = 5
c.state = 'connected'
c.base_url = 'http://foo'
c.queue = mock.MagicMock()
Expand Down Expand Up @@ -954,7 +971,7 @@ def test_write_loop_polling_one_packet(self):
packets=[packet.Packet(packet.MESSAGE, {'foo': 'bar'})])
c._send_request.assert_called_once_with(
'POST', 'http://foo', body=p.encode(),
headers={'Content-Type': 'application/octet-stream'})
headers={'Content-Type': 'application/octet-stream'}, timeout=5)

def test_write_loop_polling_three_packets(self):
c = client.Client()
Expand Down Expand Up @@ -983,7 +1000,7 @@ def test_write_loop_polling_three_packets(self):
])
c._send_request.assert_called_once_with(
'POST', 'http://foo', body=p.encode(),
headers={'Content-Type': 'application/octet-stream'})
headers={'Content-Type': 'application/octet-stream'}, timeout=5)

def test_write_loop_polling_two_packets_done(self):
c = client.Client()
Expand All @@ -1010,7 +1027,7 @@ def test_write_loop_polling_two_packets_done(self):
])
c._send_request.assert_called_once_with(
'POST', 'http://foo', body=p.encode(),
headers={'Content-Type': 'application/octet-stream'})
headers={'Content-Type': 'application/octet-stream'}, timeout=5)
self.assertEqual(c.state, 'connected')

def test_write_loop_polling_bad_connection(self):
Expand All @@ -1034,7 +1051,7 @@ def test_write_loop_polling_bad_connection(self):
packets=[packet.Packet(packet.MESSAGE, {'foo': 'bar'})])
c._send_request.assert_called_once_with(
'POST', 'http://foo', body=p.encode(),
headers={'Content-Type': 'application/octet-stream'})
headers={'Content-Type': 'application/octet-stream'}, timeout=5)
self.assertEqual(c.state, 'connected')

def test_write_loop_polling_bad_status(self):
Expand All @@ -1058,7 +1075,7 @@ def test_write_loop_polling_bad_status(self):
packets=[packet.Packet(packet.MESSAGE, {'foo': 'bar'})])
c._send_request.assert_called_once_with(
'POST', 'http://foo', body=p.encode(),
headers={'Content-Type': 'application/octet-stream'})
headers={'Content-Type': 'application/octet-stream'}, timeout=5)
self.assertEqual(c.state, 'disconnected')

def test_write_loop_websocket_one_packet(self):
Expand Down

0 comments on commit 6666d6a

Please sign in to comment.