Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix grouping counts #353

Merged
merged 7 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions db/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,22 @@ def get_group_counts(
if field_name not in table.c:
raise GroupFieldNotFound(f"Group field {field} not found in {table}.")

group_by = _create_col_objects(table, group_by)
query = (
select(*group_by, func.count(table.c[ID]))
.group_by(*group_by)
select(table)
.limit(limit)
.offset(offset)
)
if order_by is not None:
query = apply_sort(query, order_by)
if filters is not None:
query = apply_filters(query, filters)
subquery = query.subquery()

group_by = [
subquery.columns[col] if type(col) == str else subquery.columns[col.name]
for col in group_by
]
query = select(*group_by, func.count(subquery.c[ID])).group_by(*group_by)
with engine.begin() as conn:
records = conn.execute(query).fetchall()

Expand Down
94 changes: 64 additions & 30 deletions db/tests/records/test_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from sqlalchemy import select
from sqlalchemy_filters import apply_sort, apply_filters

from db import records
from db.records import GroupFieldNotFound, BadGroupFormat
Expand Down Expand Up @@ -32,35 +33,38 @@ def test_get_group_counts_mixed_str_col_field(filter_sort_table_obj):
assert ("string1", 1) in counts


def test_get_group_counts_limit_ordering(filter_sort_table_obj):
filter_sort, engine = filter_sort_table_obj
limit = 50
order_by = [{"field": "numeric", "direction": "desc", "nullslast": True}]
group_by = [filter_sort.c.numeric]
counts = records.get_group_counts(filter_sort, engine, group_by, limit=limit,
order_by=order_by)
assert len(counts) == 50
for i in range(1, 100):
if i > 50:
assert (i,) in counts
else:
assert (i,) not in counts


def test_get_group_counts_limit_offset_ordering(filter_sort_table_obj):
filter_sort, engine = filter_sort_table_obj
offset = 25
limit = 50
order_by = [{"field": "numeric", "direction": "desc", "nullslast": True}]
group_by = [filter_sort.c.numeric]
counts = records.get_group_counts(filter_sort, engine, group_by, limit=limit,
limit_offset_test_list = [
(limit, offset)
for limit in [None, 0, 25, 50, 100]
for offset in [None, 0, 25, 50, 100]
]


@pytest.mark.parametrize("limit,offset", limit_offset_test_list)
def test_get_group_counts_limit_offset_ordering(roster_table_obj, limit, offset):
roster, engine = roster_table_obj
order_by = [{"field": "Grade", "direction": "desc", "nullslast": True}]
group_by = [roster.c["Grade"]]
counts = records.get_group_counts(roster, engine, group_by, limit=limit,
offset=offset, order_by=order_by)
assert len(counts) == 50
for i in range(1, 100):
if i > 25 and i <= 75:
assert (i,) in counts
else:
assert (i,) not in counts

query = select(group_by[0])
query = apply_sort(query, order_by)
with engine.begin() as conn:
all_records = list(conn.execute(query))
if limit is None:
end = None
elif offset is None:
end = limit
else:
end = limit + offset
limit_offset_records = all_records[offset:end]
manual_count = Counter(limit_offset_records)

assert len(counts) == len(manual_count)
for value, count in manual_count.items():
assert value in counts
assert counts[value] == count


count_values_test_list = itertools.chain(*[
Expand All @@ -84,8 +88,38 @@ def test_get_group_counts_count_values(roster_table_obj, group_by):
all_records = conn.execute(select(*cols)).fetchall()
manual_count = Counter(all_records)

for key, value in counts.items():
assert manual_count[key] == value
for value, count in manual_count.items():
assert value in counts
assert counts[value] == count


filter_values_test_list = itertools.chain(*[
itertools.combinations([
{"field": "Student Name", "op": "ge", "value": "Test Name"},
{"field": "Student Email", "op": "le", "value": "Test Email"},
{"field": "Teacher Email", "op": "like", "value": "%gmail.com"},
{"field": "Subject", "op": "eq", "value": "Non-Existent Subject"},
{"field": "Grade", "op": "ne", "value": 99}
], i) for i in range(1, 3)
])


@pytest.mark.parametrize("filter_by", filter_values_test_list)
def test_get_group_counts_filter_values(roster_table_obj, filter_by):
roster, engine = roster_table_obj
group_by = ["Student Name"]
counts = records.get_group_counts(roster, engine, group_by, filters=filter_by)

cols = [roster.c[f] for f in group_by]
query = select(*cols)
query = apply_filters(query, filter_by)
with engine.begin() as conn:
all_records = conn.execute(query).fetchall()
manual_count = Counter(all_records)

for value, count in manual_count.items():
assert value in counts
assert counts[value] == count


exceptions_test_list = [
Expand Down
3 changes: 2 additions & 1 deletion mathesar/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def paginate_queryset(self, queryset, request, table_id,
filters=filters, order_by=order_by
)
# Convert the tuple keys into strings so it can be converted to JSON
group_count = {','.join(k): v for k, v in group_count.items()}
group_count = [{"values": list(cols), "count": count}
for cols, count in group_count.items()]
self.group_count = {
'group_count_by': group_count_by,
'results': group_count,
Expand Down
23 changes: 15 additions & 8 deletions mathesar/tests/views/api/test_record_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,18 @@ def test_record_list_sort(create_table, client):


def _test_record_list_group(table, client, group_count_by, expected_groups):
order_by = [
{'field': 'Center', 'direction': 'desc'},
{'field': 'Case Number', 'direction': 'asc'},
]
json_order_by = json.dumps(order_by)
json_group_count_by = json.dumps(group_count_by)
query_str = f'group_count_by={json_group_count_by}&order_by={json_order_by}'

with patch.object(
records, "get_group_counts", side_effect=records.get_group_counts
) as mock_infer:
response = client.get(
f'/api/v0/tables/{table.id}/records/?group_count_by={json_group_count_by}'
)
response = client.get(f'/api/v0/tables/{table.id}/records/?{query_str}')
response_data = response.json()

assert response.status_code == 200
Expand All @@ -133,10 +137,13 @@ def _test_record_list_group(table, client, group_count_by, expected_groups):
assert 'group_count' in response_data
assert response_data['group_count']['group_count_by'] == group_count_by
assert 'results' in response_data['group_count']
assert 'values' in response_data['group_count']['results'][0]
assert 'count' in response_data['group_count']['results'][0]

results = response_data['group_count']['results']
returned_groups = {tuple(group['values']) for group in results}
for expected_group in expected_groups:
assert expected_group in results
assert expected_group in returned_groups

assert mock_infer.call_args is not None
assert mock_infer.call_args[0][2] == group_count_by
Expand All @@ -147,8 +154,8 @@ def test_record_list_group_single_column(create_table, client):
table = create_table(table_name)
group_count_by = ['Center']
expected_groups = [
'NASA Ames Research Center',
'NASA Kennedy Space Center'
('NASA Marshall Space Flight Center',),
('NASA Stennis Space Center',)
]
_test_record_list_group(table, client, group_count_by, expected_groups)

Expand All @@ -158,8 +165,8 @@ def test_record_list_group_multi_column(create_table, client):
table = create_table(table_name)
group_count_by = ['Center', 'Status']
expected_groups = [
'NASA Ames Research Center,Issued',
'NASA Kennedy Space Center,Issued',
('NASA Marshall Space Flight Center', 'Issued'),
('NASA Stennis Space Center', 'Issued'),
]
_test_record_list_group(table, client, group_count_by, expected_groups)

Expand Down