diff --git a/apphelpers/db/piccolo.py b/apphelpers/db/piccolo.py index d88014f..ad343ad 100644 --- a/apphelpers/db/piccolo.py +++ b/apphelpers/db/piccolo.py @@ -15,7 +15,10 @@ async def connection_pool_lifespan(engine: PostgresEngine, **kwargs): await engine.close_connection_pool() -dbtransaction_ctx = PostgresEngine.transaction +@asynccontextmanager +async def dbtransaction_ctx(engine: PostgresEngine, allow_nested=True): + async with engine.transaction(allow_nested=allow_nested): + yield def dbtransaction(engine: PostgresEngine, allow_nested=True): diff --git a/apphelpers/rest/fastapi.py b/apphelpers/rest/fastapi.py index 27bc52f..4fb7fb4 100644 --- a/apphelpers/rest/fastapi.py +++ b/apphelpers/rest/fastapi.py @@ -128,6 +128,14 @@ async def get_user_agent(request: Request): user_agent = Depends(get_user_agent) +def dbtransaction(engine, allow_nested=True): + async def dependency(): + async with dbtransaction_ctx(engine, allow_nested=allow_nested): + yield + + return Depends(dependency) + + class SecureRouter(APIRoute): sessions = None @@ -295,8 +303,8 @@ def enable_multi_site(self, site_identifier: str): self.multi_site_enabled = True self.site_identifier = site_identifier - def setup_db_transaction(self, db=None): - self.router.dependencies.append(Depends(dbtransaction_ctx(db))) + def setup_db_transaction(self, db): + self.router.dependencies.append(dbtransaction(db)) def setup_honeybadger_monitoring(self): api_key = settings.HONEYBADGER_API_KEY