diff --git a/db/records.py b/db/records.py index 1c6f02004c..c815e3b209 100644 --- a/db/records.py +++ b/db/records.py @@ -96,10 +96,8 @@ def get_group_counts( if field_name not in table.c: raise GroupFieldNotFound(f"Group field {field} not found in {table}.") - group_by = _create_col_objects(table, group_by) query = ( - select(*group_by, func.count(table.c[ID])) - .group_by(*group_by) + select(table) .limit(limit) .offset(offset) ) @@ -107,6 +105,13 @@ def get_group_counts( query = apply_sort(query, order_by) if filters is not None: query = apply_filters(query, filters) + subquery = query.subquery() + + group_by = [ + subquery.columns[col] if type(col) == str else subquery.columns[col.name] + for col in group_by + ] + query = select(*group_by, func.count(subquery.c[ID])).group_by(*group_by) with engine.begin() as conn: records = conn.execute(query).fetchall() diff --git a/db/tests/records/test_grouping.py b/db/tests/records/test_grouping.py index 9134ee3ee2..ebf65abccb 100644 --- a/db/tests/records/test_grouping.py +++ b/db/tests/records/test_grouping.py @@ -3,6 +3,7 @@ import pytest from sqlalchemy import select +from sqlalchemy_filters import apply_sort, apply_filters from db import records from db.records import GroupFieldNotFound, BadGroupFormat @@ -32,35 +33,38 @@ def test_get_group_counts_mixed_str_col_field(filter_sort_table_obj): assert ("string1", 1) in counts -def test_get_group_counts_limit_ordering(filter_sort_table_obj): - filter_sort, engine = filter_sort_table_obj - limit = 50 - order_by = [{"field": "numeric", "direction": "desc", "nullslast": True}] - group_by = [filter_sort.c.numeric] - counts = records.get_group_counts(filter_sort, engine, group_by, limit=limit, - order_by=order_by) - assert len(counts) == 50 - for i in range(1, 100): - if i > 50: - assert (i,) in counts - else: - assert (i,) not in counts - - -def test_get_group_counts_limit_offset_ordering(filter_sort_table_obj): - filter_sort, engine = filter_sort_table_obj - offset = 25 - limit = 50 - order_by = [{"field": "numeric", "direction": "desc", "nullslast": True}] - group_by = [filter_sort.c.numeric] - counts = records.get_group_counts(filter_sort, engine, group_by, limit=limit, +limit_offset_test_list = [ + (limit, offset) + for limit in [None, 0, 25, 50, 100] + for offset in [None, 0, 25, 50, 100] +] + + +@pytest.mark.parametrize("limit,offset", limit_offset_test_list) +def test_get_group_counts_limit_offset_ordering(roster_table_obj, limit, offset): + roster, engine = roster_table_obj + order_by = [{"field": "Grade", "direction": "desc", "nullslast": True}] + group_by = [roster.c["Grade"]] + counts = records.get_group_counts(roster, engine, group_by, limit=limit, offset=offset, order_by=order_by) - assert len(counts) == 50 - for i in range(1, 100): - if i > 25 and i <= 75: - assert (i,) in counts - else: - assert (i,) not in counts + + query = select(group_by[0]) + query = apply_sort(query, order_by) + with engine.begin() as conn: + all_records = list(conn.execute(query)) + if limit is None: + end = None + elif offset is None: + end = limit + else: + end = limit + offset + limit_offset_records = all_records[offset:end] + manual_count = Counter(limit_offset_records) + + assert len(counts) == len(manual_count) + for value, count in manual_count.items(): + assert value in counts + assert counts[value] == count count_values_test_list = itertools.chain(*[ @@ -84,8 +88,38 @@ def test_get_group_counts_count_values(roster_table_obj, group_by): all_records = conn.execute(select(*cols)).fetchall() manual_count = Counter(all_records) - for key, value in counts.items(): - assert manual_count[key] == value + for value, count in manual_count.items(): + assert value in counts + assert counts[value] == count + + +filter_values_test_list = itertools.chain(*[ + itertools.combinations([ + {"field": "Student Name", "op": "ge", "value": "Test Name"}, + {"field": "Student Email", "op": "le", "value": "Test Email"}, + {"field": "Teacher Email", "op": "like", "value": "%gmail.com"}, + {"field": "Subject", "op": "eq", "value": "Non-Existent Subject"}, + {"field": "Grade", "op": "ne", "value": 99} + ], i) for i in range(1, 3) +]) + + +@pytest.mark.parametrize("filter_by", filter_values_test_list) +def test_get_group_counts_filter_values(roster_table_obj, filter_by): + roster, engine = roster_table_obj + group_by = ["Student Name"] + counts = records.get_group_counts(roster, engine, group_by, filters=filter_by) + + cols = [roster.c[f] for f in group_by] + query = select(*cols) + query = apply_filters(query, filter_by) + with engine.begin() as conn: + all_records = conn.execute(query).fetchall() + manual_count = Counter(all_records) + + for value, count in manual_count.items(): + assert value in counts + assert counts[value] == count exceptions_test_list = [ diff --git a/mathesar/pagination.py b/mathesar/pagination.py index d972e2d4f5..abbef4b507 100644 --- a/mathesar/pagination.py +++ b/mathesar/pagination.py @@ -67,7 +67,8 @@ def paginate_queryset(self, queryset, request, table_id, filters=filters, order_by=order_by ) # Convert the tuple keys into strings so it can be converted to JSON - group_count = {','.join(k): v for k, v in group_count.items()} + group_count = [{"values": list(cols), "count": count} + for cols, count in group_count.items()] self.group_count = { 'group_count_by': group_count_by, 'results': group_count, diff --git a/mathesar/tests/views/api/test_record_api.py b/mathesar/tests/views/api/test_record_api.py index 7446ea1814..0a09c9f3a8 100644 --- a/mathesar/tests/views/api/test_record_api.py +++ b/mathesar/tests/views/api/test_record_api.py @@ -116,14 +116,18 @@ def test_record_list_sort(create_table, client): def _test_record_list_group(table, client, group_count_by, expected_groups): + order_by = [ + {'field': 'Center', 'direction': 'desc'}, + {'field': 'Case Number', 'direction': 'asc'}, + ] + json_order_by = json.dumps(order_by) json_group_count_by = json.dumps(group_count_by) + query_str = f'group_count_by={json_group_count_by}&order_by={json_order_by}' with patch.object( records, "get_group_counts", side_effect=records.get_group_counts ) as mock_infer: - response = client.get( - f'/api/v0/tables/{table.id}/records/?group_count_by={json_group_count_by}' - ) + response = client.get(f'/api/v0/tables/{table.id}/records/?{query_str}') response_data = response.json() assert response.status_code == 200 @@ -133,10 +137,13 @@ def _test_record_list_group(table, client, group_count_by, expected_groups): assert 'group_count' in response_data assert response_data['group_count']['group_count_by'] == group_count_by assert 'results' in response_data['group_count'] + assert 'values' in response_data['group_count']['results'][0] + assert 'count' in response_data['group_count']['results'][0] results = response_data['group_count']['results'] + returned_groups = {tuple(group['values']) for group in results} for expected_group in expected_groups: - assert expected_group in results + assert expected_group in returned_groups assert mock_infer.call_args is not None assert mock_infer.call_args[0][2] == group_count_by @@ -147,8 +154,8 @@ def test_record_list_group_single_column(create_table, client): table = create_table(table_name) group_count_by = ['Center'] expected_groups = [ - 'NASA Ames Research Center', - 'NASA Kennedy Space Center' + ('NASA Marshall Space Flight Center',), + ('NASA Stennis Space Center',) ] _test_record_list_group(table, client, group_count_by, expected_groups) @@ -158,8 +165,8 @@ def test_record_list_group_multi_column(create_table, client): table = create_table(table_name) group_count_by = ['Center', 'Status'] expected_groups = [ - 'NASA Ames Research Center,Issued', - 'NASA Kennedy Space Center,Issued', + ('NASA Marshall Space Flight Center', 'Issued'), + ('NASA Stennis Space Center', 'Issued'), ] _test_record_list_group(table, client, group_count_by, expected_groups)