Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Record grouping #315

Merged
merged 21 commits into from
Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 72 additions & 11 deletions db/records.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,37 @@
import logging
from sqlalchemy import delete, select, Column
from sqlalchemy import delete, select, Column, func
from sqlalchemy.inspection import inspect
from sqlalchemy_filters import apply_filters, apply_sort
from sqlalchemy_filters.exceptions import FieldNotFound

from db.constants import ID

logger = logging.getLogger(__name__)


# Grouping exceptions follow the sqlalchemy_filters exceptions patterns
class BadGroupFormat(Exception):
pass


class GroupFieldNotFound(FieldNotFound):
pass


def _get_primary_key_column(table):
primary_key_list = list(inspect(table).primary_key)
# We do not support getting by composite primary keys
assert len(primary_key_list) == 1
return primary_key_list[0]


def _create_col_objects(table, column_list):
return [
table.columns[col] if type(col) == str else col
for col in column_list
]


def get_record(table, engine, id_value):
primary_key_column = _get_primary_key_column(table)
query = select(table).where(primary_key_column == id_value)
Expand All @@ -23,7 +42,7 @@ def get_record(table, engine, id_value):


def get_records(
table, engine, limit=None, offset=None, order_by=[], filters=[]
table, engine, limit=None, offset=None, order_by=[], filters=[],
):
"""
Returns records from a table.
Expand All @@ -40,18 +59,63 @@ def get_records(
field, in addition to an 'value' field if appropriate.
See: https://github.com/centerofci/sqlalchemy-filters#filters-format
"""
query = select(table)
if order_by:
query = select(table).limit(limit).offset(offset)
if order_by is not None:
query = apply_sort(query, order_by)
if filters is not None:
query = apply_filters(query, filters)
with engine.begin() as conn:
return conn.execute(query).fetchall()


def get_group_counts(
table, engine, group_by, limit=None, offset=None, order_by=[], filters=[],
):
"""
Returns counts by specified groupings

Args:
table: SQLAlchemy table object
engine: SQLAlchemy engine object
limit: int, gives number of rows to return
offset: int, gives number of rows to skip
group_by: list or tuple of column names or column objects to group by
order_by: list of dictionaries, where each dictionary has a 'field' and
'direction' field.
See: https://github.com/centerofci/sqlalchemy-filters#sort-format
filters: list of dictionaries, where each dictionary has a 'field' and 'op'
field, in addition to an 'value' field if appropriate.
See: https://github.com/centerofci/sqlalchemy-filters#filters-format
"""
if type(group_by) not in (tuple, list):
raise BadGroupFormat(f"Group spec {group_by} must be list or tuple.")
for field in group_by:
if type(field) not in (str, Column):
raise BadGroupFormat(f"Group field {field} must be a string or Column.")
field_name = field if type(field) == str else field.name
if field_name not in table.c:
raise GroupFieldNotFound(f"Group field {field} not found in {table}.")

group_by = _create_col_objects(table, group_by)
query = (
query
select(*group_by, func.count(table.c[ID]))
.group_by(*group_by)
.limit(limit)
.offset(offset)
)
if filters:
if order_by is not None:
query = apply_sort(query, order_by)
if filters is not None:
query = apply_filters(query, filters)
with engine.begin() as conn:
return conn.execute(query).fetchall()
records = conn.execute(query).fetchall()

# Last field is the count, preceding fields are the group by fields
counts = {
(*record[:-1],): record[-1]
for record in records
}
return counts


def get_distinct_tuple_values(
Expand All @@ -71,10 +135,7 @@ def get_distinct_tuple_values(
SQLAlchemy column objects associated with a table.
"""
if table is not None:
column_objects = [
table.columns[col] if type(col) == str else col
for col in column_list
]
column_objects = _create_col_objects(table, column_list)
else:
column_objects = column_list
try:
Expand Down
23 changes: 23 additions & 0 deletions db/tests/records/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from sqlalchemy import MetaData, Table


ROSTER = "Roster"
FILTER_SORT = "filter_sort"


@pytest.fixture
def roster_table_obj(engine_with_roster):
engine, schema = engine_with_roster
metadata = MetaData(bind=engine)
roster = Table(ROSTER, metadata, schema=schema, autoload_with=engine)
return roster, engine


@pytest.fixture
def filter_sort_table_obj(engine_with_filter_sort):
engine, schema = engine_with_filter_sort
metadata = MetaData(bind=engine)
roster = Table(FILTER_SORT, metadata, schema=schema, autoload_with=engine)
return roster, engine
21 changes: 0 additions & 21 deletions db/tests/records/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,11 @@
import pytest
from datetime import datetime

from sqlalchemy import MetaData, Table
from sqlalchemy_filters.exceptions import BadFilterFormat, FilterFieldNotFound

from db import records


ROSTER = "Roster"
FILTER_SORT = "filter_sort"


@pytest.fixture
def roster_table_obj(engine_with_roster):
engine, schema = engine_with_roster
metadata = MetaData(bind=engine)
roster = Table(ROSTER, metadata, schema=schema, autoload_with=engine)
return roster, engine


@pytest.fixture
def filter_sort_table_obj(engine_with_filter_sort):
engine, schema = engine_with_filter_sort
metadata = MetaData(bind=engine)
roster = Table(FILTER_SORT, metadata, schema=schema, autoload_with=engine)
return roster, engine


def test_get_records_filters_using_col_str_names(roster_table_obj):
roster, engine = roster_table_obj
filter_list = [
Expand Down
103 changes: 103 additions & 0 deletions db/tests/records/test_grouping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import itertools
from collections import Counter

import pytest
from sqlalchemy import select

from db import records
from db.records import GroupFieldNotFound, BadGroupFormat


def test_get_group_counts_str_field(filter_sort_table_obj):
filter_sort, engine = filter_sort_table_obj
group_by = ["varchar"]
counts = records.get_group_counts(filter_sort, engine, group_by)
assert len(counts) == 101
assert ("string1",) in counts


def test_get_group_counts_col_field(filter_sort_table_obj):
filter_sort, engine = filter_sort_table_obj
group_by = [filter_sort.c.varchar]
counts = records.get_group_counts(filter_sort, engine, group_by)
assert len(counts) == 101
assert ("string1",) in counts


def test_get_group_counts_mixed_str_col_field(filter_sort_table_obj):
filter_sort, engine = filter_sort_table_obj
group_by = ["varchar", filter_sort.c.numeric]
counts = records.get_group_counts(filter_sort, engine, group_by)
assert len(counts) == 101
assert ("string1", 1) in counts


def test_get_group_counts_limit_ordering(filter_sort_table_obj):
filter_sort, engine = filter_sort_table_obj
limit = 50
order_by = [{"field": "numeric", "direction": "desc", "nullslast": True}]
group_by = [filter_sort.c.numeric]
counts = records.get_group_counts(filter_sort, engine, group_by, limit=limit,
order_by=order_by)
assert len(counts) == 50
for i in range(1, 100):
if i > 50:
assert (i,) in counts
else:
assert (i,) not in counts


def test_get_group_counts_limit_offset_ordering(filter_sort_table_obj):
filter_sort, engine = filter_sort_table_obj
offset = 25
limit = 50
order_by = [{"field": "numeric", "direction": "desc", "nullslast": True}]
group_by = [filter_sort.c.numeric]
counts = records.get_group_counts(filter_sort, engine, group_by, limit=limit,
offset=offset, order_by=order_by)
assert len(counts) == 50
for i in range(1, 100):
if i > 25 and i <= 75:
assert (i,) in counts
else:
assert (i,) not in counts


count_values_test_list = itertools.chain(*[
itertools.combinations([
"Student Name",
"Student Email",
"Teacher Email",
"Subject",
"Grade"
], i) for i in range(1, 5)
])


@pytest.mark.parametrize("group_by", count_values_test_list)
def test_get_group_counts_count_values(roster_table_obj, group_by):
roster, engine = roster_table_obj
counts = records.get_group_counts(roster, engine, group_by)

cols = [roster.c[f] for f in group_by]
with engine.begin() as conn:
all_records = conn.execute(select(*cols)).fetchall()
manual_count = Counter(all_records)

for key, value in counts.items():
assert manual_count[key] == value


exceptions_test_list = [
("string", BadGroupFormat),
({"dictionary": ""}, BadGroupFormat),
([{"field": "varchar"}], BadGroupFormat),
(["non_existent_field"], GroupFieldNotFound),
]


@pytest.mark.parametrize("group_by,exception", exceptions_test_list)
def test_get_group_counts_exceptions(filter_sort_table_obj, group_by, exception):
filter_sort, engine = filter_sort_table_obj
with pytest.raises(exception):
records.get_group_counts(filter_sort, engine, group_by)
20 changes: 0 additions & 20 deletions db/tests/records/test_records.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,6 @@
import pytest
from sqlalchemy import MetaData, Table
from db import records

ROSTER = "Roster"
FILTER_SORT = "filter_sort"


@pytest.fixture
def roster_table_obj(engine_with_roster):
engine, schema = engine_with_roster
metadata = MetaData(bind=engine)
roster = Table(ROSTER, metadata, schema=schema, autoload_with=engine)
return roster, engine


@pytest.fixture
def filter_sort_table_obj(engine_with_filter_sort):
engine, schema = engine_with_filter_sort
metadata = MetaData(bind=engine)
roster = Table(FILTER_SORT, metadata, schema=schema, autoload_with=engine)
return roster, engine


def test_get_records_gets_all_records(roster_table_obj):
roster, engine = roster_table_obj
Expand Down
21 changes: 0 additions & 21 deletions db/tests/records/test_sorting.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,10 @@
import pytest

from sqlalchemy import MetaData, Table
from sqlalchemy_filters.exceptions import BadSortFormat, SortFieldNotFound

from db import records


ROSTER = "Roster"
FILTER_SORT = "filter_sort"


@pytest.fixture
def roster_table_obj(engine_with_roster):
engine, schema = engine_with_roster
metadata = MetaData(bind=engine)
roster = Table(ROSTER, metadata, schema=schema, autoload_with=engine)
return roster, engine


@pytest.fixture
def filter_sort_table_obj(engine_with_filter_sort):
engine, schema = engine_with_filter_sort
metadata = MetaData(bind=engine)
roster = Table(FILTER_SORT, metadata, schema=schema, autoload_with=engine)
return roster, engine


def test_get_records_gets_ordered_records_str_col_name(roster_table_obj):
roster, engine = roster_table_obj
order_list = [{"field": "Teacher", "direction": "asc"}]
Expand Down
1 change: 1 addition & 0 deletions mathesar/forms/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ class UploadFileForm(forms.Form):
class RecordListFilterForm(forms.Form):
filters = forms.JSONField(required=False, empty_value=[])
order_by = forms.JSONField(required=False, empty_value=[])
group_count_by = forms.JSONField(required=False, empty_value=[])
7 changes: 7 additions & 0 deletions mathesar/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ def get_records(self, limit=None, offset=None, filters=[], order_by=[]):
return records.get_records(self._sa_table, self.schema._sa_engine, limit,
offset, filters=filters, order_by=order_by)

def get_group_counts(
self, group_by, limit=None, offset=None, filters=[], order_by=[]
):
return records.get_group_counts(self._sa_table, self.schema._sa_engine,
group_by, limit, offset, filters=filters,
order_by=order_by)

def create_record_or_records(self, record_data):
return records.create_record_or_records(self._sa_table, self.schema._sa_engine, record_data)

Expand Down
Loading