Skip to content

Commit

Permalink
Accept an initialized requests or aiohttp session object
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Sep 1, 2020
1 parent 72ddc6a commit f371ad1
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 37 deletions.
8 changes: 6 additions & 2 deletions engineio/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class AsyncClient(client.Client):
versions.
:param request_timeout: A timeout in seconds for requests. The default is
5 seconds.
:param http_session: an initialized ``aiohttp.ClientSession`` object to be
used when sending requests to the server. Use it if
you need to add special client options such as proxy
servers, SSL certificates, etc.
:param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to
skip SSL certificate verification, allowing
connections to servers with self signed certificates.
Expand All @@ -57,7 +61,7 @@ class AsyncClient(client.Client):
def is_asyncio_based(self):
return True

async def connect(self, url, headers={}, transports=None,
async def connect(self, url, headers=None, transports=None,
engineio_path='engine.io'):
"""Connect to an Engine.IO server.
Expand Down Expand Up @@ -99,7 +103,7 @@ async def connect(self, url, headers={}, transports=None,
self.transports = transports or valid_transports
self.queue = self.create_queue()
return await getattr(self, '_connect_' + self.transports[0])(
url, headers, engineio_path)
url, headers or {}, engineio_path)

async def wait(self):
"""Wait until the connection with the server ends.
Expand Down
62 changes: 56 additions & 6 deletions engineio/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from base64 import b64encode
import logging
try:
import queue
Expand Down Expand Up @@ -63,6 +64,10 @@ class Client(object):
versions.
:param request_timeout: A timeout in seconds for requests. The default is
5 seconds.
:param http_session: an initialized ``requests.Session`` object to be used
when sending requests to the server. Use it if you
need to add special client options such as proxy
servers, SSL certificates, etc.
:param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to
skip SSL certificate verification, allowing
connections to servers with self signed certificates.
Expand All @@ -74,6 +79,7 @@ def __init__(self,
logger=False,
json=None,
request_timeout=5,
http_session=None,
ssl_verify=True):
global original_signal_handler
if original_signal_handler is None and \
Expand All @@ -89,7 +95,7 @@ def __init__(self,
self.ping_interval = None
self.ping_timeout = None
self.pong_received = True
self.http = None
self.http = http_session
self.ws = None
self.read_loop_task = None
self.write_loop_task = None
Expand Down Expand Up @@ -151,7 +157,7 @@ def set_handler(handler):
return set_handler
set_handler(handler)

def connect(self, url, headers={}, transports=None,
def connect(self, url, headers=None, transports=None,
engineio_path='engine.io'):
"""Connect to an Engine.IO server.
Expand Down Expand Up @@ -185,7 +191,7 @@ def connect(self, url, headers={}, transports=None,
self.transports = transports or valid_transports
self.queue = self.create_queue()
return getattr(self, '_connect_' + self.transports[0])(
url, headers, engineio_path)
url, headers or {}, engineio_path)

def wait(self):
"""Wait until the connection with the server ends.
Expand Down Expand Up @@ -353,10 +359,12 @@ def _connect_websocket(self, url, headers, engineio_path):
self.logger.info(
'Attempting WebSocket connection to ' + websocket_url)

# get the cookies from the long-polling connection so that they can
# also be sent the the WebSocket route
# get cookies and other settings from the long-polling connection
# so that they are preserved when connecting to the WebSocket route
cookies = None
extra_options = {}
if self.http:
# cookies
cookies = '; '.join(["{}={}".format(cookie.name, cookie.value)
for cookie in self.http.cookies])
for header, value in headers.items():
Expand All @@ -367,7 +375,49 @@ def _connect_websocket(self, url, headers, engineio_path):
del headers[header]
break

extra_options = {}
# auth
if 'Authorization' not in headers and self.http.auth is not None:
if not isinstance(self.http.auth, tuple): # pragma: no cover
raise ValueError('Only basic authentication is supported')
basic_auth = '{}:{}'.format(
self.http.auth[0], self.http.auth[1]).encode('utf-8')
basic_auth = b64encode(basic_auth).decode('utf-8')
headers['Authorization'] = 'Basic ' + basic_auth

# cert
# this can be given as ('certfile', 'keyfile') or just 'certfile'
if isinstance(self.http.cert, tuple):
extra_options['sslopt'] = {
'certfile': self.http.cert[0],
'keyfile': self.http.cert[1]}
elif self.http.cert:
extra_options['sslopt'] = {'certfile': self.http.cert}

# proxies
if self.http.proxies:
proxy_url = None
if websocket_url.startswith('ws://'):
proxy_url = self.http.proxies.get(
'ws', self.http.proxies.get('http'))
else: # wss://
proxy_url = self.http.proxies.get(
'wss', self.http.proxies.get('https'))
if proxy_url:
parsed_url = urllib.parse.urlparse(
proxy_url if '://' in proxy_url
else 'scheme://' + proxy_url)
print(parsed_url)
extra_options['http_proxy_host'] = parsed_url.hostname
extra_options['http_proxy_port'] = parsed_url.port
extra_options['http_proxy_auth'] = (
(parsed_url.username, parsed_url.password)
if parsed_url.username or parsed_url.password
else None)

# verify
if not self.http.verify:
self.ssl_verify = False

if not self.ssl_verify:
extra_options['sslopt'] = {"cert_reqs": ssl.CERT_NONE}
try:
Expand Down
11 changes: 0 additions & 11 deletions tests/asyncio/test_asyncio_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import ssl
import sys
import time
import unittest

import six
Expand Down Expand Up @@ -347,7 +346,6 @@ def test_polling_connection_successful(self):
on_connect = AsyncMock()
c.on('connect', on_connect)
_run(c.connect('http://foo'))
time.sleep(0.1)

c._ping_loop.mock.assert_called_once_with()
c._read_loop_polling.mock.assert_called_once_with()
Expand Down Expand Up @@ -391,7 +389,6 @@ def test_polling_https_noverify_connection_successful(self):
on_connect = AsyncMock()
c.on('connect', on_connect)
_run(c.connect('https://foo'))
time.sleep(0.1)

c._ping_loop.mock.assert_called_once_with()
c._read_loop_polling.mock.assert_called_once_with()
Expand Down Expand Up @@ -437,7 +434,6 @@ def test_polling_connection_with_more_packets(self):
on_connect = AsyncMock()
c.on('connect', on_connect)
_run(c.connect('http://foo'))
time.sleep(0.1)
assert c._receive_packet.mock.call_count == 1
assert (
c._receive_packet.mock.call_args_list[0][0][0].packet_type
Expand Down Expand Up @@ -509,7 +505,6 @@ def test_polling_connection_not_upgraded(self):
on_connect = mock.MagicMock()
c.on('connect', on_connect)
_run(c.connect('http://foo'))
time.sleep(0.1)

c._connect_websocket.mock.assert_called_once_with(
'http://foo', {}, 'engine.io'
Expand Down Expand Up @@ -590,7 +585,6 @@ def test_websocket_connection_successful(self, _time):
on_connect = mock.MagicMock()
c.on('connect', on_connect)
_run(c.connect('ws://foo', transports=['websocket']))
time.sleep(0.1)

c._ping_loop.mock.assert_called_once_with()
c._read_loop_polling.mock.assert_not_called()
Expand Down Expand Up @@ -633,7 +627,6 @@ def test_websocket_https_noverify_connection_successful(self, _time):
on_connect = mock.MagicMock()
c.on('connect', on_connect)
_run(c.connect('wss://foo', transports=['websocket']))
time.sleep(0.1)

c._ping_loop.mock.assert_called_once_with()
c._read_loop_polling.mock.assert_not_called()
Expand Down Expand Up @@ -681,7 +674,6 @@ def test_websocket_connection_with_cookies(self, _time):
on_connect = mock.MagicMock()
c.on('connect', on_connect)
_run(c.connect('ws://foo', transports=['websocket']))
time.sleep(0.1)
c.http.ws_connect.mock.assert_called_once_with(
'ws://foo/engine.io/?transport=websocket&EIO=3&t=123.456',
headers={},
Expand Down Expand Up @@ -717,7 +709,6 @@ def test_websocket_connection_with_cookie_header(self, _time):
transports=['websocket'],
)
)
time.sleep(0.1)
c.http.ws_connect.mock.assert_called_once_with(
'ws://foo/engine.io/?transport=websocket&EIO=3&t=123.456',
headers={},
Expand Down Expand Up @@ -760,7 +751,6 @@ def test_websocket_connection_with_cookies_and_headers(self, _time):
transports=['websocket'],
)
)
time.sleep(0.1)
c.http.ws_connect.mock.assert_called_once_with(
'ws://foo/engine.io/?transport=websocket&EIO=3&t=123.456',
headers={'Foo': 'Bar'},
Expand Down Expand Up @@ -823,7 +813,6 @@ def test_websocket_upgrade_successful(self):
on_connect = mock.MagicMock()
c.on('connect', on_connect)
assert _run(c.connect('ws://foo', transports=['websocket']))
time.sleep(0.1)

c._ping_loop.mock.assert_called_once_with()
c._read_loop_polling.mock.assert_not_called()
Expand Down
Loading

0 comments on commit f371ad1

Please sign in to comment.