Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add piccolo db #104

Merged
merged 8 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions apphelpers/db/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
try:
from apphelpers.db.piccolo import * # noqa: F401, F403

print("apphelpers: using Piccolo")
except ImportError:
from apphelpers.db.peewee import * # noqa: F401, F403

print("apphelpers: using Peewee")
22 changes: 14 additions & 8 deletions apphelpers/db/peewee.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import logging
import os
from contextlib import contextmanager
from enum import Enum

import pytz
Expand Down Expand Up @@ -62,6 +63,17 @@ def created():
return DateTimeTZField(default=lambda: datetime.datetime.now(pytz.utc))


@contextmanager
def dbtransaction_ctx(db):
if not db.in_transaction():
with db.connection_context():
with db.atomic():
yield
else:
with db.atomic():
yield


def dbtransaction(db):
"""
wrapper that make db transactions automic
Expand All @@ -72,14 +84,8 @@ def dbtransaction(db):
def wrapper(f):
@wraps(f)
def f_wrapped(*args, **kw):
if not db.in_transaction():
with db.connection_context():
with db.atomic():
result = f(*args, **kw)
else:
with db.atomic():
result = f(*args, **kw)
return result
with dbtransaction_ctx(db):
return f(*args, **kw)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We won't be using it in FastAPI. This is only for hug compatibility (and apphelpers tests).


return f_wrapped

Expand Down
62 changes: 62 additions & 0 deletions apphelpers/db/piccolo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from contextlib import asynccontextmanager
from functools import wraps
from typing import List, Set, Type

from piccolo.engine.postgres import PostgresEngine
from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync


@asynccontextmanager
async def connection_pool_lifespan(engine: PostgresEngine, **kwargs):
print("db: starting connection pool")
await engine.start_connection_pool(**kwargs)
yield
print("db: closing connection pool")
await engine.close_connection_pool()


dbtransaction_ctx = PostgresEngine.transaction


def dbtransaction(engine: PostgresEngine, allow_nested=True):
def wrapper(f):
@wraps(f)
async def f_wrapped(*args, **kw):
async with dbtransaction_ctx(engine, allow_nested=allow_nested):
return await f(*args, **kw)

return f_wrapped

return wrapper


class BaseTable(Table):
@classmethod
def all_column_names(cls) -> Set[str]:
return {col._meta.name for col in cls._meta.columns}


def get_sub_tables(basetable: Type[Table]) -> List[Type[Table]]:
tables = []
for subtable in basetable.__subclasses__():
tables.append(subtable)
tables.extend(get_sub_tables(subtable))
return tables


def setup_db(tables: List[Type[Table]]):
create_db_tables_sync(*tables, if_not_exists=True)


def setup_db_from_basetable(basetable: Type[Table]):
tables = get_sub_tables(basetable)
setup_db(tables)


def destroy_db(tables: List[Type[Table]]):
drop_db_tables_sync(*tables)


def destroy_db_from_basetable(basetable: Type[Table]):
tables = get_sub_tables(basetable)
destroy_db(tables)
17 changes: 11 additions & 6 deletions apphelpers/rest/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi.routing import APIRoute
from starlette.requests import Request

from apphelpers.db.peewee import dbtransaction
from apphelpers.db import dbtransaction_ctx
from apphelpers.errors.fastapi import (
HTTP401Unauthorized,
HTTP403Forbidden,
Expand Down Expand Up @@ -128,6 +128,14 @@ async def get_user_agent(request: Request):
user_agent = Depends(get_user_agent)


def dbtransaction(engine, allow_nested=True):
async def dependency():
async with dbtransaction_ctx(engine, allow_nested=allow_nested):
yield

return Depends(dependency)


class SecureRouter(APIRoute):
sessions = None

Expand Down Expand Up @@ -280,7 +288,6 @@ async def custom_route_handler(_request: Request):

class APIFactory:
def __init__(self, sessiondb_conn=None, urls_prefix="", site_identifier=None):
self.db_tr_wrapper = phony
self.access_wrapper = phony
self.multi_site_enabled = False
self.site_identifier = site_identifier
Expand All @@ -297,7 +304,7 @@ def enable_multi_site(self, site_identifier: str):
self.site_identifier = site_identifier

def setup_db_transaction(self, db):
self.db_tr_wrapper = dbtransaction(db)
self.router.dependencies.append(dbtransaction(db))

def setup_honeybadger_monitoring(self):
api_key = settings.HONEYBADGER_API_KEY
Expand Down Expand Up @@ -499,9 +506,7 @@ def build(self, method, method_args, method_kw, f):
f"[{method.__name__.upper()}] => {f.__module__}:{f.__name__}",
)
m = method(*method_args, **method_kw)
f = self.access_wrapper(
self.honeybadger_wrapper(self.db_tr_wrapper(raise_not_found_on_none(f)))
)
f = self.access_wrapper(self.honeybadger_wrapper(raise_not_found_on_none(f)))
# NOTE: ^ wrapper ordering is important. access_wrapper needs request which
# others don't. If access_wrapper comes late in the order it won't be passed
# request parameter.
Expand Down
2 changes: 1 addition & 1 deletion apphelpers/rest/hug.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from falcon import HTTPForbidden, HTTPNotFound, HTTPUnauthorized
from hug.decorators import wraps

from apphelpers.db.peewee import dbtransaction
from apphelpers.db import dbtransaction
from apphelpers.errors.hug import BaseError, InvalidSessionError
from apphelpers.loggers import api_logger
from apphelpers.rest import endpoint as ep
Expand Down
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ types-pytz
types-redis
redis
loguru
piccolo[postgres]
70 changes: 70 additions & 0 deletions tests/test_piccolo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import asyncio

import settings
from piccolo import columns as col
from piccolo.engine.postgres import PostgresEngine

from apphelpers.db.piccolo import (
BaseTable,
dbtransaction_ctx,
destroy_db_from_basetable,
setup_db_from_basetable,
)

db = PostgresEngine(
config=dict(
host=settings.DB_HOST,
database=settings.DB_NAME,
user=settings.DB_USER,
password=settings.DB_PASS,
)
)


class Book(BaseTable, db=db):
name = col.Text()


async def _add_book(name):
await Book.insert(Book(name=name)).run()


async def _add_book_loser(name):
await _add_book(name)
loser # will raise # noqa: F821


def setup_module():
setup_db_from_basetable(BaseTable)


def teardown_module():
destroy_db_from_basetable(BaseTable)


async def add_with_tr():
async with dbtransaction_ctx(db):
name = "The Pillars of the Earth"
await _add_book(name)
names = [b[Book.name] for b in await Book.select().run()]
assert name in names

try:
async with dbtransaction_ctx(db):
name = "The Cathedral and the Bazaar"
await _add_book_loser(name)
except NameError:
pass

names = [b[Book.name] for b in await Book.select().run()]
assert name not in names

async with dbtransaction_ctx(db):
name = "The Ego Trick"
await _add_book(name)
names = [b[Book.name] for b in await Book.select().run()]
assert name in names


def test_add_with_tr():
asyncio.run(add_with_tr())
Loading