Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add piccolo db #104

Merged
merged 8 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions apphelpers/db/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
try:
from apphelpers.db.piccolo import * # noqa: F401, F403

print("apphelpers: using Piccolo")
except ImportError:
from apphelpers.db.peewee import * # noqa: F401, F403

print("apphelpers: using Peewee")
22 changes: 14 additions & 8 deletions apphelpers/db/peewee.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import logging
import os
from contextlib import contextmanager
from enum import Enum

import pytz
Expand Down Expand Up @@ -62,6 +63,17 @@ def created():
return DateTimeTZField(default=lambda: datetime.datetime.now(pytz.utc))


@contextmanager
def dbtransaction_ctx(db):
if not db.in_transaction():
with db.connection_context():
with db.atomic():
yield
else:
with db.atomic():
yield


def dbtransaction(db):
"""
wrapper that make db transactions automic
Expand All @@ -72,14 +84,8 @@ def dbtransaction(db):
def wrapper(f):
@wraps(f)
def f_wrapped(*args, **kw):
if not db.in_transaction():
with db.connection_context():
with db.atomic():
result = f(*args, **kw)
else:
with db.atomic():
result = f(*args, **kw)
return result
with dbtransaction_ctx(db):
return f(*args, **kw)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We won't be using it in FastAPI. This is only for hug compatibility (and apphelpers tests).


return f_wrapped

Expand Down
62 changes: 62 additions & 0 deletions apphelpers/db/piccolo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from contextlib import asynccontextmanager
from functools import wraps
from typing import List, Set, Type

from piccolo.engine.postgres import PostgresEngine
from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync


@asynccontextmanager
async def connection_pool_lifespan(engine: PostgresEngine, **kwargs):
print("db: starting connection pool")
await engine.start_connection_pool(**kwargs)
yield
print("db: closing connection pool")
await engine.close_connection_pool()


dbtransaction_ctx = PostgresEngine.transaction


def dbtransaction(engine: PostgresEngine, allow_nested=True):
def wrapper(f):
@wraps(f)
async def f_wrapped(*args, **kw):
async with dbtransaction_ctx(engine, allow_nested=allow_nested):
return await f(*args, **kw)

return f_wrapped

return wrapper


class BaseTable(Table):
@classmethod
def all_column_names(cls) -> Set[str]:
return {col._meta.name for col in cls._meta.columns}


def get_sub_tables(basetable: Type[Table]) -> List[Type[Table]]:
tables = []
for subtable in basetable.__subclasses__():
tables.append(subtable)
tables.extend(get_sub_tables(subtable))
return tables


def setup_db(tables: List[Type[Table]]):
create_db_tables_sync(*tables, if_not_exists=True)


def setup_db_from_basetable(basetable: Type[Table]):
tables = get_sub_tables(basetable)
setup_db(tables)


def destroy_db(tables: List[Type[Table]]):
drop_db_tables_sync(*tables)


def destroy_db_from_basetable(basetable: Type[Table]):
tables = get_sub_tables(basetable)
destroy_db(tables)
17 changes: 11 additions & 6 deletions apphelpers/rest/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi.routing import APIRoute
from starlette.requests import Request

from apphelpers.db.peewee import dbtransaction
from apphelpers.db import dbtransaction_ctx
from apphelpers.errors.fastapi import (
HTTP401Unauthorized,
HTTP403Forbidden,
Expand Down Expand Up @@ -128,6 +128,14 @@ async def get_user_agent(request: Request):
user_agent = Depends(get_user_agent)


def dbtransaction(engine, allow_nested=True):
async def dependency():
async with dbtransaction_ctx(engine, allow_nested=allow_nested):
yield

return Depends(dependency)


class SecureRouter(APIRoute):
sessions = None

Expand Down Expand Up @@ -280,7 +288,6 @@ async def custom_route_handler(_request: Request):

class APIFactory:
def __init__(self, sessiondb_conn=None, urls_prefix="", site_identifier=None):
self.db_tr_wrapper = phony
self.access_wrapper = phony
self.multi_site_enabled = False
self.site_identifier = site_identifier
Expand All @@ -297,7 +304,7 @@ def enable_multi_site(self, site_identifier: str):
self.site_identifier = site_identifier

def setup_db_transaction(self, db):
self.db_tr_wrapper = dbtransaction(db)
self.router.dependencies.append(dbtransaction(db))

def setup_honeybadger_monitoring(self):
api_key = settings.HONEYBADGER_API_KEY
Expand Down Expand Up @@ -499,9 +506,7 @@ def build(self, method, method_args, method_kw, f):
f"[{method.__name__.upper()}] => {f.__module__}:{f.__name__}",
)
m = method(*method_args, **method_kw)
f = self.access_wrapper(
self.honeybadger_wrapper(self.db_tr_wrapper(raise_not_found_on_none(f)))
)
f = self.access_wrapper(self.honeybadger_wrapper(raise_not_found_on_none(f)))
# NOTE: ^ wrapper ordering is important. access_wrapper needs request which
# others don't. If access_wrapper comes late in the order it won't be passed
# request parameter.
Expand Down
2 changes: 1 addition & 1 deletion apphelpers/rest/hug.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from falcon import HTTPForbidden, HTTPNotFound, HTTPUnauthorized
from hug.decorators import wraps

from apphelpers.db.peewee import dbtransaction
from apphelpers.db import dbtransaction
from apphelpers.errors.hug import BaseError, InvalidSessionError
from apphelpers.loggers import api_logger
from apphelpers.rest import endpoint as ep
Expand Down
15 changes: 15 additions & 0 deletions fastapi_tests/app/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from apphelpers.rest import endpoint as ep
from apphelpers.rest.fastapi import json_body, user, user_agent, user_id
from fastapi_tests.app.models import Book


def echo(word, user=user):
Expand Down Expand Up @@ -93,6 +94,18 @@ async def get_fields(fields: set = Query(..., default_factory=set)):
return {k: v for k, v in data.items() if k in fields}


async def add_books(succeed: bool):
await Book.insert(Book(name="The Pillars of the Earth")).run()
await Book.insert(Book(name="The Cathedral and the Bazaar")).run()
if not succeed:
raise ValueError("Failure")
await Book.insert(Book(name="The Ego Trick")).run()


async def count_books():
return await Book.count()


def setup_routes(factory):
factory.get("/echo/{word}")(echo)
factory.get("/echo-async/{word}")(echo_async)
Expand All @@ -117,3 +130,5 @@ def setup_routes(factory):
echo_user_agent_without_site_ctx_async
)
factory.get("/fields")(get_fields)
factory.get("/count-books")(count_books)
factory.post("/add-books")(add_books)
18 changes: 18 additions & 0 deletions fastapi_tests/app/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import settings
from piccolo import columns as col
from piccolo.engine.postgres import PostgresEngine

from apphelpers.db.piccolo import BaseTable

db = PostgresEngine(
config=dict(
host=settings.DB_HOST,
database=settings.DB_NAME,
user=settings.DB_USER,
password=settings.DB_PASS,
)
)


class Book(BaseTable, db=db):
name = col.Text()
2 changes: 2 additions & 0 deletions fastapi_tests/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from apphelpers.rest.fastapi import APIFactory
from fastapi_tests.app.endpoints import setup_routes
from fastapi_tests.app.models import db


def make_app():
Expand All @@ -20,6 +21,7 @@ def make_app():
)

api_factory = APIFactory(sessiondb_conn=sessiondb_conn, site_identifier="site_id")
api_factory.setup_db_transaction(db)
setup_routes(api_factory)

app = fastapi.FastAPI()
Expand Down
28 changes: 28 additions & 0 deletions fastapi_tests/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from converge import settings

import apphelpers.sessions as sessionslib
from apphelpers.db.piccolo import setup_db_from_basetable, destroy_db_from_basetable
from fastapi_tests.app.models import BaseTable

base_url = "http://127.0.0.1:5000/"
echo_url = base_url + "echo"
Expand All @@ -24,6 +26,13 @@

def setup_module():
sessionsdb.destroy_all()
destroy_db_from_basetable(BaseTable)
setup_db_from_basetable(BaseTable)


def teardown_module():
sessionsdb.destroy_all()
destroy_db_from_basetable(BaseTable)


def test_get():
Expand Down Expand Up @@ -284,3 +293,22 @@ def test_user_agent_async_and_site_ctx():
response = requests.get(url, headers=headers)
assert response.status_code == 200
assert "python-requests" in response.text


def test_piccolo():
url = base_url + "count-books"
assert requests.get(url).json() == 0

url = base_url + "add-books"
data = {"succeed": True}
assert requests.post(url, params=data).status_code == 200

url = base_url + "count-books"
assert requests.get(url).json() == 3

url = base_url + "add-books"
data = {"succeed": False}
assert requests.post(url, params=data).status_code == 500

url = base_url + "count-books"
assert requests.get(url).json() == 3
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ types-pytz
types-redis
redis
loguru
piccolo[postgres]
70 changes: 70 additions & 0 deletions tests/test_piccolo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import asyncio

import settings
from piccolo import columns as col
from piccolo.engine.postgres import PostgresEngine

from apphelpers.db.piccolo import (
BaseTable,
dbtransaction_ctx,
destroy_db_from_basetable,
setup_db_from_basetable,
)

db = PostgresEngine(
config=dict(
host=settings.DB_HOST,
database=settings.DB_NAME,
user=settings.DB_USER,
password=settings.DB_PASS,
)
)


class Book(BaseTable, db=db):
name = col.Text()


async def _add_book(name):
await Book.insert(Book(name=name)).run()


async def _add_book_loser(name):
await _add_book(name)
loser # will raise # noqa: F821


def setup_module():
setup_db_from_basetable(BaseTable)


def teardown_module():
destroy_db_from_basetable(BaseTable)


async def add_with_tr():
async with dbtransaction_ctx(db):
name = "The Pillars of the Earth"
await _add_book(name)
names = [b[Book.name] for b in await Book.select().run()]
assert name in names

try:
async with dbtransaction_ctx(db):
name = "The Cathedral and the Bazaar"
await _add_book_loser(name)
except NameError:
pass

names = [b[Book.name] for b in await Book.select().run()]
assert name not in names

async with dbtransaction_ctx(db):
name = "The Ego Trick"
await _add_book(name)
names = [b[Book.name] for b in await Book.select().run()]
assert name in names


def test_add_with_tr():
asyncio.run(add_with_tr())
Loading