From d9bda0bb096a24714b8c9a4e1cd25adf939e896b Mon Sep 17 00:00:00 2001 From: Arijit Basu Date: Wed, 14 Feb 2024 19:27:54 +0530 Subject: [PATCH] Add piccolo db Also use FastAPI dependency for async db transaction. --- apphelpers/db/__init__.py | 8 +++++ apphelpers/db/peewee.py | 22 +++++++----- apphelpers/db/piccolo.py | 55 ++++++++++++++++++++++++++++++ apphelpers/rest/fastapi.py | 9 ++--- apphelpers/rest/hug.py | 2 +- requirements_dev.txt | 1 + tests/test_piccolo.py | 70 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 152 insertions(+), 15 deletions(-) create mode 100644 apphelpers/db/piccolo.py create mode 100644 tests/test_piccolo.py diff --git a/apphelpers/db/__init__.py b/apphelpers/db/__init__.py index e69de29..0589f79 100644 --- a/apphelpers/db/__init__.py +++ b/apphelpers/db/__init__.py @@ -0,0 +1,8 @@ +try: + from apphelpers.db.piccolo import * # noqa: F401, F403 + + print("apphelpers: using Piccolo") +except ImportError: + from apphelpers.db.peewee import * # noqa: F401, F403 + + print("apphelpers: using Peewee") diff --git a/apphelpers/db/peewee.py b/apphelpers/db/peewee.py index b96c35f..bd5e581 100644 --- a/apphelpers/db/peewee.py +++ b/apphelpers/db/peewee.py @@ -1,6 +1,7 @@ import datetime import logging import os +from contextlib import contextmanager from enum import Enum import pytz @@ -62,6 +63,17 @@ def created(): return DateTimeTZField(default=lambda: datetime.datetime.now(pytz.utc)) +@contextmanager +def dbtransaction_ctx(db): + if not db.in_transaction(): + with db.connection_context(): + with db.atomic(): + yield + else: + with db.atomic(): + yield + + def dbtransaction(db): """ wrapper that make db transactions automic @@ -72,14 +84,8 @@ def dbtransaction(db): def wrapper(f): @wraps(f) def f_wrapped(*args, **kw): - if not db.in_transaction(): - with db.connection_context(): - with db.atomic(): - result = f(*args, **kw) - else: - with db.atomic(): - result = f(*args, **kw) - return result + with dbtransaction_ctx(db): + return f(*args, **kw) return f_wrapped diff --git a/apphelpers/db/piccolo.py b/apphelpers/db/piccolo.py new file mode 100644 index 0000000..3156e94 --- /dev/null +++ b/apphelpers/db/piccolo.py @@ -0,0 +1,55 @@ +from contextlib import asynccontextmanager +from typing import List, Optional, Set, Type, cast + +from piccolo.engine import engine_finder +from piccolo.engine.postgres import PostgresEngine +from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync + + +@asynccontextmanager +async def connection_pool_lifespan(engine: Optional[PostgresEngine] = None, **kwargs): + if engine := engine or cast(PostgresEngine, engine_finder()): + print("db: starting connection pool") + await engine.start_connection_pool(**kwargs) + yield + print("db: closing connection pool") + await engine.close_connection_pool() + + +@asynccontextmanager +async def dbtransaction_ctx(engine: Optional[PostgresEngine] = None, allow_nested=True): + if engine := engine or cast(PostgresEngine, engine_finder()): + async with engine.transaction(allow_nested=allow_nested): + yield + + +class BaseTable(Table): + @classmethod + def all_column_names(cls) -> Set[str]: + return {col._meta.name for col in cls._meta.columns} + + +def get_sub_tables(basetable: Type[Table]) -> List[type[Table]]: + tables = [] + for subtable in basetable.__subclasses__(): + tables.append(subtable) + tables.extend(get_sub_tables(subtable)) + return tables + + +def setup_db(tables: list[Type[Table]]): + create_db_tables_sync(*tables, if_not_exists=True) + + +def setup_db_from_basetable(basetable: Type[Table]): + tables = get_sub_tables(basetable) + setup_db(tables) + + +def destroy_db(tables: List[Type[Table]]): + drop_db_tables_sync(*tables) + + +def destroy_db_from_basetable(basetable: Type[Table]): + tables = get_sub_tables(basetable) + destroy_db(tables) diff --git a/apphelpers/rest/fastapi.py b/apphelpers/rest/fastapi.py index 8d08f05..3ac4aeb 100644 --- a/apphelpers/rest/fastapi.py +++ b/apphelpers/rest/fastapi.py @@ -6,7 +6,7 @@ from fastapi.routing import APIRoute from starlette.requests import Request -from apphelpers.db.peewee import dbtransaction +from apphelpers.db import dbtransaction_ctx from apphelpers.errors.fastapi import ( HTTP401Unauthorized, HTTP403Forbidden, @@ -280,7 +280,6 @@ async def custom_route_handler(_request: Request): class APIFactory: def __init__(self, sessiondb_conn=None, urls_prefix="", site_identifier=None): - self.db_tr_wrapper = phony self.access_wrapper = phony self.multi_site_enabled = False self.site_identifier = site_identifier @@ -297,7 +296,7 @@ def enable_multi_site(self, site_identifier: str): self.site_identifier = site_identifier def setup_db_transaction(self, db): - self.db_tr_wrapper = dbtransaction(db) + self.router.dependencies.append(Depends(dbtransaction_ctx(db))) def setup_honeybadger_monitoring(self): api_key = settings.HONEYBADGER_API_KEY @@ -499,9 +498,7 @@ def build(self, method, method_args, method_kw, f): f"[{method.__name__.upper()}] => {f.__module__}:{f.__name__}", ) m = method(*method_args, **method_kw) - f = self.access_wrapper( - self.honeybadger_wrapper(self.db_tr_wrapper(raise_not_found_on_none(f))) - ) + f = self.access_wrapper(self.honeybadger_wrapper(raise_not_found_on_none(f))) # NOTE: ^ wrapper ordering is important. access_wrapper needs request which # others don't. If access_wrapper comes late in the order it won't be passed # request parameter. diff --git a/apphelpers/rest/hug.py b/apphelpers/rest/hug.py index 81a98ef..6385801 100644 --- a/apphelpers/rest/hug.py +++ b/apphelpers/rest/hug.py @@ -7,7 +7,7 @@ from falcon import HTTPForbidden, HTTPNotFound, HTTPUnauthorized from hug.decorators import wraps -from apphelpers.db.peewee import dbtransaction +from apphelpers.db import dbtransaction from apphelpers.errors.hug import BaseError, InvalidSessionError from apphelpers.loggers import api_logger from apphelpers.rest import endpoint as ep diff --git a/requirements_dev.txt b/requirements_dev.txt index 851014c..6817075 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -29,3 +29,4 @@ types-pytz types-redis redis loguru +piccolo[postgres] diff --git a/tests/test_piccolo.py b/tests/test_piccolo.py new file mode 100644 index 0000000..296b688 --- /dev/null +++ b/tests/test_piccolo.py @@ -0,0 +1,70 @@ +import asyncio + +import settings +from piccolo import columns as col +from piccolo.engine.postgres import PostgresEngine + +from apphelpers.db.piccolo import ( + BaseTable, + dbtransaction_ctx, + destroy_db_from_basetable, + setup_db_from_basetable, +) + +db = PostgresEngine( + config=dict( + host=settings.DB_HOST, + database=settings.DB_NAME, + user=settings.DB_USER, + password=settings.DB_PASS, + ) +) + + +class Book(BaseTable, db=db): + name = col.Text() + + +async def _add_book(name): + await Book.insert(Book(name=name)).run() + + +async def _add_book_loser(name): + await _add_book(name) + loser # will raise # noqa: F821 + + +def setup_module(): + setup_db_from_basetable(BaseTable) + + +def teardown_module(): + destroy_db_from_basetable(BaseTable) + + +async def add_with_tr(): + async with dbtransaction_ctx(db): + name = "The Pillars of the Earth" + await _add_book(name) + names = [b[Book.name] for b in await Book.select().run()] + assert name in names + + try: + async with dbtransaction_ctx(db): + name = "The Cathedral and the Bazaar" + await _add_book_loser(name) + except NameError: + pass + + names = [b[Book.name] for b in await Book.select().run()] + assert name not in names + + async with dbtransaction_ctx(db): + name = "The Ego Trick" + await _add_book(name) + names = [b[Book.name] for b in await Book.select().run()] + assert name in names + + +def test_add_with_tr(): + asyncio.run(add_with_tr())