Skip to content

Commit

Permalink
Add piccolo db
Browse files Browse the repository at this point in the history
Also use FastAPI dependency for async db transaction.
  • Loading branch information
sayanarijit committed Feb 14, 2024
1 parent 9907adb commit d9bda0b
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 15 deletions.
8 changes: 8 additions & 0 deletions apphelpers/db/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
try:
from apphelpers.db.piccolo import * # noqa: F401, F403

print("apphelpers: using Piccolo")
except ImportError:
from apphelpers.db.peewee import * # noqa: F401, F403

print("apphelpers: using Peewee")
22 changes: 14 additions & 8 deletions apphelpers/db/peewee.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import logging
import os
from contextlib import contextmanager
from enum import Enum

import pytz
Expand Down Expand Up @@ -62,6 +63,17 @@ def created():
return DateTimeTZField(default=lambda: datetime.datetime.now(pytz.utc))


@contextmanager
def dbtransaction_ctx(db):
if not db.in_transaction():
with db.connection_context():
with db.atomic():
yield
else:
with db.atomic():
yield


def dbtransaction(db):
"""
wrapper that make db transactions automic
Expand All @@ -72,14 +84,8 @@ def dbtransaction(db):
def wrapper(f):
@wraps(f)
def f_wrapped(*args, **kw):
if not db.in_transaction():
with db.connection_context():
with db.atomic():
result = f(*args, **kw)
else:
with db.atomic():
result = f(*args, **kw)
return result
with dbtransaction_ctx(db):
return f(*args, **kw)

return f_wrapped

Expand Down
55 changes: 55 additions & 0 deletions apphelpers/db/piccolo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from contextlib import asynccontextmanager
from typing import List, Optional, Set, Type, cast

from piccolo.engine import engine_finder
from piccolo.engine.postgres import PostgresEngine
from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync


@asynccontextmanager
async def connection_pool_lifespan(engine: Optional[PostgresEngine] = None, **kwargs):
if engine := engine or cast(PostgresEngine, engine_finder()):
print("db: starting connection pool")
await engine.start_connection_pool(**kwargs)
yield
print("db: closing connection pool")
await engine.close_connection_pool()


@asynccontextmanager
async def dbtransaction_ctx(engine: Optional[PostgresEngine] = None, allow_nested=True):
if engine := engine or cast(PostgresEngine, engine_finder()):
async with engine.transaction(allow_nested=allow_nested):
yield


class BaseTable(Table):
@classmethod
def all_column_names(cls) -> Set[str]:
return {col._meta.name for col in cls._meta.columns}


def get_sub_tables(basetable: Type[Table]) -> List[type[Table]]:
tables = []
for subtable in basetable.__subclasses__():
tables.append(subtable)
tables.extend(get_sub_tables(subtable))
return tables


def setup_db(tables: list[Type[Table]]):
create_db_tables_sync(*tables, if_not_exists=True)


def setup_db_from_basetable(basetable: Type[Table]):
tables = get_sub_tables(basetable)
setup_db(tables)


def destroy_db(tables: List[Type[Table]]):
drop_db_tables_sync(*tables)


def destroy_db_from_basetable(basetable: Type[Table]):
tables = get_sub_tables(basetable)
destroy_db(tables)
9 changes: 3 additions & 6 deletions apphelpers/rest/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi.routing import APIRoute
from starlette.requests import Request

from apphelpers.db.peewee import dbtransaction
from apphelpers.db import dbtransaction_ctx
from apphelpers.errors.fastapi import (
HTTP401Unauthorized,
HTTP403Forbidden,
Expand Down Expand Up @@ -280,7 +280,6 @@ async def custom_route_handler(_request: Request):

class APIFactory:
def __init__(self, sessiondb_conn=None, urls_prefix="", site_identifier=None):
self.db_tr_wrapper = phony
self.access_wrapper = phony
self.multi_site_enabled = False
self.site_identifier = site_identifier
Expand All @@ -297,7 +296,7 @@ def enable_multi_site(self, site_identifier: str):
self.site_identifier = site_identifier

def setup_db_transaction(self, db):
self.db_tr_wrapper = dbtransaction(db)
self.router.dependencies.append(Depends(dbtransaction_ctx(db)))

def setup_honeybadger_monitoring(self):
api_key = settings.HONEYBADGER_API_KEY
Expand Down Expand Up @@ -499,9 +498,7 @@ def build(self, method, method_args, method_kw, f):
f"[{method.__name__.upper()}] => {f.__module__}:{f.__name__}",
)
m = method(*method_args, **method_kw)
f = self.access_wrapper(
self.honeybadger_wrapper(self.db_tr_wrapper(raise_not_found_on_none(f)))
)
f = self.access_wrapper(self.honeybadger_wrapper(raise_not_found_on_none(f)))
# NOTE: ^ wrapper ordering is important. access_wrapper needs request which
# others don't. If access_wrapper comes late in the order it won't be passed
# request parameter.
Expand Down
2 changes: 1 addition & 1 deletion apphelpers/rest/hug.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from falcon import HTTPForbidden, HTTPNotFound, HTTPUnauthorized
from hug.decorators import wraps

from apphelpers.db.peewee import dbtransaction
from apphelpers.db import dbtransaction
from apphelpers.errors.hug import BaseError, InvalidSessionError
from apphelpers.loggers import api_logger
from apphelpers.rest import endpoint as ep
Expand Down
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ types-pytz
types-redis
redis
loguru
piccolo[postgres]
70 changes: 70 additions & 0 deletions tests/test_piccolo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import asyncio

import settings
from piccolo import columns as col
from piccolo.engine.postgres import PostgresEngine

from apphelpers.db.piccolo import (
BaseTable,
dbtransaction_ctx,
destroy_db_from_basetable,
setup_db_from_basetable,
)

db = PostgresEngine(
config=dict(
host=settings.DB_HOST,
database=settings.DB_NAME,
user=settings.DB_USER,
password=settings.DB_PASS,
)
)


class Book(BaseTable, db=db):
name = col.Text()


async def _add_book(name):
await Book.insert(Book(name=name)).run()


async def _add_book_loser(name):
await _add_book(name)
loser # will raise # noqa: F821


def setup_module():
setup_db_from_basetable(BaseTable)


def teardown_module():
destroy_db_from_basetable(BaseTable)


async def add_with_tr():
async with dbtransaction_ctx(db):
name = "The Pillars of the Earth"
await _add_book(name)
names = [b[Book.name] for b in await Book.select().run()]
assert name in names

try:
async with dbtransaction_ctx(db):
name = "The Cathedral and the Bazaar"
await _add_book_loser(name)
except NameError:
pass

names = [b[Book.name] for b in await Book.select().run()]
assert name not in names

async with dbtransaction_ctx(db):
name = "The Ego Trick"
await _add_book(name)
names = [b[Book.name] for b in await Book.select().run()]
assert name in names


def test_add_with_tr():
asyncio.run(add_with_tr())

0 comments on commit d9bda0b

Please sign in to comment.