Skip to content

Commit

Permalink
Merge pull request #353 from centerofci/fix_grouping_counts
Browse files Browse the repository at this point in the history
Fix grouping counts
  • Loading branch information
kgodey committed Jul 13, 2021
2 parents 0d7bf2e + 3f7b0b4 commit a6267e0
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 42 deletions.
11 changes: 8 additions & 3 deletions db/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,22 @@ def get_group_counts(
if field_name not in table.c:
raise GroupFieldNotFound(f"Group field {field} not found in {table}.")

group_by = _create_col_objects(table, group_by)
query = (
select(*group_by, func.count(table.c[ID]))
.group_by(*group_by)
select(table)
.limit(limit)
.offset(offset)
)
if order_by is not None:
query = apply_sort(query, order_by)
if filters is not None:
query = apply_filters(query, filters)
subquery = query.subquery()

group_by = [
subquery.columns[col] if type(col) == str else subquery.columns[col.name]
for col in group_by
]
query = select(*group_by, func.count(subquery.c[ID])).group_by(*group_by)
with engine.begin() as conn:
records = conn.execute(query).fetchall()

Expand Down
94 changes: 64 additions & 30 deletions db/tests/records/test_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from sqlalchemy import select
from sqlalchemy_filters import apply_sort, apply_filters

from db import records
from db.records import GroupFieldNotFound, BadGroupFormat
Expand Down Expand Up @@ -32,35 +33,38 @@ def test_get_group_counts_mixed_str_col_field(filter_sort_table_obj):
assert ("string1", 1) in counts


def test_get_group_counts_limit_ordering(filter_sort_table_obj):
filter_sort, engine = filter_sort_table_obj
limit = 50
order_by = [{"field": "numeric", "direction": "desc", "nullslast": True}]
group_by = [filter_sort.c.numeric]
counts = records.get_group_counts(filter_sort, engine, group_by, limit=limit,
order_by=order_by)
assert len(counts) == 50
for i in range(1, 100):
if i > 50:
assert (i,) in counts
else:
assert (i,) not in counts


def test_get_group_counts_limit_offset_ordering(filter_sort_table_obj):
filter_sort, engine = filter_sort_table_obj
offset = 25
limit = 50
order_by = [{"field": "numeric", "direction": "desc", "nullslast": True}]
group_by = [filter_sort.c.numeric]
counts = records.get_group_counts(filter_sort, engine, group_by, limit=limit,
limit_offset_test_list = [
(limit, offset)
for limit in [None, 0, 25, 50, 100]
for offset in [None, 0, 25, 50, 100]
]


@pytest.mark.parametrize("limit,offset", limit_offset_test_list)
def test_get_group_counts_limit_offset_ordering(roster_table_obj, limit, offset):
roster, engine = roster_table_obj
order_by = [{"field": "Grade", "direction": "desc", "nullslast": True}]
group_by = [roster.c["Grade"]]
counts = records.get_group_counts(roster, engine, group_by, limit=limit,
offset=offset, order_by=order_by)
assert len(counts) == 50
for i in range(1, 100):
if i > 25 and i <= 75:
assert (i,) in counts
else:
assert (i,) not in counts

query = select(group_by[0])
query = apply_sort(query, order_by)
with engine.begin() as conn:
all_records = list(conn.execute(query))
if limit is None:
end = None
elif offset is None:
end = limit
else:
end = limit + offset
limit_offset_records = all_records[offset:end]
manual_count = Counter(limit_offset_records)

assert len(counts) == len(manual_count)
for value, count in manual_count.items():
assert value in counts
assert counts[value] == count


count_values_test_list = itertools.chain(*[
Expand All @@ -84,8 +88,38 @@ def test_get_group_counts_count_values(roster_table_obj, group_by):
all_records = conn.execute(select(*cols)).fetchall()
manual_count = Counter(all_records)

for key, value in counts.items():
assert manual_count[key] == value
for value, count in manual_count.items():
assert value in counts
assert counts[value] == count


filter_values_test_list = itertools.chain(*[
itertools.combinations([
{"field": "Student Name", "op": "ge", "value": "Test Name"},
{"field": "Student Email", "op": "le", "value": "Test Email"},
{"field": "Teacher Email", "op": "like", "value": "%gmail.com"},
{"field": "Subject", "op": "eq", "value": "Non-Existent Subject"},
{"field": "Grade", "op": "ne", "value": 99}
], i) for i in range(1, 3)
])


@pytest.mark.parametrize("filter_by", filter_values_test_list)
def test_get_group_counts_filter_values(roster_table_obj, filter_by):
roster, engine = roster_table_obj
group_by = ["Student Name"]
counts = records.get_group_counts(roster, engine, group_by, filters=filter_by)

cols = [roster.c[f] for f in group_by]
query = select(*cols)
query = apply_filters(query, filter_by)
with engine.begin() as conn:
all_records = conn.execute(query).fetchall()
manual_count = Counter(all_records)

for value, count in manual_count.items():
assert value in counts
assert counts[value] == count


exceptions_test_list = [
Expand Down
3 changes: 2 additions & 1 deletion mathesar/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def paginate_queryset(self, queryset, request, table_id,
filters=filters, order_by=order_by
)
# Convert the tuple keys into strings so it can be converted to JSON
group_count = {','.join(k): v for k, v in group_count.items()}
group_count = [{"values": list(cols), "count": count}
for cols, count in group_count.items()]
self.group_count = {
'group_count_by': group_count_by,
'results': group_count,
Expand Down
23 changes: 15 additions & 8 deletions mathesar/tests/views/api/test_record_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,18 @@ def test_record_list_sort(create_table, client):


def _test_record_list_group(table, client, group_count_by, expected_groups):
order_by = [
{'field': 'Center', 'direction': 'desc'},
{'field': 'Case Number', 'direction': 'asc'},
]
json_order_by = json.dumps(order_by)
json_group_count_by = json.dumps(group_count_by)
query_str = f'group_count_by={json_group_count_by}&order_by={json_order_by}'

with patch.object(
records, "get_group_counts", side_effect=records.get_group_counts
) as mock_infer:
response = client.get(
f'/api/v0/tables/{table.id}/records/?group_count_by={json_group_count_by}'
)
response = client.get(f'/api/v0/tables/{table.id}/records/?{query_str}')
response_data = response.json()

assert response.status_code == 200
Expand All @@ -133,10 +137,13 @@ def _test_record_list_group(table, client, group_count_by, expected_groups):
assert 'group_count' in response_data
assert response_data['group_count']['group_count_by'] == group_count_by
assert 'results' in response_data['group_count']
assert 'values' in response_data['group_count']['results'][0]
assert 'count' in response_data['group_count']['results'][0]

results = response_data['group_count']['results']
returned_groups = {tuple(group['values']) for group in results}
for expected_group in expected_groups:
assert expected_group in results
assert expected_group in returned_groups

assert mock_infer.call_args is not None
assert mock_infer.call_args[0][2] == group_count_by
Expand All @@ -147,8 +154,8 @@ def test_record_list_group_single_column(create_table, client):
table = create_table(table_name)
group_count_by = ['Center']
expected_groups = [
'NASA Ames Research Center',
'NASA Kennedy Space Center'
('NASA Marshall Space Flight Center',),
('NASA Stennis Space Center',)
]
_test_record_list_group(table, client, group_count_by, expected_groups)

Expand All @@ -158,8 +165,8 @@ def test_record_list_group_multi_column(create_table, client):
table = create_table(table_name)
group_count_by = ['Center', 'Status']
expected_groups = [
'NASA Ames Research Center,Issued',
'NASA Kennedy Space Center,Issued',
('NASA Marshall Space Flight Center', 'Issued'),
('NASA Stennis Space Center', 'Issued'),
]
_test_record_list_group(table, client, group_count_by, expected_groups)

Expand Down

0 comments on commit a6267e0

Please sign in to comment.