diff --git a/db/records/operations/group.py b/db/records/operations/group.py index f5f78d83b6..0d2a6a1cce 100644 --- a/db/records/operations/group.py +++ b/db/records/operations/group.py @@ -14,7 +14,9 @@ class GroupMode(Enum): DISTINCT = 'distinct' + ENDPOINTS = 'endpoints' # intended for internal use at the moment MAGNITUDE = 'magnitude' + COUNT_BY = 'count_by' PERCENTILE = 'percentile' @@ -31,12 +33,24 @@ class GroupMetadataField(Enum): class GroupBy: def __init__( - self, columns, mode=GroupMode.DISTINCT.value, num_groups=None + self, + columns, + mode=GroupMode.DISTINCT.value, + num_groups=None, + bound_tuples=None, + count_by=None, + global_min=None, + global_max=None, ): self._columns = tuple(columns) if type(columns) != str else tuple([columns]) self._mode = mode self._num_groups = num_groups + self._bound_tuples = bound_tuples + self._count_by = count_by + self._global_min = global_min + self._global_max = global_max self._ranged = bool(mode != GroupMode.DISTINCT.value) + self.validate() @property def columns(self): @@ -50,10 +64,23 @@ def mode(self): def num_groups(self): return self._num_groups + @property + def bound_tuples(self): + if self._bound_tuples is not None: + return self._bound_tuples + elif self._mode == GroupMode.COUNT_BY.value: + return [bt for bt in self._bound_tuple_generator()] + @property def ranged(self): return self._ranged + def _bound_tuple_generator(self): + val = self._global_min + while val <= self._global_max: + yield (val,) + val += self._count_by + def validate(self): group_modes = {group_mode.value for group_mode in GroupMode} if self.mode not in group_modes: @@ -61,16 +88,34 @@ def validate(self): f'mode "{self.mode}" is invalid. valid modes are: ' + ', '.join([f"'{gm}'" for gm in group_modes]) ) - if ( + elif ( self.mode == GroupMode.PERCENTILE.value and not type(self.num_groups) == int ): raise records_exceptions.BadGroupFormat( - 'percentile mode requires integer num_groups' + f'{GroupMode.PERCENTILE.value} mode requires integer num_groups' + ) + elif self.mode == GroupMode.MAGNITUDE.value and not len(self.columns) == 1: + raise records_exceptions.BadGroupFormat( + f'{GroupMode.MAGNITUDE.value} mode only works on single columns' ) - if self.mode == GroupMode.MAGNITUDE.value and not len(self.columns) == 1: + elif self.mode == GroupMode.ENDPOINTS.value and not self.bound_tuples: raise records_exceptions.BadGroupFormat( - 'magnitude mode only works on single columns' + f'{GroupMode.ENDPOINTS.value} mode requires bound_tuples' + ) + elif ( + self.mode == GroupMode.COUNT_BY.value + and ( + self._count_by is None + or not len(self.columns) == 1 + or self._global_min is None + or self._global_max is None + ) + ): + raise records_exceptions.BadGroupFormat( + f'{GroupMode.COUNT_BY.value} mode requires' + ' count_by, global_min, and global_max.' + ' further, it works only for single columns.' ) for col in self.columns: @@ -80,7 +125,6 @@ def validate(self): ) def get_validated_group_by_columns(self, table): - self.validate() for col in self.columns: col_name = col if isinstance(col, str) else col.name if col_name not in table.columns: @@ -123,6 +167,13 @@ def get_group_augmented_records_query(table, group_by): query = _get_percentile_range_group_select( table, grouping_columns, group_by.num_groups ) + elif ( + group_by.mode == GroupMode.ENDPOINTS.value + or group_by.mode == GroupMode.COUNT_BY.value + ): + query = _get_custom_endpoints_range_group_select( + table, grouping_columns, group_by.bound_tuples + ) elif group_by.mode == GroupMode.MAGNITUDE.value: query = _get_tens_powers_range_group_select(table, grouping_columns) elif group_by.mode == GroupMode.DISTINCT.value: @@ -206,6 +257,72 @@ def _get_pretty_bound_expr(id_offset): ) +def _get_custom_endpoints_range_group_select(table, columns, bound_tuples_list): + column_names = [col.name for col in columns] + RANGE_ID = 'range_id' + GEQ_BOUND = 'geq_bound' + LT_BOUND = 'lt_bound' + + def _get_inner_json_object(bound_tuple): + key_value_tuples = ( + (literal(str(col)), literal(val)) + for col, val in zip(column_names, bound_tuple) + ) + key_value_list = [ + part for tup in key_value_tuples for part in tup + ] + return func.json_build_object(*key_value_list) + + def _build_range_cases(result_expr): + return [ + ( + and_( + func.ROW(*columns) >= func.ROW(*bound_tuples_list[i]), + func.ROW(*columns) < func.ROW(*bound_tuples_list[i + 1]) + ), + result_expr(i) + ) + for i in range(len(bound_tuples_list) - 1) + ] + ranges_cte = select( + *columns, + case(*_build_range_cases(lambda x: x + 1), else_=None).label(RANGE_ID), + case( + *_build_range_cases( + lambda x: _get_inner_json_object(bound_tuples_list[x]) + ), + else_=None + ).label(GEQ_BOUND), + case( + *_build_range_cases( + lambda x: _get_inner_json_object(bound_tuples_list[x + 1]) + ), + else_=None + ).label(LT_BOUND), + ).cte() + + ranges_aggregation_cols = [ + col for col in ranges_cte.columns if col.name in column_names + ] + window_def = GroupingWindowDefinition( + order_by=ranges_aggregation_cols, + partition_by=ranges_cte.columns[RANGE_ID] + ) + group_id_expr = window_def.partition_by + geq_expr = ranges_cte.columns[GEQ_BOUND] + lt_expr = ranges_cte.columns[LT_BOUND] + return select( + *[col for col in ranges_cte.columns if col.name in table.columns], + _get_group_metadata_definition( + window_def, + ranges_aggregation_cols, + group_id_expr, + geq_expr=geq_expr, + lt_expr=lt_expr, + ) + ).where(ranges_cte.columns[RANGE_ID] != None) # noqa + + def _get_percentile_range_group_select(table, columns, num_groups): column_names = [col.name for col in columns] # cume_dist is a PostgreSQL function that calculates the cumulative diff --git a/db/tests/records/operations/test_group.py b/db/tests/records/operations/test_group.py index 63b53a12a5..7a43e62a30 100644 --- a/db/tests/records/operations/test_group.py +++ b/db/tests/records/operations/test_group.py @@ -111,33 +111,102 @@ def test_GB_validate_passes_valid_kwargs_mag(): gb.validate() -def test_GB_validate_fails_invalid_mode(): +def test_GB_validate_passes_valid_kwargs_endpoints(): gb = group.GroupBy( - columns=['col1', 'col2'], - mode='potato', - num_groups=1234, + columns=['col1'], + mode=group.GroupMode.ENDPOINTS.value, + bound_tuples=[('a', 5), ('b', 0)], ) - with pytest.raises(records_exceptions.InvalidGroupType): - gb.validate() + gb.validate() -def test_GB_validate_fails_invalid_num_group(): +def test_GB_validate_passes_valid_kwargs_count_by(): gb = group.GroupBy( - columns=['col1', 'col2'], - mode=group.GroupMode.PERCENTILE.value, - num_groups=None, + columns=['col1'], + mode=group.GroupMode.COUNT_BY.value, + count_by=3, + global_min=234.5, + global_max=987.6 ) + gb.validate() + + +def test_GB_validate_fails_invalid_mode(): + with pytest.raises(records_exceptions.InvalidGroupType): + group.GroupBy( + columns=['col1', 'col2'], + mode='potato', + num_groups=1234, + ) + + +def test_GB_validate_fails_invalid_num_group(): with pytest.raises(records_exceptions.BadGroupFormat): - gb.validate() + group.GroupBy( + columns=['col1', 'col2'], + mode=group.GroupMode.PERCENTILE.value, + num_groups=None, + ) def test_GB_validate_fails_invalid_columns_len(): - gb = group.GroupBy( - columns=['col1', 'col2'], - mode=group.GroupMode.MAGNITUDE.value, - ) with pytest.raises(records_exceptions.BadGroupFormat): - gb.validate() + group.GroupBy( + columns=['col1', 'col2'], + mode=group.GroupMode.MAGNITUDE.value, + ) + + +def test_GB_validate_fails_missing_bound_tuples(): + with pytest.raises(records_exceptions.BadGroupFormat): + group.GroupBy( + columns=['col1', 'col2'], + mode=group.GroupMode.ENDPOINTS.value, + ) + + +def test_GB_validate_fails_missing_count_by(): + with pytest.raises(records_exceptions.BadGroupFormat): + group.GroupBy( + columns=['col1'], + mode=group.GroupMode.COUNT_BY.value, + count_by=None, + global_min=234.5, + global_max=987.6 + ) + + +def test_GB_validate_fails_missing_global_min(): + with pytest.raises(records_exceptions.BadGroupFormat): + group.GroupBy( + columns=['col1'], + mode=group.GroupMode.COUNT_BY.value, + count_by=3, + global_min=None, + global_max=987.6 + ) + + +def test_GB_validate_fails_missing_global_max(): + with pytest.raises(records_exceptions.BadGroupFormat): + group.GroupBy( + columns=['col1'], + mode=group.GroupMode.COUNT_BY.value, + count_by=3, + global_min=234.5, + global_max=None + ) + + +def test_GB_validate_fails_multiple_cols_with_count_by(): + with pytest.raises(records_exceptions.BadGroupFormat): + group.GroupBy( + columns=['col1', 'col2'], + mode=group.GroupMode.COUNT_BY.value, + count_by=3, + global_min=234.5, + global_max=987.6 + ) def test_GB_get_valid_group_by_columns_str_cols(roster_table_obj): @@ -184,6 +253,7 @@ def _group_id(row): basic_group_modes = [ group.GroupMode.DISTINCT.value, group.GroupMode.PERCENTILE.value, + group.GroupMode.ENDPOINTS.value, ] @@ -191,7 +261,14 @@ def _group_id(row): def test_get_group_augmented_records_query_metadata_fields(roster_table_obj, group_mode): roster, engine = roster_table_obj group_by = group.GroupBy( - ['Student Number', 'Student Name'], mode=group_mode, num_groups=12 + ['Student Number', 'Student Name'], + mode=group_mode, + num_groups=12, + bound_tuples=[ + ('00000000-0000-0000-0000-000000000000', 'Alice'), + ('77777777-7777-7777-7777-777777777777', 'Margot'), + ('ffffffff-ffff-ffff-ffff-ffffffffffff', 'Zachary'), + ] ) augmented_query = group.get_group_augmented_records_query(roster, group_by) with engine.begin() as conn: @@ -205,9 +282,22 @@ def test_get_group_augmented_records_query_metadata_fields(roster_table_obj, gro ) -def test_smoke_get_group_augmented_records_query_magnitude(magnitude_table_obj): +single_col_number_modes = [ + group.GroupMode.MAGNITUDE.value, + group.GroupMode.COUNT_BY.value, +] + + +@pytest.mark.parametrize('mode', single_col_number_modes) +def test_smoke_get_group_augmented_records_query_magnitude(magnitude_table_obj, mode): magnitude, engine = magnitude_table_obj - group_by = group.GroupBy(['big_num'], mode=group.GroupMode.MAGNITUDE.value) + group_by = group.GroupBy( + ['big_num'], + mode=mode, + count_by=50, + global_min=0, + global_max=1000 + ) augmented_query = group.get_group_augmented_records_query(magnitude, group_by) with engine.begin() as conn: res = conn.execute(augmented_query).fetchall() @@ -259,6 +349,16 @@ def test_smoke_get_group_augmented_records_query_magnitude(magnitude_table_obj): num_groups=1500, ), 1500 + ), + ( + group.GroupBy( + ['Subject', 'Grade'], + mode=group.GroupMode.ENDPOINTS.value, + bound_tuples=[ + ('a', 50), ('f', 75), ('k', 25), ('p', 90), ('r', 100) + ] + ), + 4 ) ] @@ -300,6 +400,36 @@ def test_group_select_correct_num_group_id_magnitude( assert max([_group_id(row) for row in res]) == num +count_by_count_by = [0.000005, 0.00001, 7, 80.5, 750, 25, 100] +count_by_global_min = [0, 0, 0, -100, -4500, -100, 0] +count_by_global_max = [0.0003, 0.001, 250, 600, 5500, 100, 2000] +count_by_max_group_id = [59, 99, 29, 8, 13, 8, 20] + + +@pytest.mark.parametrize( + 'col_name,count_by,global_min,global_max,num', zip( + magnitude_columns, count_by_count_by, count_by_global_min, + count_by_global_max, count_by_max_group_id + ) +) +def test_group_select_correct_num_group_id_count_by( + magnitude_table_obj, col_name, count_by, global_min, global_max, num +): + magnitude, engine = magnitude_table_obj + group_by = group.GroupBy( + [col_name], + mode=group.GroupMode.COUNT_BY.value, + count_by=count_by, + global_min=global_min, + global_max=global_max, + ) + augmented_query = group.get_group_augmented_records_query(magnitude, group_by) + with engine.begin() as conn: + res = conn.execute(augmented_query).fetchall() + + assert max([_group_id(row) for row in res]) == num + + @pytest.mark.parametrize('col_name', magnitude_columns) def test_magnitude_group_select_bounds_chain(magnitude_table_obj, col_name): magnitude, engine = magnitude_table_obj