Skip to content

Commit

Permalink
Merge pull request #1079 from centerofci/add-uri-filters
Browse files Browse the repository at this point in the history
Add URI and text filters
  • Loading branch information
mathemancer committed Feb 28, 2022
2 parents 02e22cb + cb9a118 commit 67c6439
Show file tree
Hide file tree
Showing 18 changed files with 448 additions and 124 deletions.
60 changes: 56 additions & 4 deletions db/functions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
from db.functions.exceptions import BadDBFunctionFormat


def sa_call_sql_function(function_name, *parameters):
return getattr(func, function_name)(*parameters)


# NOTE: this class is abstract.
class DBFunction(ABC):
id = None
Expand Down Expand Up @@ -73,7 +77,7 @@ class Literal(DBFunction):
name = 'as literal'
hints = tuple([
hints.parameter_count(1),
hints.parameter(1, hints.literal),
hints.parameter(0, hints.literal),
])

@staticmethod
Expand All @@ -87,7 +91,7 @@ class ColumnName(DBFunction):
name = 'as column name'
hints = tuple([
hints.parameter_count(1),
hints.parameter(1, hints.column),
hints.parameter(0, hints.column),
])

@property
Expand Down Expand Up @@ -190,7 +194,8 @@ class In(DBFunction):
hints = tuple([
hints.returns(hints.boolean),
hints.parameter_count(2),
hints.parameter(2, hints.array),
hints.parameter(0, hints.any),
hints.parameter(1, hints.array),
])

@staticmethod
Expand Down Expand Up @@ -225,6 +230,36 @@ def to_sa_expression(*values):
class StartsWith(DBFunction):
id = 'starts_with'
name = 'starts with'
hints = tuple([
hints.returns(hints.boolean),
hints.parameter_count(2),
hints.all_parameters(hints.string_like),
])

@staticmethod
def to_sa_expression(string, prefix):
pattern = func.concat(prefix, '%')
return string.like(pattern)


class Contains(DBFunction):
id = 'contains'
name = 'contains'
hints = tuple([
hints.returns(hints.boolean),
hints.parameter_count(2),
hints.all_parameters(hints.string_like),
])

@staticmethod
def to_sa_expression(string, sub_string):
pattern = func.concat('%', sub_string, '%')
return string.like(pattern)


class StartsWithCaseInsensitive(DBFunction):
id = 'starts_with_case_insensitive'
name = 'starts with'
hints = tuple([
hints.returns(hints.boolean),
hints.parameter_count(2),
Expand All @@ -234,7 +269,24 @@ class StartsWith(DBFunction):

@staticmethod
def to_sa_expression(string, prefix):
return string.like(f'{prefix}%')
pattern = func.concat(prefix, '%')
return string.ilike(pattern)


class ContainsCaseInsensitive(DBFunction):
id = 'contains_case_insensitive'
name = 'contains'
hints = tuple([
hints.returns(hints.boolean),
hints.parameter_count(2),
hints.all_parameters(hints.string_like),
hints.mathesar_filter,
])

@staticmethod
def to_sa_expression(string, sub_string):
pattern = func.concat('%', sub_string, '%')
return string.ilike(pattern)


class ToLowercase(DBFunction):
Expand Down
44 changes: 16 additions & 28 deletions db/functions/known_db_functions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
"""
Exports the known_db_functions variable, which describes what `DBFunction`s the library is aware
of. Note, that a `DBFunction` might be in this collection, but not be supported by a given
database.
Contains a private collection (`_db_functions_in_other_modules`) of `DBFunction` subclasses
declared outside the base module.
Exports the known_db_functions variable, which describes what `DBFunction` concrete subclasses the
library is aware of. Note, that a `DBFunction` might be in this collection, but not be
supported by a given database.
These variables were broken off into a discrete module to avoid circular imports.
"""
Expand All @@ -13,11 +10,10 @@

import db.functions.base
import db.functions.redundant
import db.types.uri

from db.functions.base import DBFunction

from db.types import uri


def _get_module_members_that_satisfy(module, predicate):
"""
Expand All @@ -42,29 +38,21 @@ def _is_concrete_db_function_subclass(member):
)


_db_functions_in_base_module = (
_get_module_members_that_satisfy(
db.functions.base,
_is_concrete_db_function_subclass
)
)
_modules_to_search_in = tuple([
db.functions.base,
db.functions.redundant,
db.types.uri
])


_db_functions_in_redundant_module = (
def _concat_tuples(tuples):
return sum(tuples, ())


known_db_functions = _concat_tuples([
_get_module_members_that_satisfy(
db.functions.redundant,
module,
_is_concrete_db_function_subclass
)
)


_db_functions_in_other_modules = tuple([
uri.ExtractURIAuthority,
for module in _modules_to_search_in
])


known_db_functions = (
_db_functions_in_base_module
+ _db_functions_in_redundant_module
+ _db_functions_in_other_modules
)
19 changes: 13 additions & 6 deletions db/functions/operations/apply.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from db.functions.base import DBFunction
from db.functions.exceptions import ReferencedColumnsDontExist
from db.functions.redundant import RedundantDBFunction
from db.functions.operations.deserialize import get_db_function_from_ma_function_spec


Expand Down Expand Up @@ -32,19 +33,25 @@ def _get_columns_that_exist(relation):
return set(column.name for column in columns)


def _db_function_to_sa_expression(db_function):
def _db_function_to_sa_expression(db_function_or_literal):
"""
Takes a DBFunction, looks at the tree of its parameters (and the parameters of nested
Takes a DBFunction instance, looks at the tree of its parameters (and the parameters of nested
DBFunctions), and turns it into an SQLAlchemy expression. Each parameter is expected to either
be a DBFunction instance or a literal primitive.
"""
if isinstance(db_function, DBFunction):
if isinstance(db_function_or_literal, RedundantDBFunction):
db_function = db_function_or_literal
unpacked_db_function = db_function.unpack()
return _db_function_to_sa_expression(unpacked_db_function)
elif isinstance(db_function_or_literal, DBFunction):
db_function = db_function_or_literal
raw_parameters = db_function.parameters
parameters = [
sa_expression_parameters = [
_db_function_to_sa_expression(raw_parameter)
for raw_parameter in raw_parameters
]
db_function_subclass = type(db_function)
return db_function_subclass.to_sa_expression(*parameters)
return db_function_subclass.to_sa_expression(*sa_expression_parameters)
else:
return db_function
literal = db_function_or_literal
return literal
3 changes: 2 additions & 1 deletion db/functions/operations/check_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def _are_db_function_dependencies_satisfied(db_function, functions_on_database):
return (
no_dependencies
or all(
dependency_function in functions_on_database
# dependency_function is expected to be an Enum member
dependency_function.value in functions_on_database
for dependency_function in db_function.depends_on
)
)
49 changes: 36 additions & 13 deletions db/functions/redundant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,33 @@
Mathesar filters not supporting composition.
"""

from db.functions.base import DBFunction
from abc import abstractmethod

from db.functions.base import DBFunction, Or, Lesser, Equal, Greater

from db.functions import hints


# TODO (tech debt) define in terms of other DBFunctions
# would involve creating an alternative to to_sa_expression: something like to_db_function
# execution engine would see that to_sa_expression is not implemented, and it would look for
# to_db_function.
class RedundantDBFunction(DBFunction):
"""
A DBFunction that is meant to be unpacked into another DBFunction. A way to define a DBFunction
as a combination of DBFunctions. Its to_sa_expression method is not to used. Its concrete
implementations are expected to implement the unpack method.
"""
@staticmethod
def to_sa_expression(*_):
raise Exception("UnpackabelDBFunction.to_sa_expression should never be used.")

@abstractmethod
def unpack(self):
"""
Should return a DBFunction instance with self.parameters forwarded to it. A way to define
a DBFunction in terms of other DBFunctions.
"""
pass


class LesserOrEqual(DBFunction):
class LesserOrEqual(RedundantDBFunction):
id = 'lesser_or_equal'
name = 'is lesser or equal to'
hints = tuple([
Expand All @@ -25,12 +40,16 @@ class LesserOrEqual(DBFunction):
hints.mathesar_filter,
])

@staticmethod
def to_sa_expression(value1, value2):
return value1 <= value2
def unpack(self):
param0 = self.parameters[0]
param1 = self.parameters[1]
return Or([
Lesser([param0, param1]),
Equal([param0, param1]),
])


class GreaterOrEqual(DBFunction):
class GreaterOrEqual(RedundantDBFunction):
id = 'greater_or_equal'
name = 'is greater or equal to'
hints = tuple([
Expand All @@ -40,6 +59,10 @@ class GreaterOrEqual(DBFunction):
hints.mathesar_filter,
])

@staticmethod
def to_sa_expression(value1, value2):
return value1 >= value2
def unpack(self):
param0 = self.parameters[0]
param1 = self.parameters[1]
return Or([
Greater([param0, param1]),
Equal([param0, param1]),
])
29 changes: 28 additions & 1 deletion db/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import os

import pytest
from sqlalchemy import MetaData, text
from sqlalchemy import MetaData, text, Table

from db import constants, types
from db.tables.operations.split import extract_columns_from_table
from db.engine import _add_custom_types_to_engine
from db.types import install


APP_SCHEMA = "test_schema"

FILE_DIR = os.path.abspath(os.path.dirname(__file__))
RESOURCES = os.path.join(FILE_DIR, "resources")
ROSTER_SQL = os.path.join(RESOURCES, "roster_create.sql")
URIS_SQL = os.path.join(RESOURCES, "uris_create.sql")
FILTER_SORT_SQL = os.path.join(RESOURCES, "filter_sort_create.sql")


Expand All @@ -34,6 +37,17 @@ def engine_with_roster(engine_with_schema):
return engine, schema


@pytest.fixture
def engine_with_uris(engine_with_schema):
engine, schema = engine_with_schema
_add_custom_types_to_engine(engine)
install.install_mathesar_on_database(engine)
with engine.begin() as conn, open(URIS_SQL) as f:
conn.execute(text(f"SET search_path={schema}"))
conn.execute(text(f.read()))
return engine, schema


@pytest.fixture
def engine_with_filter_sort(engine_with_schema):
engine, schema = engine_with_schema
Expand All @@ -49,6 +63,11 @@ def roster_table_name():
return "Roster"


@pytest.fixture(scope='session')
def uris_table_name():
return "uris"


@pytest.fixture(scope='session')
def teachers_table_name():
return "Teachers"
Expand Down Expand Up @@ -86,3 +105,11 @@ def extracted_remainder_roster(engine_with_roster, roster_table_name, roster_ext
roster_no_teachers = metadata.tables[f"{schema}.{roster_no_teachers_table_name}"]
roster = metadata.tables[f"{schema}.{roster_table_name}"]
return teachers, roster_no_teachers, roster, engine, schema


@pytest.fixture
def roster_table_obj(engine_with_roster, roster_table_name):
engine, schema = engine_with_roster
metadata = MetaData(bind=engine)
roster = Table(roster_table_name, metadata, schema=schema, autoload_with=engine)
return roster, engine
Loading

0 comments on commit 67c6439

Please sign in to comment.