Skip to content

Commit

Permalink
Merge pull request #1022 from centerofci/function-api-denuked
Browse files Browse the repository at this point in the history
Introduce function API, without affecting existing APIs
  • Loading branch information
dmos62 committed Feb 7, 2022
2 parents 6d708e8 + 70ebae2 commit 5a6db9f
Show file tree
Hide file tree
Showing 20 changed files with 728 additions and 1 deletion.
Empty file added db/functions/__init__.py
Empty file.
228 changes: 228 additions & 0 deletions db/functions/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
"""
This namespace defines the DBFunction abstract class and its subclasses. These subclasses
represent functions that have identifiers, display names and hints, and their instances
hold parameters. Each DBFunction subclass defines how its instance can be converted into an
SQLAlchemy expression.
Hints hold information about what kind of input the function might expect and what output
can be expected from it. This is used to provide interface information without constraining its
user.
These classes might be used, for example, to define a filter for an SQL query, or to
access hints on what composition of functions and parameters should be valid.
"""

from abc import ABC, abstractmethod

from sqlalchemy import column, not_, and_, or_, func, literal

from db.functions import hints


class DBFunction(ABC):
id = None
name = None
hints = None

# Optionally lists the SQL functions this DBFunction depends on.
# Will be checked against SQL functions defined on a database to tell if it
# supports this DBFunction. Either None or a tuple of SQL function name
# strings.
depends_on = None

def __init__(self, parameters):
if self.id is None:
raise ValueError('DBFunction subclasses must define an ID.')
if self.name is None:
raise ValueError('DBFunction subclasses must define a name.')
if self.depends_on is not None and not isinstance(self.depends_on, tuple):
raise ValueError('DBFunction subclasses\' depends_on attribute must either be None or a tuple of SQL function names.')
self.parameters = parameters

@property
def referenced_columns(self):
"""Walks the expression tree, collecting referenced columns.
Useful when checking if all referenced columns are present in the queried relation."""
columns = set([])
for parameter in self.parameters:
if isinstance(parameter, ColumnReference):
columns.add(parameter.column)
elif isinstance(parameter, DBFunction):
columns.update(parameter.referenced_columns)
return columns

@staticmethod
@abstractmethod
def to_sa_expression():
return None


class Literal(DBFunction):
id = 'literal'
name = 'Literal'
hints = tuple([
hints.parameter_count(1),
hints.parameter(1, hints.literal),
])

@staticmethod
def to_sa_expression(primitive):
return literal(primitive)


class ColumnReference(DBFunction):
id = 'column_reference'
name = 'Column Reference'
hints = tuple([
hints.parameter_count(1),
hints.parameter(1, hints.column),
])

@property
def column(self):
return self.parameters[0]

@staticmethod
def to_sa_expression(column_name):
return column(column_name)


class List(DBFunction):
id = 'list'
name = 'List'

@staticmethod
def to_sa_expression(*items):
return list(items)


class Empty(DBFunction):
id = 'empty'
name = 'Empty'
hints = tuple([
hints.returns(hints.boolean),
hints.parameter_count(1),
])

@staticmethod
def to_sa_expression(value):
return value.is_(None)


class Not(DBFunction):
id = 'not'
name = 'Not'
hints = tuple([
hints.returns(hints.boolean),
hints.parameter_count(1),
])

@staticmethod
def to_sa_expression(value):
return not_(value)


class Equal(DBFunction):
id = 'equal'
name = 'Equal'
hints = tuple([
hints.returns(hints.boolean),
hints.parameter_count(2),
])

@staticmethod
def to_sa_expression(value1, value2):
return value1 == value2


class Greater(DBFunction):
id = 'greater'
name = 'Greater'
hints = tuple([
hints.returns(hints.boolean),
hints.parameter_count(2),
hints.all_parameters(hints.comparable),
])

@staticmethod
def to_sa_expression(value1, value2):
return value1 > value2


class Lesser(DBFunction):
id = 'lesser'
name = 'Lesser'
hints = tuple([
hints.returns(hints.boolean),
hints.parameter_count(2),
hints.all_parameters(hints.comparable),
])

@staticmethod
def to_sa_expression(value1, value2):
return value1 < value2


class In(DBFunction):
id = 'in'
name = 'In'
hints = tuple([
hints.returns(hints.boolean),
hints.parameter_count(2),
hints.parameter(2, hints.array),
])

@staticmethod
def to_sa_expression(value1, value2):
return value1.in_(value2)


class And(DBFunction):
id = 'and'
name = 'And'
hints = tuple([
hints.returns(hints.boolean),
])

@staticmethod
def to_sa_expression(*values):
return and_(*values)


class Or(DBFunction):
id = 'or'
name = 'Or'
hints = tuple([
hints.returns(hints.boolean),
])

@staticmethod
def to_sa_expression(*values):
return or_(*values)


class StartsWith(DBFunction):
id = 'starts_with'
name = 'Starts With'
hints = tuple([
hints.returns(hints.boolean),
hints.parameter_count(2),
hints.all_parameters(hints.string_like),
])

@staticmethod
def to_sa_expression(string, prefix):
return string.like(f'{prefix}%')


class ToLowercase(DBFunction):
id = 'to_lowercase'
name = 'To Lowercase'
hints = tuple([
hints.parameter_count(1),
hints.all_parameters(hints.string_like),
])

@staticmethod
def to_sa_expression(string):
return func.lower(string)
10 changes: 10 additions & 0 deletions db/functions/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class BadDBFunctionFormat(Exception):
pass


class UnknownDBFunctionId(BadDBFunctionFormat):
pass


class ReferencedColumnsDontExist(BadDBFunctionFormat):
pass
42 changes: 42 additions & 0 deletions db/functions/hints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from frozendict import frozendict


def _make_hint(id, **rest):
return frozendict({"id": id, **rest})


def parameter_count(count):
return _make_hint("parameter_count", count=count)


def parameter(index, *hints):
return _make_hint("parameter", index=index, hints=hints)


def all_parameters(*hints):
return _make_hint("all_parameters", hints=hints)


def returns(*hints):
return _make_hint("returns", hints=hints)


boolean = _make_hint("boolean")


comparable = _make_hint("comparable")


column = _make_hint("column")


array = _make_hint("array")


string_like = _make_hint("string_like")


uri = _make_hint("uri")


literal = _make_hint("literal")
57 changes: 57 additions & 0 deletions db/functions/known_db_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Exports the known_db_functions variable, which describes what `DBFunction`s the library is aware
of. Note, that a `DBFunction` might be in this collection, but not be supported by a given
database.
Contains a private collection (`_db_functions_in_other_modules`) of `DBFunction` subclasses
declared outside the base module.
These variables were broken off into a discrete module to avoid circular imports.
"""

import inspect

import db.functions.base

from db.functions.base import DBFunction

from db.types import uri


def _get_module_members_that_satisfy(module, predicate):
"""
Looks at the members of the provided module and filters them using the provided predicate.
In this context, it (together with the appropriate predicate) is used to automatically collect
all DBFunction subclasses found as top-level members of a module.
"""
all_members_in_defining_module = inspect.getmembers(module)
return tuple(
member
for _, member in all_members_in_defining_module
if predicate(member)
)


def _is_concrete_db_function_subclass(member):
return (
inspect.isclass(member)
and member != DBFunction
and issubclass(member, DBFunction)
)


_db_functions_in_base_module = (
_get_module_members_that_satisfy(
db.functions.base,
_is_concrete_db_function_subclass
)
)


_db_functions_in_other_modules = tuple([
uri.ExtractURIAuthority,
])


known_db_functions = _db_functions_in_base_module + _db_functions_in_other_modules
50 changes: 50 additions & 0 deletions db/functions/operations/apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from db.functions.base import DBFunction
from db.functions.exceptions import ReferencedColumnsDontExist
from db.functions.operations.deserialize import get_db_function_from_ma_function_spec


def apply_ma_function_spec_as_filter(relation, ma_function_spec):
db_function = get_db_function_from_ma_function_spec(ma_function_spec)
return apply_db_function_as_filter(relation, db_function)


def apply_db_function_as_filter(relation, db_function):
_assert_that_all_referenced_columns_exist(relation, db_function)
sa_expression = _db_function_to_sa_expression(db_function)
relation = relation.filter(sa_expression)
return relation


def _assert_that_all_referenced_columns_exist(relation, db_function):
columns_that_exist = _get_columns_that_exist(relation)
referenced_columns = db_function.referenced_columns
referenced_columns_that_dont_exist = \
set.difference(referenced_columns, columns_that_exist)
if len(referenced_columns_that_dont_exist) > 0:
raise ReferencedColumnsDontExist(
"These referenced columns don't exist on the relevant relation: "
+ f"{referenced_columns_that_dont_exist}"
)


def _get_columns_that_exist(relation):
columns = relation.selected_columns
return set(column.name for column in columns)


def _db_function_to_sa_expression(db_function):
"""
Takes a DBFunction, looks at the tree of its parameters (and the parameters of nested
DBFunctions), and turns it into an SQLAlchemy expression. Each parameter is expected to either
be a DBFunction instance or a literal primitive.
"""
if isinstance(db_function, DBFunction):
raw_parameters = db_function.parameters
parameters = [
_db_function_to_sa_expression(raw_parameter)
for raw_parameter in raw_parameters
]
db_function_subclass = type(db_function)
return db_function_subclass.to_sa_expression(*parameters)
else:
return db_function
Loading

0 comments on commit 5a6db9f

Please sign in to comment.