Skip to content

Commit

Permalink
Added site_ctx for FastAPI
Browse files Browse the repository at this point in the history
  • Loading branch information
gauravr committed Dec 22, 2023
1 parent a57d210 commit 21c8b4a
Show file tree
Hide file tree
Showing 12 changed files with 119 additions and 15 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
History
=======

0.92.0 (2023-12-23)
-------------------
* site_ctx implementation for FastAPI
* user_agent directive for FastAPI & Hug
* ignore_site_ctx implementation for FastAPI & Hug
* count_matched_keys implementation for ReadOnlyCachedModel

0.91.0 (2023-12-22)
-------------------
* Breaking: moved apphelpers.sessions.whoami to apphelpers.rest.{hug/fastapi}.whoami
Expand Down
2 changes: 1 addition & 1 deletion apphelpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

__author__ = """Scroll Tech"""
__email__ = "[email protected]"
__version__ = "0.91.0"
__version__ = "0.92.0"
1 change: 1 addition & 0 deletions apphelpers/rest/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class User:
email: Optional[str] = None
mobile: Optional[str] = None
site_groups: Dict[int, int] = field(default_factory=dict)
site_ctx: Optional[int] = None

def to_dict(self):
return asdict(self)
Expand Down
5 changes: 5 additions & 0 deletions apphelpers/rest/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,8 @@ def decorator(func):
def not_found_on_none(func):
func.not_found_on_none = True
return func


def ignore_site_ctx(func):
func.ignore_site_ctx = True
return func
50 changes: 45 additions & 5 deletions apphelpers/rest/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ async def get_raw_body(request: Request):
return request.body()


async def get_user_agent(request: Request):
return request.headers.get("USER-AGENT", "")


user = Depends(get_current_user)
user_id = Depends(get_current_user_id)
user_name = Depends(get_current_user_name)
Expand All @@ -110,6 +114,7 @@ async def get_raw_body(request: Request):
domain = Depends(get_current_domain)
raw_body = Depends(get_raw_body)
json_body = Depends(get_json_body)
user_agent = Depends(get_user_agent)


class SecureRouter(APIRoute):
Expand All @@ -123,13 +128,14 @@ def get_route_handler(self):
original_route_handler = super().get_route_handler()

async def custom_route_handler(_request: Request):
uid, groups, name, email, mobile, site_groups = (
uid, groups, name, email, mobile, site_groups, site_ctx = (
None,
[],
"",
None,
None,
{},
None,
)

token = _request.headers.get("Authorization")
Expand All @@ -143,15 +149,17 @@ async def custom_route_handler(_request: Request):
"email",
"mobile",
"site_groups",
"site_ctx",
],
)
uid, name, groups, email, mobile, site_groups = (
uid, name, groups, email, mobile, site_groups, site_ctx = (
session["uid"],
session["name"],
session["groups"],
session["email"],
session["mobile"],
session["site_groups"],
session["site_ctx"],
)

_request.state.user = User(
Expand All @@ -162,6 +170,7 @@ async def custom_route_handler(_request: Request):
email=email,
mobile=mobile,
site_groups=site_groups,
site_ctx=site_ctx,
)

return await original_route_handler(_request)
Expand Down Expand Up @@ -194,22 +203,39 @@ def get_route_handler(self):
original_route_handler = super().get_route_handler()

async def custom_route_handler(_request: Request):
uid, groups, name, email, mobile, site_groups = None, [], "", None, None, {}
uid, groups, name, email, mobile, site_groups, site_ctx = (
None,
[],
"",
None,
None,
{},
None,
)

token = _request.headers.get("Authorization")
if token:
try:
session = self.sessions.get( # type: ignore
token,
["uid", "name", "groups", "email", "mobile", "site_groups"],
[
"uid",
"name",
"groups",
"email",
"mobile",
"site_groups",
"site_ctx",
],
)
uid, name, groups, email, mobile, site_groups = (
uid, name, groups, email, mobile, site_groups, site_ctx = (
session["uid"],
session["name"],
session["groups"],
session["email"],
session["mobile"],
session["site_groups"],
session["site_ctx"],
)
except InvalidSessionError:
pass
Expand All @@ -222,6 +248,7 @@ async def custom_route_handler(_request: Request):
email=email,
mobile=mobile,
site_groups=site_groups,
site_ctx=site_ctx,
)
return await original_route_handler(_request)

Expand Down Expand Up @@ -371,11 +398,24 @@ def multisite_access_wrapper(f):
@wraps(f)
async def wrapper(_request, *args, **kw):
user: User = _request.state.user
site_id = (
int(kw[self.site_identifier])
if self.site_identifier in kw
else None
)

# this is authentication part
if not user.id:
raise HTTP401Unauthorized("Invalid or expired session")

# bound site authorization
if (
user.site_ctx
and site_id != user.site_ctx
and getattr(f, "ignore_site_ctx", False) is False
):
raise HTTP401Unauthorized("Invalid or expired session")

# this is authorization part
groups = set(user.groups)
if self.site_identifier in kw:
Expand Down
11 changes: 10 additions & 1 deletion apphelpers/rest/hug.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ def user_mobile(default=None, request=None, **kwargs):
return request.context["user"].mobile


@hug.directive()
def user_agent(default=None, request=None, **kwargs):
return request.headers.get("USER-AGENT", "")


@dataclass
class User:
sid: str = None
Expand Down Expand Up @@ -376,7 +381,11 @@ def wrapper(request, *args, **kw):
raise HTTPUnauthorized("Invalid or expired session")

# bound site authorization
if user.site_ctx and site_id != user.site_ctx:
if (
user.site_ctx
and site_id != user.site_ctx
and getattr(f, "ignore_site_ctx", False) is False
):
raise HTTPUnauthorized("Invalid or expired session")

# this is authorization part
Expand Down
2 changes: 1 addition & 1 deletion apphelpers/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def create(
if extras:
session_dict.update(extras)
session = {k: pickle.dumps(v) for k, v in session_dict.items()}
self.rconn.hmset(key, session)
self.rconn.hset(key, mapping=session)

if uid:
rev_key = rev_lookup_key(uid, site_ctx)
Expand Down
5 changes: 5 additions & 0 deletions apphelpers/utilities/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def get_count(cls, **data: Any) -> int:
count: Optional[Any] = cls.connection.get(key)
return int(count) if count else 0

@classmethod
def count_matched_keys(cls, **data) -> int:
keys = cls._get_matched_keys(data)
return len(keys)


class ReadWriteCachedModel(ReadOnlyCachedModel):
"""
Expand Down
22 changes: 17 additions & 5 deletions fastapi_tests/app/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from apphelpers.rest import endpoint as ep
from apphelpers.rest.fastapi import (
json_body,
user,
user_id,
)
from apphelpers.rest.fastapi import json_body, user, user_id, user_agent


def echo(word, user=user):
Expand Down Expand Up @@ -68,6 +64,17 @@ async def echo_site_groups_async(site_id: int, user=user):
return user.site_groups[site_id]


@ep.login_required
async def echo_user_agent_async(user_agent=user_agent):
return user_agent


@ep.login_required
@ep.ignore_site_ctx
async def echo_user_agent_without_site_ctx_async(user_agent=user_agent):
return user_agent


def setup_routes(factory):
factory.get("/echo/{word}")(echo)
factory.get("/echo-async/{word}")(echo_async)
Expand All @@ -86,3 +93,8 @@ def setup_routes(factory):

factory.get("/sites/{site_id}/echo-groups")(echo_site_groups)
factory.get("/sites/{site_id}/echo-groups-async")(echo_site_groups_async)

factory.get("/echo-user-agent-async")(echo_user_agent_async)
factory.get("/echo-user-agent-without-site-ctx-async")(
echo_user_agent_without_site_ctx_async
)
25 changes: 25 additions & 0 deletions fastapi_tests/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,28 @@ def test_not_found_on_none_async():

url = base_url + "snakes-async"
assert requests.get(url).status_code == 404


def test_user_agent_async_and_site_ctx():
url = base_url + "echo-user-agent-async"

headers = {"Authorization": sessionsdb.create(uid=1214)}
response = requests.get(url, headers=headers)
assert response.status_code == 200
assert "python-requests" in response.text

headers = {"Authorization": sessionsdb.create(uid=1215, site_ctx=4011)}
response = requests.get(url, headers=headers)
assert response.status_code == 401

url = base_url + "echo-user-agent-without-site-ctx-async"

headers = {"Authorization": sessionsdb.create(uid=1214)}
response = requests.get(url, headers=headers)
assert response.status_code == 200
assert "python-requests" in response.text

headers = {"Authorization": sessionsdb.create(uid=1215, site_ctx=4011)}
response = requests.get(url, headers=headers)
assert response.status_code == 200
assert "python-requests" in response.text
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.91.0
current_version = 0.92.0
commit = True
tag = True

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@
test_suite="tests",
tests_require=test_requirements,
url="https://github.com/scrolltech/apphelpers",
version="0.91.0",
version="0.92.0",
zip_safe=False,
)

0 comments on commit 21c8b4a

Please sign in to comment.