From adcbb86e6de343e9f0630f6a7e5d190b81bed775 Mon Sep 17 00:00:00 2001 From: Brent Moran Date: Tue, 12 Apr 2022 16:53:58 +0800 Subject: [PATCH 1/8] first pass custom grouping function --- db/records/operations/group.py | 61 ++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/db/records/operations/group.py b/db/records/operations/group.py index f5f78d83b6..63ac4f87bb 100644 --- a/db/records/operations/group.py +++ b/db/records/operations/group.py @@ -206,6 +206,67 @@ def _get_pretty_bound_expr(id_offset): ) +def _get_custom_endpoints_range_group_select(table, columns, bound_tuples_list): + column_names = [col.name for col in columns] + GROUP_OBJ = 'group_obj' + RANGE_ID = 'range_id' + GEQ_BOUND = 'geq_bound' + LT_BOUND = 'lt_bound' + + def _get_inner_json_object(bound_tuple): + key_value_tuples = ( + (literal(str(col)), literal(val)) + for col, val in zip(column_names, bound_tuple) + ) + key_value_list = [ + part for tup in key_value_tuples for part in tup + ] + return func.json_build_object(*key_value_list) + + ranges = [ + ( + and_( + func.ROW(*columns) >= func.ROW(*bound_tuples_list[i]), + func.ROW(*columns) < func.ROW(*bound_tuples_list[i + 1]) + ), + func.json_build_object( + literal(RANGE_ID), + i + 1, + literal(GEQ_BOUND), + _get_inner_json_object(*bound_tuples_list[i]), + literal(LT_BOUND), + _get_inner_json_object(*bound_tuples_list[i + 1]), + ) + ) + for i in range(len(bound_tuples_list)) + ] + ranges_cte = select( + *columns, + case(*ranges).label(GROUP_OBJ) + ).cte() + + ranges_aggregation_cols = [ + col for col in ranges_cte.columns if col.name in column_names + ] + window_def = GroupingWindowDefinition( + order_by=ranges_aggregation_cols, + partition_by=ranges_cte.columns[GROUP_OBJ][RANGE_ID] + ) + group_id_expr = window_def.partition_by + geq_expr = ranges_cte.columns[GROUP_OBJ][GEQ_BOUND] + lt_expr = ranges_cte.columns[GROUP_OBJ][LT_BOUND] + return select( + *[col for col in ranges_cte.columns if col.name in table.columns], + _get_group_metadata_definition( + window_def, + ranges_aggregation_cols, + group_id_expr, + geq_expr=geq_expr, + lt_expr=lt_expr, + ) + ) + + def _get_percentile_range_group_select(table, columns, num_groups): column_names = [col.name for col in columns] # cume_dist is a PostgreSQL function that calculates the cumulative From 99ca9e24e672e685d1651c8dc1c098d7636cc4c0 Mon Sep 17 00:00:00 2001 From: Brent Moran Date: Thu, 14 Apr 2022 16:59:41 +0800 Subject: [PATCH 2/8] fix bugs in custom range select generating function --- db/records/operations/group.py | 51 +++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/db/records/operations/group.py b/db/records/operations/group.py index 63ac4f87bb..83a1c58b9f 100644 --- a/db/records/operations/group.py +++ b/db/records/operations/group.py @@ -1,7 +1,8 @@ from enum import Enum import json import logging -from sqlalchemy import select, func, and_, case, literal +from sqlalchemy import select, func, and_, cast, case, literal +from sqlalchemy.dialects.postgresql import JSON from db.records import exceptions as records_exceptions from db.records.operations import calculation @@ -223,26 +224,32 @@ def _get_inner_json_object(bound_tuple): ] return func.json_build_object(*key_value_list) - ranges = [ - ( - and_( - func.ROW(*columns) >= func.ROW(*bound_tuples_list[i]), - func.ROW(*columns) < func.ROW(*bound_tuples_list[i + 1]) - ), - func.json_build_object( - literal(RANGE_ID), - i + 1, - literal(GEQ_BOUND), - _get_inner_json_object(*bound_tuples_list[i]), - literal(LT_BOUND), - _get_inner_json_object(*bound_tuples_list[i + 1]), + def _build_range_cases(result_expr): + return [ + ( + and_( + func.ROW(*columns) >= func.ROW(*bound_tuples_list[i]), + func.ROW(*columns) < func.ROW(*bound_tuples_list[i + 1]) + ), + result_expr(i) ) - ) - for i in range(len(bound_tuples_list)) - ] + for i in range(len(bound_tuples_list) - 1) + ] ranges_cte = select( *columns, - case(*ranges).label(GROUP_OBJ) + case(*_build_range_cases(lambda x: x + 1), else_=None).label(RANGE_ID), + case( + *_build_range_cases( + lambda x: _get_inner_json_object(bound_tuples_list[x]) + ), + else_=None + ).label(GEQ_BOUND), + case( + *_build_range_cases( + lambda x: _get_inner_json_object(bound_tuples_list[x + 1]) + ), + else_=None + ).label(LT_BOUND), ).cte() ranges_aggregation_cols = [ @@ -250,11 +257,11 @@ def _get_inner_json_object(bound_tuple): ] window_def = GroupingWindowDefinition( order_by=ranges_aggregation_cols, - partition_by=ranges_cte.columns[GROUP_OBJ][RANGE_ID] + partition_by=ranges_cte.columns[RANGE_ID] ) group_id_expr = window_def.partition_by - geq_expr = ranges_cte.columns[GROUP_OBJ][GEQ_BOUND] - lt_expr = ranges_cte.columns[GROUP_OBJ][LT_BOUND] + geq_expr = ranges_cte.columns[GEQ_BOUND] + lt_expr = ranges_cte.columns[LT_BOUND] return select( *[col for col in ranges_cte.columns if col.name in table.columns], _get_group_metadata_definition( @@ -264,7 +271,7 @@ def _get_inner_json_object(bound_tuple): geq_expr=geq_expr, lt_expr=lt_expr, ) - ) + ).where(ranges_cte.columns[RANGE_ID] != None) def _get_percentile_range_group_select(table, columns, num_groups): From 0571bdc8e8d05691e6c976655bebdd07f3cb5c59 Mon Sep 17 00:00:00 2001 From: Brent Moran Date: Thu, 14 Apr 2022 22:20:04 +0800 Subject: [PATCH 3/8] wire endpoints function up to API --- db/records/operations/group.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/db/records/operations/group.py b/db/records/operations/group.py index 83a1c58b9f..7a583e865d 100644 --- a/db/records/operations/group.py +++ b/db/records/operations/group.py @@ -1,8 +1,7 @@ from enum import Enum import json import logging -from sqlalchemy import select, func, and_, cast, case, literal -from sqlalchemy.dialects.postgresql import JSON +from sqlalchemy import select, func, and_, case, literal from db.records import exceptions as records_exceptions from db.records.operations import calculation @@ -15,6 +14,7 @@ class GroupMode(Enum): DISTINCT = 'distinct' + ENDPOINTS = 'endpoints' MAGNITUDE = 'magnitude' PERCENTILE = 'percentile' @@ -32,11 +32,16 @@ class GroupMetadataField(Enum): class GroupBy: def __init__( - self, columns, mode=GroupMode.DISTINCT.value, num_groups=None + self, + columns, + mode=GroupMode.DISTINCT.value, + num_groups=None, + bound_tuples=None, ): self._columns = tuple(columns) if type(columns) != str else tuple([columns]) self._mode = mode self._num_groups = num_groups + self._bound_tuples = bound_tuples self._ranged = bool(mode != GroupMode.DISTINCT.value) @property @@ -51,6 +56,10 @@ def mode(self): def num_groups(self): return self._num_groups + @property + def bound_tuples(self): + return self._bound_tuples + @property def ranged(self): return self._ranged @@ -62,16 +71,20 @@ def validate(self): f'mode "{self.mode}" is invalid. valid modes are: ' + ', '.join([f"'{gm}'" for gm in group_modes]) ) - if ( + elif ( self.mode == GroupMode.PERCENTILE.value and not type(self.num_groups) == int ): raise records_exceptions.BadGroupFormat( - 'percentile mode requires integer num_groups' + f'{GroupMode.PERCENTILE.value} mode requires integer num_groups' ) - if self.mode == GroupMode.MAGNITUDE.value and not len(self.columns) == 1: + elif self.mode == GroupMode.MAGNITUDE.value and not len(self.columns) == 1: raise records_exceptions.BadGroupFormat( - 'magnitude mode only works on single columns' + f'{GroupMode.MAGNITUDE.value} mode only works on single columns' + ) + elif self.mode == GroupMode.ENDPOINTS.value and self.bound_tuples is None: + raise records_exceptions.BadGroupFormat( + f'{GroupMode.ENDPOINTS.value} mode requires bound_tuples' ) for col in self.columns: @@ -124,6 +137,10 @@ def get_group_augmented_records_query(table, group_by): query = _get_percentile_range_group_select( table, grouping_columns, group_by.num_groups ) + elif group_by.mode == GroupMode.ENDPOINTS.value: + query = _get_custom_endpoints_range_group_select( + table, grouping_columns, group_by.bound_tuples + ) elif group_by.mode == GroupMode.MAGNITUDE.value: query = _get_tens_powers_range_group_select(table, grouping_columns) elif group_by.mode == GroupMode.DISTINCT.value: @@ -209,7 +226,6 @@ def _get_pretty_bound_expr(id_offset): def _get_custom_endpoints_range_group_select(table, columns, bound_tuples_list): column_names = [col.name for col in columns] - GROUP_OBJ = 'group_obj' RANGE_ID = 'range_id' GEQ_BOUND = 'geq_bound' LT_BOUND = 'lt_bound' @@ -271,7 +287,7 @@ def _build_range_cases(result_expr): geq_expr=geq_expr, lt_expr=lt_expr, ) - ).where(ranges_cte.columns[RANGE_ID] != None) + ).where(ranges_cte.columns[RANGE_ID] != None) # noqa def _get_percentile_range_group_select(table, columns, num_groups): From eca7752ad3faf52801d8ad8af6a837997b05dfd3 Mon Sep 17 00:00:00 2001 From: Brent Moran Date: Thu, 14 Apr 2022 23:59:59 +0800 Subject: [PATCH 4/8] add tests, including one before solution --- db/tests/records/operations/test_group.py | 58 ++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/db/tests/records/operations/test_group.py b/db/tests/records/operations/test_group.py index 63b53a12a5..a76dcffdcf 100644 --- a/db/tests/records/operations/test_group.py +++ b/db/tests/records/operations/test_group.py @@ -111,6 +111,15 @@ def test_GB_validate_passes_valid_kwargs_mag(): gb.validate() +def test_GB_validate_passes_valid_kwargs_endpoints(): + gb = group.GroupBy( + columns=['col1'], + mode=group.GroupMode.ENDPOINTS.value, + bound_tuples=[('a', 5), ('b', 0)], + ) + gb.validate() + + def test_GB_validate_fails_invalid_mode(): gb = group.GroupBy( columns=['col1', 'col2'], @@ -140,6 +149,15 @@ def test_GB_validate_fails_invalid_columns_len(): gb.validate() +def test_GB_validate_fails_missing_bound_tuples(): + gb = group.GroupBy( + columns=['col1', 'col2'], + mode=group.GroupMode.ENDPOINTS.value, + ) + with pytest.raises(records_exceptions.BadGroupFormat): + gb.validate() + + def test_GB_get_valid_group_by_columns_str_cols(roster_table_obj): roster, _ = roster_table_obj column_names = ['Student Number', 'Student Email'] @@ -184,6 +202,7 @@ def _group_id(row): basic_group_modes = [ group.GroupMode.DISTINCT.value, group.GroupMode.PERCENTILE.value, + group.GroupMode.ENDPOINTS.value, ] @@ -191,7 +210,14 @@ def _group_id(row): def test_get_group_augmented_records_query_metadata_fields(roster_table_obj, group_mode): roster, engine = roster_table_obj group_by = group.GroupBy( - ['Student Number', 'Student Name'], mode=group_mode, num_groups=12 + ['Student Number', 'Student Name'], + mode=group_mode, + num_groups=12, + bound_tuples=[ + ('00000000-0000-0000-0000-000000000000', 'Alice'), + ('77777777-7777-7777-7777-777777777777', 'Margot'), + ('ffffffff-ffff-ffff-ffff-ffffffffffff', 'Zachary'), + ] ) augmented_query = group.get_group_augmented_records_query(roster, group_by) with engine.begin() as conn: @@ -259,6 +285,16 @@ def test_smoke_get_group_augmented_records_query_magnitude(magnitude_table_obj): num_groups=1500, ), 1500 + ), + ( + group.GroupBy( + ['Subject', 'Grade'], + mode=group.GroupMode.ENDPOINTS.value, + bound_tuples=[ + ('a', 50), ('f', 75), ('k', 25), ('p', 90), ('r', 100) + ] + ), + 4 ) ] @@ -357,6 +393,26 @@ def test_magnitude_group_select_inside_bounds(magnitude_table_obj, col_name): ) +invalid_endpoints_setups = [ + (['Grade'], [(0,), (2,), (1,)]) +] + + +@pytest.mark.parametrize('columns,bound_tuples', invalid_endpoints_setups) +def test_invalid_bound_tuples_lists(roster_table_obj, columns, bound_tuples): + roster, engine = roster_table_obj + input_cols = columns + group_by = group.GroupBy( + columns=input_cols, + mode=group.GroupMode.ENDPOINTS.value, + bound_tuples=bound_tuples + ) + sel = group.get_group_augmented_records_query(roster, group_by) + with pytest.raises(records_exceptions.BadGroupFormat): + with engine.begin() as conn: + conn.execute(sel).fetchall() + + def test_get_distinct_group_select_correct_first_last_row_match(roster_distinct_setup): res = roster_distinct_setup for row in res: From e468b41a9df551774a2eb7176868942f43d3c47b Mon Sep 17 00:00:00 2001 From: Brent Moran Date: Tue, 19 Apr 2022 17:36:58 +0800 Subject: [PATCH 5/8] Move GroupBy validation to __init__ --- db/records/operations/group.py | 9 +++- db/tests/records/operations/test_group.py | 60 +++++++---------------- 2 files changed, 25 insertions(+), 44 deletions(-) diff --git a/db/records/operations/group.py b/db/records/operations/group.py index 7a583e865d..853abb8b4b 100644 --- a/db/records/operations/group.py +++ b/db/records/operations/group.py @@ -14,8 +14,9 @@ class GroupMode(Enum): DISTINCT = 'distinct' - ENDPOINTS = 'endpoints' + ENDPOINTS = 'endpoints' # intended for internal use at the moment MAGNITUDE = 'magnitude' + COUNT_BY = 'count_by' PERCENTILE = 'percentile' @@ -37,12 +38,17 @@ def __init__( mode=GroupMode.DISTINCT.value, num_groups=None, bound_tuples=None, + global_min=None, + global_max=None, ): self._columns = tuple(columns) if type(columns) != str else tuple([columns]) self._mode = mode self._num_groups = num_groups self._bound_tuples = bound_tuples + self._global_min = global_min + self._global_max = global_max self._ranged = bool(mode != GroupMode.DISTINCT.value) + self.validate() @property def columns(self): @@ -94,7 +100,6 @@ def validate(self): ) def get_validated_group_by_columns(self, table): - self.validate() for col in self.columns: col_name = col if isinstance(col, str) else col.name if col_name not in table.columns: diff --git a/db/tests/records/operations/test_group.py b/db/tests/records/operations/test_group.py index a76dcffdcf..e9f1bceaa9 100644 --- a/db/tests/records/operations/test_group.py +++ b/db/tests/records/operations/test_group.py @@ -121,41 +121,37 @@ def test_GB_validate_passes_valid_kwargs_endpoints(): def test_GB_validate_fails_invalid_mode(): - gb = group.GroupBy( - columns=['col1', 'col2'], - mode='potato', - num_groups=1234, - ) with pytest.raises(records_exceptions.InvalidGroupType): - gb.validate() + group.GroupBy( + columns=['col1', 'col2'], + mode='potato', + num_groups=1234, + ) def test_GB_validate_fails_invalid_num_group(): - gb = group.GroupBy( - columns=['col1', 'col2'], - mode=group.GroupMode.PERCENTILE.value, - num_groups=None, - ) with pytest.raises(records_exceptions.BadGroupFormat): - gb.validate() + group.GroupBy( + columns=['col1', 'col2'], + mode=group.GroupMode.PERCENTILE.value, + num_groups=None, + ) def test_GB_validate_fails_invalid_columns_len(): - gb = group.GroupBy( - columns=['col1', 'col2'], - mode=group.GroupMode.MAGNITUDE.value, - ) with pytest.raises(records_exceptions.BadGroupFormat): - gb.validate() + group.GroupBy( + columns=['col1', 'col2'], + mode=group.GroupMode.MAGNITUDE.value, + ) def test_GB_validate_fails_missing_bound_tuples(): - gb = group.GroupBy( - columns=['col1', 'col2'], - mode=group.GroupMode.ENDPOINTS.value, - ) with pytest.raises(records_exceptions.BadGroupFormat): - gb.validate() + group.GroupBy( + columns=['col1', 'col2'], + mode=group.GroupMode.ENDPOINTS.value, + ) def test_GB_get_valid_group_by_columns_str_cols(roster_table_obj): @@ -393,26 +389,6 @@ def test_magnitude_group_select_inside_bounds(magnitude_table_obj, col_name): ) -invalid_endpoints_setups = [ - (['Grade'], [(0,), (2,), (1,)]) -] - - -@pytest.mark.parametrize('columns,bound_tuples', invalid_endpoints_setups) -def test_invalid_bound_tuples_lists(roster_table_obj, columns, bound_tuples): - roster, engine = roster_table_obj - input_cols = columns - group_by = group.GroupBy( - columns=input_cols, - mode=group.GroupMode.ENDPOINTS.value, - bound_tuples=bound_tuples - ) - sel = group.get_group_augmented_records_query(roster, group_by) - with pytest.raises(records_exceptions.BadGroupFormat): - with engine.begin() as conn: - conn.execute(sel).fetchall() - - def test_get_distinct_group_select_correct_first_last_row_match(roster_distinct_setup): res = roster_distinct_setup for row in res: From 14a3516e1d0e90aa4c36aca8c4018c1f4b501752 Mon Sep 17 00:00:00 2001 From: Brent Moran Date: Thu, 21 Apr 2022 21:35:12 +0800 Subject: [PATCH 6/8] add count_by mode to GroupBy --- db/records/operations/group.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/db/records/operations/group.py b/db/records/operations/group.py index 853abb8b4b..fd455ffcc3 100644 --- a/db/records/operations/group.py +++ b/db/records/operations/group.py @@ -38,6 +38,7 @@ def __init__( mode=GroupMode.DISTINCT.value, num_groups=None, bound_tuples=None, + count_by=None, global_min=None, global_max=None, ): @@ -45,6 +46,7 @@ def __init__( self._mode = mode self._num_groups = num_groups self._bound_tuples = bound_tuples + self._count_by = count_by self._global_min = global_min self._global_max = global_max self._ranged = bool(mode != GroupMode.DISTINCT.value) @@ -64,12 +66,21 @@ def num_groups(self): @property def bound_tuples(self): - return self._bound_tuples + if self._bound_tuples is not None: + return self._bound_tuples + elif self._mode == GroupMode.COUNT_BY.value: + return [bt for bt in self._bound_tuple_generator()] @property def ranged(self): return self._ranged + def _bound_tuple_generator(self): + val = self._global_min + while val <= self._global_max: + yield (val,) + val += self._count_by + def validate(self): group_modes = {group_mode.value for group_mode in GroupMode} if self.mode not in group_modes: @@ -88,10 +99,17 @@ def validate(self): raise records_exceptions.BadGroupFormat( f'{GroupMode.MAGNITUDE.value} mode only works on single columns' ) - elif self.mode == GroupMode.ENDPOINTS.value and self.bound_tuples is None: + elif self.mode == GroupMode.ENDPOINTS.value and not self.bound_tuples: raise records_exceptions.BadGroupFormat( f'{GroupMode.ENDPOINTS.value} mode requires bound_tuples' ) + elif ( + self.mode == GroupMode.COUNT_BY.value + and (self._global_min is None or self._global_max is None) + ): + raise records_exceptions.BadGroupFormat( + f'{GroupMode.COUNT_BY.value} mode requires global_min and global_max' + ) for col in self.columns: if type(col) != str: @@ -142,7 +160,10 @@ def get_group_augmented_records_query(table, group_by): query = _get_percentile_range_group_select( table, grouping_columns, group_by.num_groups ) - elif group_by.mode == GroupMode.ENDPOINTS.value: + elif ( + group_by.mode == GroupMode.ENDPOINTS.value + or group_by.mode == GroupMode.COUNT_BY.value + ) : query = _get_custom_endpoints_range_group_select( table, grouping_columns, group_by.bound_tuples ) From 442bcc7214d9e95800db7b983d623001579cbd1d Mon Sep 17 00:00:00 2001 From: Brent Moran Date: Mon, 25 Apr 2022 22:53:50 +0800 Subject: [PATCH 7/8] test count_by grouping mode, fix bugs found --- db/records/operations/group.py | 13 ++- db/tests/records/operations/test_group.py | 102 +++++++++++++++++++++- mathesar/tests/api/test_record_api.py | 1 + 3 files changed, 111 insertions(+), 5 deletions(-) diff --git a/db/records/operations/group.py b/db/records/operations/group.py index fd455ffcc3..0d2a6a1cce 100644 --- a/db/records/operations/group.py +++ b/db/records/operations/group.py @@ -105,10 +105,17 @@ def validate(self): ) elif ( self.mode == GroupMode.COUNT_BY.value - and (self._global_min is None or self._global_max is None) + and ( + self._count_by is None + or not len(self.columns) == 1 + or self._global_min is None + or self._global_max is None + ) ): raise records_exceptions.BadGroupFormat( - f'{GroupMode.COUNT_BY.value} mode requires global_min and global_max' + f'{GroupMode.COUNT_BY.value} mode requires' + ' count_by, global_min, and global_max.' + ' further, it works only for single columns.' ) for col in self.columns: @@ -163,7 +170,7 @@ def get_group_augmented_records_query(table, group_by): elif ( group_by.mode == GroupMode.ENDPOINTS.value or group_by.mode == GroupMode.COUNT_BY.value - ) : + ): query = _get_custom_endpoints_range_group_select( table, grouping_columns, group_by.bound_tuples ) diff --git a/db/tests/records/operations/test_group.py b/db/tests/records/operations/test_group.py index e9f1bceaa9..7a43e62a30 100644 --- a/db/tests/records/operations/test_group.py +++ b/db/tests/records/operations/test_group.py @@ -120,6 +120,17 @@ def test_GB_validate_passes_valid_kwargs_endpoints(): gb.validate() +def test_GB_validate_passes_valid_kwargs_count_by(): + gb = group.GroupBy( + columns=['col1'], + mode=group.GroupMode.COUNT_BY.value, + count_by=3, + global_min=234.5, + global_max=987.6 + ) + gb.validate() + + def test_GB_validate_fails_invalid_mode(): with pytest.raises(records_exceptions.InvalidGroupType): group.GroupBy( @@ -154,6 +165,50 @@ def test_GB_validate_fails_missing_bound_tuples(): ) +def test_GB_validate_fails_missing_count_by(): + with pytest.raises(records_exceptions.BadGroupFormat): + group.GroupBy( + columns=['col1'], + mode=group.GroupMode.COUNT_BY.value, + count_by=None, + global_min=234.5, + global_max=987.6 + ) + + +def test_GB_validate_fails_missing_global_min(): + with pytest.raises(records_exceptions.BadGroupFormat): + group.GroupBy( + columns=['col1'], + mode=group.GroupMode.COUNT_BY.value, + count_by=3, + global_min=None, + global_max=987.6 + ) + + +def test_GB_validate_fails_missing_global_max(): + with pytest.raises(records_exceptions.BadGroupFormat): + group.GroupBy( + columns=['col1'], + mode=group.GroupMode.COUNT_BY.value, + count_by=3, + global_min=234.5, + global_max=None + ) + + +def test_GB_validate_fails_multiple_cols_with_count_by(): + with pytest.raises(records_exceptions.BadGroupFormat): + group.GroupBy( + columns=['col1', 'col2'], + mode=group.GroupMode.COUNT_BY.value, + count_by=3, + global_min=234.5, + global_max=987.6 + ) + + def test_GB_get_valid_group_by_columns_str_cols(roster_table_obj): roster, _ = roster_table_obj column_names = ['Student Number', 'Student Email'] @@ -227,9 +282,22 @@ def test_get_group_augmented_records_query_metadata_fields(roster_table_obj, gro ) -def test_smoke_get_group_augmented_records_query_magnitude(magnitude_table_obj): +single_col_number_modes = [ + group.GroupMode.MAGNITUDE.value, + group.GroupMode.COUNT_BY.value, +] + + +@pytest.mark.parametrize('mode', single_col_number_modes) +def test_smoke_get_group_augmented_records_query_magnitude(magnitude_table_obj, mode): magnitude, engine = magnitude_table_obj - group_by = group.GroupBy(['big_num'], mode=group.GroupMode.MAGNITUDE.value) + group_by = group.GroupBy( + ['big_num'], + mode=mode, + count_by=50, + global_min=0, + global_max=1000 + ) augmented_query = group.get_group_augmented_records_query(magnitude, group_by) with engine.begin() as conn: res = conn.execute(augmented_query).fetchall() @@ -332,6 +400,36 @@ def test_group_select_correct_num_group_id_magnitude( assert max([_group_id(row) for row in res]) == num +count_by_count_by = [0.000005, 0.00001, 7, 80.5, 750, 25, 100] +count_by_global_min = [0, 0, 0, -100, -4500, -100, 0] +count_by_global_max = [0.0003, 0.001, 250, 600, 5500, 100, 2000] +count_by_max_group_id = [59, 99, 29, 8, 13, 8, 20] + + +@pytest.mark.parametrize( + 'col_name,count_by,global_min,global_max,num', zip( + magnitude_columns, count_by_count_by, count_by_global_min, + count_by_global_max, count_by_max_group_id + ) +) +def test_group_select_correct_num_group_id_count_by( + magnitude_table_obj, col_name, count_by, global_min, global_max, num +): + magnitude, engine = magnitude_table_obj + group_by = group.GroupBy( + [col_name], + mode=group.GroupMode.COUNT_BY.value, + count_by=count_by, + global_min=global_min, + global_max=global_max, + ) + augmented_query = group.get_group_augmented_records_query(magnitude, group_by) + with engine.begin() as conn: + res = conn.execute(augmented_query).fetchall() + + assert max([_group_id(row) for row in res]) == num + + @pytest.mark.parametrize('col_name', magnitude_columns) def test_magnitude_group_select_bounds_chain(magnitude_table_obj, col_name): magnitude, engine = magnitude_table_obj diff --git a/mathesar/tests/api/test_record_api.py b/mathesar/tests/api/test_record_api.py index 623e634882..0ce9fd1a47 100644 --- a/mathesar/tests/api/test_record_api.py +++ b/mathesar/tests/api/test_record_api.py @@ -505,6 +505,7 @@ def test_record_list_groups( json_grouping = json.dumps(ids_converted_group_by) limit = 100 query_str = f'grouping={json_grouping}&order_by={json_order_by}&limit={limit}' + print(query_str) response = client.get(f'/api/db/v0/tables/{table.id}/records/?{query_str}') response_data = response.json() From a25d676187beb92e140358eaff527c150e4b35d2 Mon Sep 17 00:00:00 2001 From: Brent Moran Date: Tue, 26 Apr 2022 17:18:29 +0800 Subject: [PATCH 8/8] remove debugging print statement --- mathesar/tests/api/test_record_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mathesar/tests/api/test_record_api.py b/mathesar/tests/api/test_record_api.py index 0ce9fd1a47..623e634882 100644 --- a/mathesar/tests/api/test_record_api.py +++ b/mathesar/tests/api/test_record_api.py @@ -505,7 +505,6 @@ def test_record_list_groups( json_grouping = json.dumps(ids_converted_group_by) limit = 100 query_str = f'grouping={json_grouping}&order_by={json_order_by}&limit={limit}' - print(query_str) response = client.get(f'/api/db/v0/tables/{table.id}/records/?{query_str}') response_data = response.json()