Skip to content

Commit

Permalink
Use a scoped session to ensure either commit or rollback is called
Browse files Browse the repository at this point in the history
  • Loading branch information
polyrabbit committed Mar 1, 2024
1 parent b5cca5a commit 2b1eb43
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 47 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/probe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ on:
types: [ probe-hn-sites ]

schedule:
- cron: "*/10 * * * *"
- cron: "*/30 * * * *"

# Allow one concurrent deployment
concurrency:
Expand Down
20 changes: 16 additions & 4 deletions db/engine.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
import datetime
import logging
import time

from contextlib import contextmanager
from sqlalchemy import create_engine, TIMESTAMP, event, Engine
from sqlalchemy.orm import DeclarativeBase, mapped_column, Session

import config

logger = logging.getLogger(__name__)
engine = create_engine(config.DATABASE_URL, echo=config.DATABASE_ECHO_SQL) # lazy connection

# TODO: should have a scope
session = Session(engine)


@contextmanager
def session_scope(defer_commit: bool = False) -> Session:
"""Provide a transactional scope around a series of operations."""
try:
yield session
if not defer_commit:
session.commit()
except:
session.rollback()
raise
# finally:
# session.close()


class Base(DeclarativeBase):
access = mapped_column(TIMESTAMP, default=datetime.datetime.utcnow)

Expand All @@ -25,6 +37,6 @@ def before_cursor_execute(conn, cursor, statement, parameters, context, executem

@event.listens_for(Engine, "after_cursor_execute")
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
cost = (time.time() - conn.info["query_start_time"].pop(-1))*1000
cost = (time.time() - conn.info["query_start_time"].pop(-1)) * 1000
if cost >= config.SLOW_SQL_MS:
logger.warning(f'Slow sql {statement}, cost(ms): {cost:.2f}')
11 changes: 6 additions & 5 deletions db/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import config
from db import Summary
from db.engine import session
from db.engine import session_scope

logger = logging.getLogger(__name__)

Expand All @@ -26,10 +26,11 @@ def expire():
stmt = select(values).join(Summary, Summary.image_name == values.c.name,
isouter=True # Add this to implement left outer join
).where(Summary.image_name.is_(None))
for image_name in session.execute(stmt):
logger.debug(f'removing {image_name[0]}')
os.remove(os.path.join(config.image_dir, image_name[0]))
removed += 1
with session_scope() as session:
for image_name in session.execute(stmt):
logger.debug(f'removing {image_name[0]}')
os.remove(os.path.join(config.image_dir, image_name[0]))
removed += 1
cost = (time.time() - start) * 1000
logger.info(f'removed {removed} feature images, cost(ms): {cost:.2f}')

Expand Down
36 changes: 19 additions & 17 deletions db/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlalchemy.orm import mapped_column

import config
from db.engine import Base, session
from db.engine import Base, session_scope

logger = logging.getLogger(__name__)
CONTENT_TTL = 1 * 24 * 60 * 60
Expand Down Expand Up @@ -70,7 +70,8 @@ def get_summary_model(self):
def get(url) -> Summary:
if config.disable_summary_cache:
return Summary(url)
summary = session.get(Summary, url) # Try to leverage the identity map cache
with session_scope() as session:
summary = session.get(Summary, url) # Try to leverage the identity map cache
return summary or Summary(url)


Expand All @@ -83,32 +84,33 @@ def put(db_summary: Summary) -> Summary:
db_summary.image_name = db_summary.image_name[:Summary.image_name.type.length]
if db_summary.image_json:
db_summary.image_json = db_summary.image_json[:Summary.image_json.type.length]
db_summary = session.merge(db_summary)
session.commit()
with session_scope() as session:
db_summary = session.merge(db_summary)
return db_summary


def filter_url(url_list: list[str]) -> set[str]:
# use `all()` to populate the Identity Map so that following `get` can read from cache
summaries = session.scalars(select(Summary).where(Summary.url.in_(url_list))).all()
assert len(session.identity_map) == len(summaries)
with session_scope() as session:
summaries = session.scalars(select(Summary).where(Summary.url.in_(url_list))).all()
assert len(session.identity_map) == len(summaries)
return set(s.url for s in summaries)


def expire():
start = time.time()
stmt = delete(Summary).where(
Summary.access < datetime.utcnow() - timedelta(seconds=config.summary_ttl))
result = session.execute(stmt)
deleted = result.rowcount
logger.info(f'evicted {result.rowcount} summary items')
with session_scope() as session:
result = session.execute(stmt)
deleted = result.rowcount
logger.info(f'evicted {result.rowcount} summary items')

stmt = delete(Summary).where(
Summary.access < datetime.utcnow() - timedelta(seconds=CONTENT_TTL),
Summary.model.not_in((Model.OPENAI.value, Model.TRANSFORMER.value, Model.LLAMA.value)))
result = session.execute(stmt)
cost = (time.time() - start) * 1000
logger.info(f'evicted {result.rowcount} full content items, cost(ms): {cost:.2f}')

stmt = delete(Summary).where(
Summary.access < datetime.utcnow() - timedelta(seconds=CONTENT_TTL),
Summary.model.not_in((Model.OPENAI.value, Model.TRANSFORMER.value, Model.LLAMA.value)))
result = session.execute(stmt)
cost = (time.time() - start) * 1000
logger.info(f'evicted {result.rowcount} full content items, cost(ms): {cost:.2f}')

session.commit()
return deleted + result.rowcount
20 changes: 10 additions & 10 deletions db/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlalchemy.orm import mapped_column

import config
from db.engine import Base, session
from db.engine import Base, session_scope

logger = logging.getLogger(__name__)

Expand All @@ -28,11 +28,11 @@ def get(text, to_lang):
return text # shortcut
text = text[:Translation.source.type.length]
stmt = select(Translation).where(Translation.source == text, Translation.language == to_lang)
trans = session.scalars(stmt).first()
if trans:
trans.access = datetime.utcnow()
# session.commit() # Ok to batch it
return trans.target
with session_scope(defer_commit=True) as session: # Ok to batch it
trans = session.scalars(stmt).first()
if trans:
trans.access = datetime.utcnow()
return trans.target
return text


Expand All @@ -42,16 +42,16 @@ def add(source, target, lang):
source = source[:Translation.source.type.length]
target = target[:Translation.source.type.length]
trans = Translation(source=source, target=target, language=lang, access=datetime.utcnow())
session.merge(trans) # source is primary key
session.commit()
with session_scope() as session:
session.merge(trans) # source is primary key


def expire():
start = time.time()
stmt = delete(Translation).where(
Translation.access < datetime.utcnow() - timedelta(seconds=config.summary_ttl))
result = session.execute(stmt)
session.commit()
with session_scope() as session:
result = session.execute(stmt)
cost = (time.time() - start) * 1000
logger.info(f'evicted {result.rowcount} translation items, cost(ms): {cost:.2f}')
return result.rowcount
5 changes: 4 additions & 1 deletion page_content_extractor/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ class CustomHTTPAdapter(HTTPAdapter):

def __init__(self, *args, **kwargs):
if "max_retries" not in kwargs:
kwargs['max_retries'] = 3
# Just fail fast, otherwise the total timeout will be 30s * max_retries
# bad case is https://struct.ai/blog/introducing-the-struct-chat-platform,
# which blocks all image requests, so the whole update-round times out
kwargs['max_retries'] = 1
# Remove until switching to Python 3.12,
# https://stackoverflow.com/questions/71603314/ssl-error-unsafe-legacy-renegotiation-disabled
# https://github.com/urllib3/urllib3/issues/2653
Expand Down
12 changes: 6 additions & 6 deletions test/test_news_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import config
import db
from db.engine import session
from db.engine import session_scope
from db.summary import Model
from hacker_news.llm.coze import summarize_by_coze
from hacker_news.news import News
Expand Down Expand Up @@ -94,8 +94,8 @@ def test_fallback_when_new_fetch_failed(self):
self.assertEqual('wonderful summary', summary)
self.assertEqual(news.cache.get_summary_model(), summarized_by)
finally:
session.delete(news.cache)
session.commit()
with session_scope() as session:
session.delete(news.cache)

@mock.patch.object(News, 'parser')
def test_all_from_cache(self, mock_news_parser):
Expand All @@ -116,6 +116,6 @@ def test_all_from_cache(self, mock_news_parser):
self.assertEqual(db_summary, cached)
self.assertFalse(mock_news_parser.called)
finally:
session.delete(news.cache)
pathlib.Path(os.path.join(config.image_dir, db_summary.image_name)).unlink()
session.commit()
with session_scope() as session:
session.delete(news.cache)
pathlib.Path(os.path.join(config.image_dir, db_summary.image_name)).unlink()
8 changes: 5 additions & 3 deletions test/test_summary_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import config
import db.summary
from db import translation, Translation, summary
from db.engine import session
from db.engine import session_scope


class TranslationCacheTestCase(unittest.TestCase):
Expand All @@ -14,7 +14,8 @@ def test_translation_cache(self):
self.assertEqual('hello', translation.get('hello', 'en'))
deleted = translation.expire()
self.assertEqual(0, deleted)
trans = session.get(Translation, 'hello')
with session_scope() as session:
trans = session.get(Translation, 'hello')
trans.access = datetime.utcnow() - timedelta(seconds=translation.config.summary_ttl + 1)
deleted = translation.expire()
self.assertEqual(1, deleted)
Expand All @@ -24,7 +25,8 @@ def test_summary_cache(self):
self.assertEqual('world', summary.get('hello').summary)
deleted = summary.expire()
self.assertEqual(0, deleted)
summ = session.get(db.Summary, 'hello')
with session_scope() as session:
summ = session.get(db.Summary, 'hello')
summ.access = datetime.utcnow() - timedelta(seconds=summary.CONTENT_TTL + 1) # not expired
deleted = summary.expire()
self.assertEqual(0, deleted)
Expand Down

0 comments on commit 2b1eb43

Please sign in to comment.