Skip to content

Commit

Permalink
user sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Dec 30, 2018
1 parent e6614a9 commit 561efad
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 0 deletions.
58 changes: 58 additions & 0 deletions engineio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,64 @@ async def send(self, sid, data, binary=None):
await socket.send(packet.Packet(packet.MESSAGE, data=data,
binary=binary))

async def get_session(self, sid):
"""Return the user session for a client.
:param sid: The session id of the client.
The return value is a dictionary. Modifications made to this
dictionary are not guaranteed to be preserved. If you want to modify
the user session, use the ``session`` context manager instead.
"""
socket = self._get_socket(sid)
return socket.session

async def save_session(self, sid, session):
"""Store the user session for a client.
:param sid: The session id of the client.
:param session: The session dictionary.
"""
socket = self._get_socket(sid)
socket.session = session

def session(self, sid):
"""Return the user session for a client with context manager syntax.
:param sid: The session id of the client.
This is a context manager that returns the user session dictionary for
the client. Any changes that are made to this dictionary inside the
context manager block are saved back to the session. Example usage::
@eio.on('connect')
def on_connect(sid, environ):
username = authenticate_user(environ)
if not username:
return False
with eio.session(sid) as session:
session['username'] = username
@eio.on('message')
def on_message(sid, msg):
async with eio.session(sid) as session:
print('received message from ', session['username'])
"""
class _session_context_manager(object):
def __init__(self, server, sid):
self.server = server
self.sid = sid
self.session = None

async def __aenter__(self):
self.session = await self.server.get_session(sid)
return self.session

async def __aexit__(self, *args):
await self.server.save_session(sid, self.session)

return _session_context_manager(self, sid)

async def disconnect(self, sid=None):
"""Disconnect a client.
Expand Down
59 changes: 59 additions & 0 deletions engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,65 @@ def send(self, sid, data, binary=None):
return
socket.send(packet.Packet(packet.MESSAGE, data=data, binary=binary))

def get_session(self, sid):
"""Return the user session for a client.
:param sid: The session id of the client.
The return value is a dictionary. Modifications made to this
dictionary are not guaranteed to be preserved unless
``save_session()`` is called, or when the ``session`` context manager
is used.
"""
socket = self._get_socket(sid)
return socket.session

def save_session(self, sid, session):
"""Store the user session for a client.
:param sid: The session id of the client.
:param session: The session dictionary.
"""
socket = self._get_socket(sid)
socket.session = session

def session(self, sid):
"""Return the user session for a client with context manager syntax.
:param sid: The session id of the client.
This is a context manager that returns the user session dictionary for
the client. Any changes that are made to this dictionary inside the
context manager block are saved back to the session. Example usage::
@eio.on('connect')
def on_connect(sid, environ):
username = authenticate_user(environ)
if not username:
return False
with eio.session(sid) as session:
session['username'] = username
@eio.on('message')
def on_message(sid, msg):
with eio.session(sid) as session:
print('received message from ', session['username'])
"""
class _session_context_manager(object):
def __init__(self, server, sid):
self.server = server
self.sid = sid
self.session = None

def __enter__(self):
self.session = self.server.get_session(sid)
return self.session

def __exit__(self, *args):
self.server.save_session(sid, self.session)

return _session_context_manager(self, sid)

def disconnect(self, sid=None):
"""Disconnect a client.
Expand Down
1 change: 1 addition & 0 deletions engineio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self, server, sid):
self.upgraded = False
self.closing = False
self.closed = False
self.session = {}

def create_queue(self):
return getattr(self.server._async['queue'],
Expand Down
13 changes: 13 additions & 0 deletions tests/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def _get_mock_socket(self):
mock_socket.handle_post_request = AsyncMock()
mock_socket.check_ping_timeout = AsyncMock()
mock_socket.close = AsyncMock()
mock_socket.session = {}
return mock_socket

@classmethod
Expand Down Expand Up @@ -138,6 +139,18 @@ def test_attach(self, import_module):
s.attach('app', engineio_path='jkl/')
a._async['create_route'].assert_called_with('app', s, '/jkl/')

def test_session(self):
s = asyncio_server.AsyncServer()
s.sockets['foo'] = mock_socket = self._get_mock_socket()

async def _func():
async with s.session('foo') as session:
await s.sleep(0)
session['username'] = 'bar'
self.assertEqual(await s.get_session('foo'), {'username': 'bar'})

_run(_func())

def test_disconnect(self):
s = asyncio_server.AsyncServer()
s.sockets['foo'] = mock_socket = self._get_mock_socket()
Expand Down
10 changes: 10 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _get_mock_socket(self):
mock_socket.closed = False
mock_socket.closing = False
mock_socket.upgraded = False
mock_socket.session = {}
return mock_socket

@classmethod
Expand Down Expand Up @@ -299,6 +300,15 @@ def bar(sid, data):
r = s._trigger_event('message', 3, 4, run_async=False)
self.assertEqual(r, None)

def test_session(self):
s = server.Server()
mock_socket = self._get_mock_socket()
s.sockets['foo'] = mock_socket
with s.session('foo') as session:
self.assertEqual(session, {})
session['username'] = 'bar'
self.assertEqual(s.get_session('foo'), {'username': 'bar'})

def test_close_one_socket(self):
s = server.Server()
mock_socket = self._get_mock_socket()
Expand Down

0 comments on commit 561efad

Please sign in to comment.