Skip to content

Commit

Permalink
Allow configuring underlying websocket connection with custom options
Browse files Browse the repository at this point in the history
(Fixes #293)
  • Loading branch information
bruce-y authored and miguelgrinberg committed Nov 2, 2022
1 parent 4a8a9a6 commit 45e97b8
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 16 deletions.
31 changes: 19 additions & 12 deletions src/engineio/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ class AsyncClient(client.Client):
leave interrupt handling to the calling application.
Interrupt handling can only be enabled when the
client instance is created in the main thread.
:param websocket_extra_options: Dictionary containing additional keyword
arguments passed to
``aiohttp.ws_connect()``.
"""

def is_asyncio_based(self):
return True

Expand Down Expand Up @@ -297,19 +301,22 @@ async def _connect_websocket(self, url, headers, engineio_path):
break
self.http.cookie_jar.update_cookies(cookies)

extra_options = {'timeout': self.request_timeout}
if not self.ssl_verify:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
extra_options['ssl'] = ssl_context

# combine internally generated options with the ones supplied by the
# caller. The caller's options take precedence.
headers.update(self.websocket_extra_options.pop('headers', {}))
extra_options['headers'] = headers
extra_options.update(self.websocket_extra_options)

try:
if not self.ssl_verify:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
ws = await self.http.ws_connect(
websocket_url + self._get_url_timestamp(),
headers=headers, ssl=ssl_context,
timeout=self.request_timeout)
else:
ws = await self.http.ws_connect(
websocket_url + self._get_url_timestamp(),
headers=headers, timeout=self.request_timeout)
ws = await self.http.ws_connect(
websocket_url + self._get_url_timestamp(), **extra_options)
except (aiohttp.client_exceptions.WSServerHandshakeError,
aiohttp.client_exceptions.ServerConnectionError,
aiohttp.client_exceptions.ClientConnectionError):
Expand Down
20 changes: 16 additions & 4 deletions src/engineio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,15 @@ class Client(object):
leave interrupt handling to the calling application.
Interrupt handling can only be enabled when the
client instance is created in the main thread.
:param websocket_extra_options: Dictionary containing additional keyword
arguments passed to
``websocket.create_connection()``.
"""
event_names = ['connect', 'disconnect', 'message']

def __init__(self, logger=False, json=None, request_timeout=5,
http_session=None, ssl_verify=True, handle_sigint=True):
http_session=None, ssl_verify=True, handle_sigint=True,
websocket_extra_options=None):
global original_signal_handler
if handle_sigint and original_signal_handler is None and \
threading.current_thread() == threading.main_thread():
Expand All @@ -97,6 +101,7 @@ def __init__(self, logger=False, json=None, request_timeout=5,
self.queue = None
self.state = 'disconnected'
self.ssl_verify = ssl_verify
self.websocket_extra_options = websocket_extra_options or {}

if json is not None:
packet.Packet.json = json
Expand Down Expand Up @@ -414,11 +419,18 @@ def _connect_websocket(self, url, headers, engineio_path):

if not self.ssl_verify:
extra_options['sslopt'] = {"cert_reqs": ssl.CERT_NONE}

# combine internally generated options with the ones supplied by the
# caller. The caller's options take precedence.
headers.update(self.websocket_extra_options.pop('header', {}))
extra_options['header'] = headers
extra_options['cookie'] = cookies
extra_options['enable_multithread'] = True
extra_options['timeout'] = self.request_timeout
extra_options.update(self.websocket_extra_options)
try:
ws = websocket.create_connection(
websocket_url + self._get_url_timestamp(), header=headers,
cookie=cookies, enable_multithread=True,
timeout=self.request_timeout, **extra_options)
websocket_url + self._get_url_timestamp(), **extra_options)
except (ConnectionError, IOError, websocket.WebSocketException):
if upgrade:
self.logger.warning(
Expand Down
24 changes: 24 additions & 0 deletions tests/asyncio/test_asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,30 @@ def test_websocket_connection_failed(self, _time):
timeout=5
)

@mock.patch('engineio.client.time.time', return_value=123.456)
def test_websocket_connection_extra(self, _time):
c = asyncio_client.AsyncClient(websocket_extra_options={
'headers': {'Baz': 'Qux'},
'timeout': 10
})
c.http = mock.MagicMock(closed=False)
c.http.ws_connect = AsyncMock(
side_effect=[aiohttp.client_exceptions.ServerConnectionError()]
)
with pytest.raises(exceptions.ConnectionError):
_run(
c.connect(
'http://foo',
transports=['websocket'],
headers={'Foo': 'Bar'},
)
)
c.http.ws_connect.mock.assert_called_once_with(
'ws://foo/engine.io/?transport=websocket&EIO=4&t=123.456',
headers={'Foo': 'Bar', 'Baz': 'Qux'},
timeout=10,
)

@mock.patch('engineio.client.time.time', return_value=123.456)
def test_websocket_upgrade_failed(self, _time):
c = asyncio_client.AsyncClient()
Expand Down
20 changes: 20 additions & 0 deletions tests/common/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,26 @@ def test_websocket_connection_failed(self, create_connection, _time):
timeout=5
)

@mock.patch('engineio.client.time.time', return_value=123.456)
@mock.patch(
'engineio.client.websocket.create_connection',
side_effect=[ConnectionError],
)
def test_websocket_connection_extra(self, create_connection, _time):
c = client.Client(websocket_extra_options={'header': {'Baz': 'Qux'},
'timeout': 10})
with pytest.raises(exceptions.ConnectionError):
c.connect(
'http://foo', transports=['websocket'], headers={'Foo': 'Bar'}
)
create_connection.assert_called_once_with(
'ws://foo/engine.io/?transport=websocket&EIO=4&t=123.456',
header={'Foo': 'Bar', 'Baz': 'Qux'},
cookie=None,
enable_multithread=True,
timeout=10
)

@mock.patch('engineio.client.time.time', return_value=123.456)
@mock.patch(
'engineio.client.websocket.create_connection',
Expand Down

0 comments on commit 45e97b8

Please sign in to comment.