diff --git a/db/functions/base.py b/db/functions/base.py index 8142e12885..cdd97f021f 100644 --- a/db/functions/base.py +++ b/db/functions/base.py @@ -20,6 +20,10 @@ from db.functions.exceptions import BadDBFunctionFormat +def sa_call_sql_function(function_name, *parameters): + return getattr(func, function_name)(*parameters) + + # NOTE: this class is abstract. class DBFunction(ABC): id = None @@ -73,7 +77,7 @@ class Literal(DBFunction): name = 'as literal' hints = tuple([ hints.parameter_count(1), - hints.parameter(1, hints.literal), + hints.parameter(0, hints.literal), ]) @staticmethod @@ -87,7 +91,7 @@ class ColumnName(DBFunction): name = 'as column name' hints = tuple([ hints.parameter_count(1), - hints.parameter(1, hints.column), + hints.parameter(0, hints.column), ]) @property @@ -190,7 +194,8 @@ class In(DBFunction): hints = tuple([ hints.returns(hints.boolean), hints.parameter_count(2), - hints.parameter(2, hints.array), + hints.parameter(0, hints.any), + hints.parameter(1, hints.array), ]) @staticmethod @@ -225,6 +230,36 @@ def to_sa_expression(*values): class StartsWith(DBFunction): id = 'starts_with' name = 'starts with' + hints = tuple([ + hints.returns(hints.boolean), + hints.parameter_count(2), + hints.all_parameters(hints.string_like), + ]) + + @staticmethod + def to_sa_expression(string, prefix): + pattern = func.concat(prefix, '%') + return string.like(pattern) + + +class Contains(DBFunction): + id = 'contains' + name = 'contains' + hints = tuple([ + hints.returns(hints.boolean), + hints.parameter_count(2), + hints.all_parameters(hints.string_like), + ]) + + @staticmethod + def to_sa_expression(string, sub_string): + pattern = func.concat('%', sub_string, '%') + return string.like(pattern) + + +class StartsWithCaseInsensitive(DBFunction): + id = 'starts_with_case_insensitive' + name = 'starts with' hints = tuple([ hints.returns(hints.boolean), hints.parameter_count(2), @@ -234,7 +269,24 @@ class StartsWith(DBFunction): @staticmethod def to_sa_expression(string, prefix): - return string.like(f'{prefix}%') + pattern = func.concat(prefix, '%') + return string.ilike(pattern) + + +class ContainsCaseInsensitive(DBFunction): + id = 'contains_case_insensitive' + name = 'contains' + hints = tuple([ + hints.returns(hints.boolean), + hints.parameter_count(2), + hints.all_parameters(hints.string_like), + hints.mathesar_filter, + ]) + + @staticmethod + def to_sa_expression(string, sub_string): + pattern = func.concat('%', sub_string, '%') + return string.ilike(pattern) class ToLowercase(DBFunction): diff --git a/db/functions/known_db_functions.py b/db/functions/known_db_functions.py index 35f0ae6da8..5c706be264 100644 --- a/db/functions/known_db_functions.py +++ b/db/functions/known_db_functions.py @@ -1,10 +1,7 @@ """ -Exports the known_db_functions variable, which describes what `DBFunction`s the library is aware -of. Note, that a `DBFunction` might be in this collection, but not be supported by a given -database. - -Contains a private collection (`_db_functions_in_other_modules`) of `DBFunction` subclasses -declared outside the base module. +Exports the known_db_functions variable, which describes what `DBFunction` concrete subclasses the +library is aware of. Note, that a `DBFunction` might be in this collection, but not be +supported by a given database. These variables were broken off into a discrete module to avoid circular imports. """ @@ -13,11 +10,10 @@ import db.functions.base import db.functions.redundant +import db.types.uri from db.functions.base import DBFunction -from db.types import uri - def _get_module_members_that_satisfy(module, predicate): """ @@ -42,29 +38,21 @@ def _is_concrete_db_function_subclass(member): ) -_db_functions_in_base_module = ( - _get_module_members_that_satisfy( - db.functions.base, - _is_concrete_db_function_subclass - ) -) +_modules_to_search_in = tuple([ + db.functions.base, + db.functions.redundant, + db.types.uri +]) -_db_functions_in_redundant_module = ( +def _concat_tuples(tuples): + return sum(tuples, ()) + + +known_db_functions = _concat_tuples([ _get_module_members_that_satisfy( - db.functions.redundant, + module, _is_concrete_db_function_subclass ) -) - - -_db_functions_in_other_modules = tuple([ - uri.ExtractURIAuthority, + for module in _modules_to_search_in ]) - - -known_db_functions = ( - _db_functions_in_base_module - + _db_functions_in_redundant_module - + _db_functions_in_other_modules -) diff --git a/db/functions/operations/apply.py b/db/functions/operations/apply.py index 7e58625a2d..a0032359e1 100644 --- a/db/functions/operations/apply.py +++ b/db/functions/operations/apply.py @@ -1,5 +1,6 @@ from db.functions.base import DBFunction from db.functions.exceptions import ReferencedColumnsDontExist +from db.functions.redundant import RedundantDBFunction from db.functions.operations.deserialize import get_db_function_from_ma_function_spec @@ -32,19 +33,25 @@ def _get_columns_that_exist(relation): return set(column.name for column in columns) -def _db_function_to_sa_expression(db_function): +def _db_function_to_sa_expression(db_function_or_literal): """ - Takes a DBFunction, looks at the tree of its parameters (and the parameters of nested + Takes a DBFunction instance, looks at the tree of its parameters (and the parameters of nested DBFunctions), and turns it into an SQLAlchemy expression. Each parameter is expected to either be a DBFunction instance or a literal primitive. """ - if isinstance(db_function, DBFunction): + if isinstance(db_function_or_literal, RedundantDBFunction): + db_function = db_function_or_literal + unpacked_db_function = db_function.unpack() + return _db_function_to_sa_expression(unpacked_db_function) + elif isinstance(db_function_or_literal, DBFunction): + db_function = db_function_or_literal raw_parameters = db_function.parameters - parameters = [ + sa_expression_parameters = [ _db_function_to_sa_expression(raw_parameter) for raw_parameter in raw_parameters ] db_function_subclass = type(db_function) - return db_function_subclass.to_sa_expression(*parameters) + return db_function_subclass.to_sa_expression(*sa_expression_parameters) else: - return db_function + literal = db_function_or_literal + return literal diff --git a/db/functions/operations/check_support.py b/db/functions/operations/check_support.py index c337a5d69d..478af1ae72 100644 --- a/db/functions/operations/check_support.py +++ b/db/functions/operations/check_support.py @@ -32,7 +32,8 @@ def _are_db_function_dependencies_satisfied(db_function, functions_on_database): return ( no_dependencies or all( - dependency_function in functions_on_database + # dependency_function is expected to be an Enum member + dependency_function.value in functions_on_database for dependency_function in db_function.depends_on ) ) diff --git a/db/functions/redundant.py b/db/functions/redundant.py index 83243b1c5d..9da475d508 100644 --- a/db/functions/redundant.py +++ b/db/functions/redundant.py @@ -4,18 +4,33 @@ Mathesar filters not supporting composition. """ -from db.functions.base import DBFunction +from abc import abstractmethod + +from db.functions.base import DBFunction, Or, Lesser, Equal, Greater from db.functions import hints -# TODO (tech debt) define in terms of other DBFunctions -# would involve creating an alternative to to_sa_expression: something like to_db_function -# execution engine would see that to_sa_expression is not implemented, and it would look for -# to_db_function. +class RedundantDBFunction(DBFunction): + """ + A DBFunction that is meant to be unpacked into another DBFunction. A way to define a DBFunction + as a combination of DBFunctions. Its to_sa_expression method is not to used. Its concrete + implementations are expected to implement the unpack method. + """ + @staticmethod + def to_sa_expression(*_): + raise Exception("UnpackabelDBFunction.to_sa_expression should never be used.") + + @abstractmethod + def unpack(self): + """ + Should return a DBFunction instance with self.parameters forwarded to it. A way to define + a DBFunction in terms of other DBFunctions. + """ + pass -class LesserOrEqual(DBFunction): +class LesserOrEqual(RedundantDBFunction): id = 'lesser_or_equal' name = 'is lesser or equal to' hints = tuple([ @@ -25,12 +40,16 @@ class LesserOrEqual(DBFunction): hints.mathesar_filter, ]) - @staticmethod - def to_sa_expression(value1, value2): - return value1 <= value2 + def unpack(self): + param0 = self.parameters[0] + param1 = self.parameters[1] + return Or([ + Lesser([param0, param1]), + Equal([param0, param1]), + ]) -class GreaterOrEqual(DBFunction): +class GreaterOrEqual(RedundantDBFunction): id = 'greater_or_equal' name = 'is greater or equal to' hints = tuple([ @@ -40,6 +59,10 @@ class GreaterOrEqual(DBFunction): hints.mathesar_filter, ]) - @staticmethod - def to_sa_expression(value1, value2): - return value1 >= value2 + def unpack(self): + param0 = self.parameters[0] + param1 = self.parameters[1] + return Or([ + Greater([param0, param1]), + Equal([param0, param1]), + ]) diff --git a/db/tests/conftest.py b/db/tests/conftest.py index 0b1a0c4ebf..e7f9ea531a 100644 --- a/db/tests/conftest.py +++ b/db/tests/conftest.py @@ -1,10 +1,12 @@ import os import pytest -from sqlalchemy import MetaData, text +from sqlalchemy import MetaData, text, Table from db import constants, types from db.tables.operations.split import extract_columns_from_table +from db.engine import _add_custom_types_to_engine +from db.types import install APP_SCHEMA = "test_schema" @@ -12,6 +14,7 @@ FILE_DIR = os.path.abspath(os.path.dirname(__file__)) RESOURCES = os.path.join(FILE_DIR, "resources") ROSTER_SQL = os.path.join(RESOURCES, "roster_create.sql") +URIS_SQL = os.path.join(RESOURCES, "uris_create.sql") FILTER_SORT_SQL = os.path.join(RESOURCES, "filter_sort_create.sql") @@ -34,6 +37,17 @@ def engine_with_roster(engine_with_schema): return engine, schema +@pytest.fixture +def engine_with_uris(engine_with_schema): + engine, schema = engine_with_schema + _add_custom_types_to_engine(engine) + install.install_mathesar_on_database(engine) + with engine.begin() as conn, open(URIS_SQL) as f: + conn.execute(text(f"SET search_path={schema}")) + conn.execute(text(f.read())) + return engine, schema + + @pytest.fixture def engine_with_filter_sort(engine_with_schema): engine, schema = engine_with_schema @@ -49,6 +63,11 @@ def roster_table_name(): return "Roster" +@pytest.fixture(scope='session') +def uris_table_name(): + return "uris" + + @pytest.fixture(scope='session') def teachers_table_name(): return "Teachers" @@ -86,3 +105,11 @@ def extracted_remainder_roster(engine_with_roster, roster_table_name, roster_ext roster_no_teachers = metadata.tables[f"{schema}.{roster_no_teachers_table_name}"] roster = metadata.tables[f"{schema}.{roster_table_name}"] return teachers, roster_no_teachers, roster, engine, schema + + +@pytest.fixture +def roster_table_obj(engine_with_roster, roster_table_name): + engine, schema = engine_with_roster + metadata = MetaData(bind=engine) + roster = Table(roster_table_name, metadata, schema=schema, autoload_with=engine) + return roster, engine diff --git a/db/tests/functions/operations/test_filter.py b/db/tests/functions/operations/test_filter.py index 5a65982d1d..e43770716c 100644 --- a/db/tests/functions/operations/test_filter.py +++ b/db/tests/functions/operations/test_filter.py @@ -4,7 +4,10 @@ from db.utils import execute_query from db.functions.base import ( - ColumnName, Not, Literal, Empty, Equal, Greater, And, Or + ColumnName, Not, Literal, Empty, Equal, Greater, And, Or, StartsWith, Contains, StartsWithCaseInsensitive, ContainsCaseInsensitive +) +from db.functions.redundant import ( + GreaterOrEqual, LesserOrEqual ) from db.functions.operations.apply import apply_db_function_as_filter @@ -25,6 +28,10 @@ def _ilike(x, v): "and": lambda x: And(x), "or": lambda x: Or(x), "not": lambda x: Not(x), + "ge": lambda x, v: GreaterOrEqual([ColumnName([x]), Literal([v])]), + "le": lambda x, v: LesserOrEqual([ColumnName([x]), Literal([v])]), + "starts_with": lambda x, v: StartsWith([ColumnName([x]), Literal([v])]), + "contains": lambda x, v: Contains([ColumnName([x]), Literal([v])]), } @@ -46,11 +53,18 @@ def _ilike(x, v): "not_any": lambda x, v: v not in x, "and": lambda x: all(x), "or": lambda x: any(x), - "not": lambda x: not x[0] + "not": lambda x: not x[0], + "starts_with": lambda x, v: x.startswith(v), + "contains": lambda x, v: v in x, } +# Table can be examined in db/tests/resources/filter_sort_create.sql ops_test_list = [ + # starts_with + ("varchar", "starts_with", "string9", 11), + # contains + ("varchar", "contains", "g9", 11), # is_null ("varchar", "is_null", None, 5), ("numeric", "is_null", None, 5), @@ -69,7 +83,15 @@ def _ilike(x, v): # gt ("varchar", "gt", "string0", 100), ("numeric", "gt", 50, 50), - ("date", "gt", "2000-01-01", 99), + ("date", "gt", "2000-01-01 AD", 99), + # ge + ("varchar", "ge", "string1", 100), + ("numeric", "ge", 50, 51), + ("date", "ge", "2000-01-01 AD", 100), + # le + ("varchar", "le", "string2", 13), + ("numeric", "le", 51, 51), + ("date", "le", "2099-01-01 AD", 100), ] @@ -98,7 +120,8 @@ def test_filter_with_db_functions( assert len(record_list) == res_len for record in record_list: val_func = op_to_python_func[op] - assert val_func(getattr(record, column_name), value) + actual_value = getattr(record, column_name) + assert val_func(actual_value, value) boolean_ops_test_list = [ @@ -166,3 +189,21 @@ def test_filtering_nested_boolean_ops(filter_sort_table_obj): for record in record_list: assert ((record.varchar == "string24" or record.numeric == 42) and (record.varchar == "string42" or record.numeric == 24)) + + +@pytest.mark.parametrize("column_name,main_db_function,literal_param,expected_count", [ + ("Student Name", StartsWithCaseInsensitive, "stephanie", 15), + ("Student Name", StartsWith, "stephanie", 0), + ("Student Name", ContainsCaseInsensitive, "JUAREZ", 5), + ("Student Name", Contains, "juarez", 0), +]) +def test_case_insensitive_filtering(roster_table_obj, column_name, main_db_function, literal_param, expected_count): + table, engine = roster_table_obj + selectable = table.select() + db_function = main_db_function([ + ColumnName([column_name]), + Literal([literal_param]), + ]) + query = apply_db_function_as_filter(selectable, db_function) + record_list = execute_query(engine, query) + assert len(record_list) == expected_count diff --git a/db/tests/records/conftest.py b/db/tests/records/conftest.py index 371d5e153d..2cd185cc4d 100644 --- a/db/tests/records/conftest.py +++ b/db/tests/records/conftest.py @@ -6,14 +6,6 @@ FILTER_SORT = "filter_sort" -@pytest.fixture -def roster_table_obj(engine_with_roster, roster_table_name): - engine, schema = engine_with_roster - metadata = MetaData(bind=engine) - roster = Table(roster_table_name, metadata, schema=schema, autoload_with=engine) - return roster, engine - - @pytest.fixture def filter_sort_table_obj(engine_with_filter_sort): engine, schema = engine_with_filter_sort diff --git a/db/tests/resources/uris_create.sql b/db/tests/resources/uris_create.sql new file mode 100644 index 0000000000..a3380bdb3c --- /dev/null +++ b/db/tests/resources/uris_create.sql @@ -0,0 +1,43 @@ +CREATE TABLE "uris" ( + id integer NOT NULL, + "uri" character varying(250) +); + +CREATE SEQUENCE "uris_id_seq" + AS integer + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; + +ALTER SEQUENCE "uris_id_seq" OWNED BY "uris".id; + +ALTER TABLE ONLY "uris" ALTER COLUMN id SET DEFAULT nextval('"uris_id_seq"'::regclass); + +INSERT INTO "uris" VALUES +(1, 'http://soundcloud.com/denzo-1/denzo-in-mix-0knackpunkt-nr-15-0-electro-swing'), +(2, 'http://picasaweb.google.com/lh/photo/94RGMDCSTmCW04l6SPnteTBPFtERcSvqpRI6vP3N6YI?feat=embedwebsite'), +(3, 'http://banedon.posterous.com/bauforstschritt-2262010'), +(4, 'http://imgur.com/M2v2H.png'), +(5, 'http://tweetphoto.com/31300678'), +(6, 'http://www.youtube.com/watch?v=zXLGHyGxY2E'), +(7, 'http://tweetphoto.com/31103212'), +(8, 'http://soundcloud.com/dj-soro'), +(9, 'http://i.imgur.com/H6yyu.jpg'), +(10, 'http://www.flickr.com/photos/jocke66/4657443374/'), +(11, 'http://tweetphoto.com/31332311'), +(12, 'http://tweetphoto.com/31421017'), +(13, 'http://yfrog.com/j6cimg3038gj'), +(14, 'http://yfrog.com/msradon2p'), +(15, 'http://soundcloud.com/hedo/hedo-der-groove-junger-knospen'), +(16, 'http://soundcloud.com/strawberryhaze/this-is-my-house-in-summer-2010'), +(17, 'http://tumblr.com/x4acyiuxf'), +(18, 'ftp://foobar.com/179179'), +(19, 'ftps://asldp.com/158915'), +(20, 'ftp://abcdefg.com/x-y-z'); + +SELECT pg_catalog.setval('"uris_id_seq"', 1000, true); + +ALTER TABLE ONLY "uris" + ADD CONSTRAINT "uris_pkey" PRIMARY KEY (id); diff --git a/db/tests/types/fixtures.py b/db/tests/types/fixtures.py index 38633997ca..cdff6a2b99 100644 --- a/db/tests/types/fixtures.py +++ b/db/tests/types/fixtures.py @@ -7,9 +7,11 @@ """ import pytest +from sqlalchemy import MetaData, Table from sqlalchemy.schema import CreateSchema, DropSchema from db.engine import _add_custom_types_to_engine from db.types import base, install +from db.columns.operations.alter import alter_column_type TEST_SCHEMA = "test_schema" @@ -37,3 +39,24 @@ def engine_email_type(temporary_testing_schema): yield engine, schema with engine.begin() as conn: conn.execute(DropSchema(base.SCHEMA, cascade=True, if_exists=True)) + + +@pytest.fixture +def uris_table_obj(engine_with_uris, uris_table_name): + engine, schema = engine_with_uris + metadata = MetaData(bind=engine) + table = Table(uris_table_name, metadata, schema=schema, autoload_with=engine) + # Cast "uri" column from string to URI + with engine.begin() as conn: + uri_column_name = "uri" + uri_type_id = "uri" + alter_column_type( + table, + uri_column_name, + engine, + conn, + uri_type_id, + ) + yield table, engine + with engine.begin() as conn: + conn.execute(DropSchema(base.SCHEMA, cascade=True, if_exists=True)) diff --git a/db/tests/types/test_uri.py b/db/tests/types/test_uri.py index e45a05e7f4..3393890f52 100644 --- a/db/tests/types/test_uri.py +++ b/db/tests/types/test_uri.py @@ -5,12 +5,16 @@ from db.engine import _add_custom_types_to_engine from db.tests.types import fixtures from db.types import uri +from db.utils import execute_query +from db.functions.base import ColumnName, Literal +from db.functions.operations.apply import apply_db_function_as_filter # We need to set these variables when the file loads, or pytest can't # properly detect the fixtures. Importing them directly results in a # flake8 unused import error, and a bunch of flake8 F811 errors. engine_with_types = fixtures.engine_with_types +uris_table_obj = fixtures.uris_table_obj engine_email_type = fixtures.engine_email_type temporary_testing_schema = fixtures.temporary_testing_schema @@ -206,3 +210,22 @@ def test_uri_type_domain_rejects_malformed_uris(engine_email_type, test_str): with engine.begin() as conn: conn.execute(text(f"SELECT '{test_str}'::{uri.DB_TYPE}")) assert type(e.orig) == CheckViolation + + +@pytest.mark.parametrize("main_db_function,literal_param,expected_count", [ + (uri.URIAuthorityContains, "soundcloud", 4), + (uri.URIAuthorityContains, "http", 0), + (uri.URISchemeEquals, "ftp", 2), + (uri.Contains, ".com/31421017", 1), +]) +def test_uri_db_functions(uris_table_obj, main_db_function, literal_param, expected_count): + table, engine = uris_table_obj + selectable = table.select() + uris_column_name = "uri" + db_function = main_db_function([ + ColumnName([uris_column_name]), + Literal([literal_param]), + ]) + query = apply_db_function_as_filter(selectable, db_function) + record_list = execute_query(engine, query) + assert len(record_list) == expected_count diff --git a/db/types/base.py b/db/types/base.py index a745eb0777..21b6215c44 100644 --- a/db/types/base.py +++ b/db/types/base.py @@ -86,6 +86,18 @@ class MathesarCustomType(Enum): known_db_types = _known_vanilla_db_types + _known_custom_db_types +# Origin: https://www.python.org/dev/peps/pep-0616/#id17 +def _remove_prefix(self: str, prefix: str, /) -> str: + """ + This will remove the passed prefix, if it's there. + Otherwise, it will return the string unchanged. + """ + if self.startswith(prefix): + return self[len(prefix):] + else: + return self[:] + + def get_db_type_enum_from_id(db_type_id): """ Gets an instance of either the PostgresType enum or the MathesarCustomType enum corresponding @@ -96,30 +108,33 @@ def get_db_type_enum_from_id(db_type_id): return PostgresType(db_type_id) except ValueError: try: - return MathesarCustomType(db_type_id) + # Sometimes MA type identifiers are qualified like so: `mathesar_types.uri`. + # We want to remove that prefix, when it's there, because MathesarCustomType + # enum stores type ids without a qualifier (e.g. `uri`). + possible_prefix = _ma_type_qualifier_prefix + '.' + preprocessed_db_type_id = _remove_prefix(db_type_id, possible_prefix) + return MathesarCustomType(preprocessed_db_type_id) except ValueError: return None def _build_db_types_hinted(): + """ + Builds up a map of db types to hintsets. + """ + # Start out by defining some hints manually. db_types_hinted = { PostgresType.BOOLEAN: tuple([ hints.boolean ]), - PostgresType.CHARACTER_VARYING: tuple([ - hints.string_like - ]), - PostgresType.CHARACTER: tuple([ - hints.string_like - ]), - PostgresType.TEXT: tuple([ - hints.string_like - ]), MathesarCustomType.URI: tuple([ hints.uri ]), } + # Then, start adding hints automatically. + # This is for many-to-many relationships, i.e. adding multiple identical hintsets to the + # hintsets of multiple db types. def _add_to_db_type_hintsets(db_types, hints): """ Mutates db_types_hinted to map every hint in `hints` to every DB type in `db_types`. @@ -131,10 +146,22 @@ def _add_to_db_type_hintsets(db_types, hints): else: db_types_hinted[db_type] = tuple(hints) + # all types get the "any" hint all_db_types = known_db_types hints_for_all_db_types = (hints.any,) _add_to_db_type_hintsets(all_db_types, hints_for_all_db_types) + # string-like types get the "string_like" hint + string_like_db_types = ( + PostgresType.CHARACTER_VARYING, + PostgresType.CHARACTER, + PostgresType.TEXT, + MathesarCustomType.URI, + ) + hints_for_string_like_types = (hints.string_like,) + _add_to_db_type_hintsets(string_like_db_types, hints_for_string_like_types) + + # numeric types get the "comparable" hint numeric_db_types = ( PostgresType.BIGINT, PostgresType.DECIMAL, @@ -160,8 +187,12 @@ def _add_to_db_type_hintsets(db_types, hints): preparer = create_engine("postgresql://").dialect.identifier_preparer +# Should usually equal `mathesar_types` +_ma_type_qualifier_prefix = preparer.quote_schema(SCHEMA) + + def get_qualified_name(name): - return ".".join([preparer.quote_schema(SCHEMA), name]) + return ".".join([_ma_type_qualifier_prefix, name]) def get_available_types(engine): diff --git a/db/types/uri.py b/db/types/uri.py index b36dc81a88..273f804659 100644 --- a/db/types/uri.py +++ b/db/types/uri.py @@ -1,12 +1,13 @@ from enum import Enum import os -from sqlalchemy import text, Text, Table, Column, String, MetaData, func +from sqlalchemy import text, Text, Table, Column, String, MetaData from sqlalchemy.sql import quoted_name from sqlalchemy.sql.functions import GenericFunction from sqlalchemy.types import UserDefinedType from db.functions import hints -from db.functions.base import DBFunction +from db.functions.base import DBFunction, Contains, sa_call_sql_function, Equal +from db.functions.redundant import RedundantDBFunction from db.types import base @@ -129,7 +130,7 @@ def install_tld_lookup_table(engine): class ExtractURIAuthority(DBFunction): id = 'extract_uri_authority' - name = 'Extract URI Authority' + name = 'extract URI authority' hints = tuple([ hints.parameter_count(1), hints.parameter(1, hints.uri), @@ -138,4 +139,60 @@ class ExtractURIAuthority(DBFunction): @staticmethod def to_sa_expression(uri): - return func.getattr(URIFunction.AUTHORITY)(uri) + return sa_call_sql_function(URIFunction.AUTHORITY.value, uri) + + +class ExtractURIScheme(DBFunction): + id = 'extract_uri_scheme' + name = 'extract URI scheme' + hints = tuple([ + hints.parameter_count(1), + hints.parameter(1, hints.uri), + ]) + depends_on = tuple([URIFunction.SCHEME]) + + @staticmethod + def to_sa_expression(uri): + return sa_call_sql_function(URIFunction.SCHEME.value, uri) + + +class URIAuthorityContains(RedundantDBFunction): + id = 'uri_authority_contains' + name = 'URI authority contains' + hints = tuple([ + hints.returns(hints.boolean), + hints.parameter_count(2), + hints.parameter(0, hints.uri), + hints.parameter(1, hints.string_like), + hints.mathesar_filter, + ]) + depends_on = tuple([URIFunction.AUTHORITY]) + + def unpack(self): + param0 = self.parameters[0] + param1 = self.parameters[1] + return Contains([ + ExtractURIAuthority([param0]), + param1, + ]) + + +class URISchemeEquals(RedundantDBFunction): + id = 'uri_scheme_equals' + name = 'URI scheme equals' + hints = tuple([ + hints.returns(hints.boolean), + hints.parameter_count(2), + hints.parameter(0, hints.uri), + hints.parameter(1, hints.string_like), + hints.mathesar_filter, + ]) + depends_on = tuple([URIFunction.SCHEME]) + + def unpack(self): + param0 = self.parameters[0] + param1 = self.parameters[1] + return Equal([ + ExtractURIScheme([param0]), + param1, + ]) diff --git a/mathesar/database/types.py b/mathesar/database/types.py index bf0908ef21..715b33186d 100644 --- a/mathesar/database/types.py +++ b/mathesar/database/types.py @@ -152,7 +152,7 @@ def get_ma_types_mapped_to_hintsets(engine): # TODO ma_type = get_ma_type_enum_from_id(ma_type_description['identifier']) associated_db_type_descriptions = ma_type_description['db_types'] - associated_db_types = ( + associated_db_types = tuple( # TODO why is db_type_descriptions a list that seems to always have one element? get_db_type_enum_from_id(db_type_descriptions[0]["sa_type_name"]) for _db_type_id, db_type_descriptions in associated_db_type_descriptions.items() diff --git a/mathesar/models.py b/mathesar/models.py index e691a7e0ed..da90900945 100644 --- a/mathesar/models.py +++ b/mathesar/models.py @@ -181,6 +181,10 @@ def save(self, *args, **kwargs): self.validate_unique() super().save(*args, **kwargs) + @cached_property + def _sa_engine(self): + return self.schema._sa_engine + @cached_property def _sa_table(self): try: diff --git a/mathesar/tests/api/conftest.py b/mathesar/tests/api/conftest.py index 659a70ac11..af61a73524 100644 --- a/mathesar/tests/api/conftest.py +++ b/mathesar/tests/api/conftest.py @@ -1,19 +1,14 @@ from django.core.files import File import pytest from rest_framework.test import APIClient -from sqlalchemy import Column, MetaData, text, Integer -from sqlalchemy import Table as SATable +from sqlalchemy import text -from db.types import base, install -from db.tables.operations.select import get_oid_from_table from mathesar.database.base import create_mathesar_engine from mathesar.imports.csv import create_table_from_csv -from mathesar.models import Table, DataFile +from mathesar.models import DataFile TEST_SCHEMA = 'import_csv_schema' -PATENT_SCHEMA = 'Patents' -NASA_TABLE = 'NASA Schema List' @pytest.fixture @@ -43,35 +38,6 @@ def _create_table(table_name, schema='Data Types'): return _create_table -@pytest.fixture -def patent_schema(test_db_model, create_schema): - engine = create_mathesar_engine(test_db_model.name) - install.install_mathesar_on_database(engine) - with engine.begin() as conn: - conn.execute(text(f'DROP SCHEMA IF EXISTS "{PATENT_SCHEMA}" CASCADE;')) - yield create_schema(PATENT_SCHEMA) - with engine.begin() as conn: - conn.execute(text(f'DROP SCHEMA {base.SCHEMA} CASCADE;')) - - -@pytest.fixture -def empty_nasa_table(patent_schema): - engine = create_mathesar_engine(patent_schema.database.name) - db_table = SATable( - NASA_TABLE, MetaData(bind=engine), - Column('id', Integer, primary_key=True), - schema=patent_schema.name, - ) - db_table.create() - db_table_oid = get_oid_from_table(db_table.name, db_table.schema, engine) - table = Table.current_objects.create(oid=db_table_oid, schema=patent_schema) - - yield table - - table.delete_sa_table() - table.delete() - - @pytest.fixture def table_for_reflection(test_db_name): engine = create_mathesar_engine(test_db_name) diff --git a/mathesar/tests/conftest.py b/mathesar/tests/conftest.py index 52c923befd..7c590b013d 100644 --- a/mathesar/tests/conftest.py +++ b/mathesar/tests/conftest.py @@ -1,16 +1,25 @@ -from django.core.files import File """ This inherits the fixtures in the root conftest.py """ import pytest -from sqlalchemy import text +from django.core.files import File + +from sqlalchemy import Column, MetaData, text, Integer +from sqlalchemy import Table as SATable +from db.types import base, install from db.schemas.operations.create import create_schema as create_sa_schema from db.schemas.utils import get_schema_oid_from_name, get_schema_name_from_oid +from db.tables.operations.select import get_oid_from_table + +from mathesar.models import Schema, Table, Database, DataFile +from mathesar.database.base import create_mathesar_engine from mathesar.imports.csv import create_table_from_csv -from mathesar.models import Database -from mathesar.models import Schema, DataFile + + +PATENT_SCHEMA = 'Patents' +NASA_TABLE = 'NASA Schema List' @pytest.fixture(scope="session", autouse=True) @@ -121,6 +130,35 @@ def _create_schema(schema_name): conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE;')) +@pytest.fixture +def patent_schema(test_db_model, create_schema): + engine = create_mathesar_engine(test_db_model.name) + install.install_mathesar_on_database(engine) + with engine.begin() as conn: + conn.execute(text(f'DROP SCHEMA IF EXISTS "{PATENT_SCHEMA}" CASCADE;')) + yield create_schema(PATENT_SCHEMA) + with engine.begin() as conn: + conn.execute(text(f'DROP SCHEMA {base.SCHEMA} CASCADE;')) + + +@pytest.fixture +def empty_nasa_table(patent_schema): + engine = create_mathesar_engine(patent_schema.database.name) + db_table = SATable( + NASA_TABLE, MetaData(bind=engine), + Column('id', Integer, primary_key=True), + schema=patent_schema.name, + ) + db_table.create() + db_table_oid = get_oid_from_table(db_table.name, db_table.schema, engine) + table = Table.current_objects.create(oid=db_table_oid, schema=patent_schema) + + yield table + + table.delete_sa_table() + table.delete() + + @pytest.fixture def create_table(csv_filename, create_schema): with open(csv_filename, 'rb') as csv_file: diff --git a/mathesar/tests/filters/test_filters.py b/mathesar/tests/filters/test_filters.py index 482aff306c..1ee404bf4e 100644 --- a/mathesar/tests/filters/test_filters.py +++ b/mathesar/tests/filters/test_filters.py @@ -1,13 +1,21 @@ from mathesar.filters.base import get_available_filters -def test_available_filters_structure(test_db_model): - engine = test_db_model._sa_engine +def test_available_filters_structure(empty_nasa_table): + engine = empty_nasa_table._sa_engine available_filters = get_available_filters(engine) assert len(available_filters) > 0 available_filter_ids = tuple(filter['id'] for filter in available_filters) some_filters_that_we_expect_to_be_there = [ - 'greater', 'lesser', 'empty', 'equal', 'greater_or_equal', 'starts_with', + 'greater', + 'lesser', + 'empty', + 'equal', + 'greater_or_equal', + 'contains_case_insensitive', + 'starts_with_case_insensitive', + 'uri_authority_contains', + 'uri_scheme_equals', ] expected_filters_are_available = set.issubset( set(some_filters_that_we_expect_to_be_there),