Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

count_by and endpoints grouping modes #1312

Merged
merged 13 commits into from
May 4, 2022
129 changes: 123 additions & 6 deletions db/records/operations/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

class GroupMode(Enum):
DISTINCT = 'distinct'
ENDPOINTS = 'endpoints' # intended for internal use at the moment
MAGNITUDE = 'magnitude'
COUNT_BY = 'count_by'
PERCENTILE = 'percentile'


Expand All @@ -31,12 +33,24 @@ class GroupMetadataField(Enum):

class GroupBy:
def __init__(
self, columns, mode=GroupMode.DISTINCT.value, num_groups=None
self,
columns,
mode=GroupMode.DISTINCT.value,
num_groups=None,
bound_tuples=None,
count_by=None,
global_min=None,
global_max=None,
):
self._columns = tuple(columns) if type(columns) != str else tuple([columns])
self._mode = mode
self._num_groups = num_groups
self._bound_tuples = bound_tuples
self._count_by = count_by
self._global_min = global_min
self._global_max = global_max
self._ranged = bool(mode != GroupMode.DISTINCT.value)
self.validate()

@property
def columns(self):
Expand All @@ -50,27 +64,58 @@ def mode(self):
def num_groups(self):
return self._num_groups

@property
def bound_tuples(self):
if self._bound_tuples is not None:
return self._bound_tuples
elif self._mode == GroupMode.COUNT_BY.value:
return [bt for bt in self._bound_tuple_generator()]

@property
def ranged(self):
return self._ranged

def _bound_tuple_generator(self):
val = self._global_min
while val <= self._global_max:
yield (val,)
val += self._count_by

def validate(self):
group_modes = {group_mode.value for group_mode in GroupMode}
if self.mode not in group_modes:
raise records_exceptions.InvalidGroupType(
f'mode "{self.mode}" is invalid. valid modes are: '
+ ', '.join([f"'{gm}'" for gm in group_modes])
)
if (
elif (
self.mode == GroupMode.PERCENTILE.value
and not type(self.num_groups) == int
):
raise records_exceptions.BadGroupFormat(
'percentile mode requires integer num_groups'
f'{GroupMode.PERCENTILE.value} mode requires integer num_groups'
)
elif self.mode == GroupMode.MAGNITUDE.value and not len(self.columns) == 1:
raise records_exceptions.BadGroupFormat(
f'{GroupMode.MAGNITUDE.value} mode only works on single columns'
)
if self.mode == GroupMode.MAGNITUDE.value and not len(self.columns) == 1:
elif self.mode == GroupMode.ENDPOINTS.value and not self.bound_tuples:
raise records_exceptions.BadGroupFormat(
'magnitude mode only works on single columns'
f'{GroupMode.ENDPOINTS.value} mode requires bound_tuples'
)
elif (
self.mode == GroupMode.COUNT_BY.value
and (
self._count_by is None
or not len(self.columns) == 1
or self._global_min is None
or self._global_max is None
)
):
raise records_exceptions.BadGroupFormat(
f'{GroupMode.COUNT_BY.value} mode requires'
' count_by, global_min, and global_max.'
' further, it works only for single columns.'
)

for col in self.columns:
Expand All @@ -80,7 +125,6 @@ def validate(self):
)

def get_validated_group_by_columns(self, table):
self.validate()
for col in self.columns:
col_name = col if isinstance(col, str) else col.name
if col_name not in table.columns:
Expand Down Expand Up @@ -123,6 +167,13 @@ def get_group_augmented_records_query(table, group_by):
query = _get_percentile_range_group_select(
table, grouping_columns, group_by.num_groups
)
elif (
group_by.mode == GroupMode.ENDPOINTS.value
or group_by.mode == GroupMode.COUNT_BY.value
):
query = _get_custom_endpoints_range_group_select(
table, grouping_columns, group_by.bound_tuples
)
elif group_by.mode == GroupMode.MAGNITUDE.value:
query = _get_tens_powers_range_group_select(table, grouping_columns)
elif group_by.mode == GroupMode.DISTINCT.value:
Expand Down Expand Up @@ -206,6 +257,72 @@ def _get_pretty_bound_expr(id_offset):
)


def _get_custom_endpoints_range_group_select(table, columns, bound_tuples_list):
column_names = [col.name for col in columns]
RANGE_ID = 'range_id'
GEQ_BOUND = 'geq_bound'
LT_BOUND = 'lt_bound'

def _get_inner_json_object(bound_tuple):
key_value_tuples = (
(literal(str(col)), literal(val))
for col, val in zip(column_names, bound_tuple)
)
key_value_list = [
part for tup in key_value_tuples for part in tup
]
return func.json_build_object(*key_value_list)

def _build_range_cases(result_expr):
return [
(
and_(
func.ROW(*columns) >= func.ROW(*bound_tuples_list[i]),
func.ROW(*columns) < func.ROW(*bound_tuples_list[i + 1])
),
result_expr(i)
)
for i in range(len(bound_tuples_list) - 1)
]
ranges_cte = select(
*columns,
case(*_build_range_cases(lambda x: x + 1), else_=None).label(RANGE_ID),
case(
*_build_range_cases(
lambda x: _get_inner_json_object(bound_tuples_list[x])
),
else_=None
).label(GEQ_BOUND),
case(
*_build_range_cases(
lambda x: _get_inner_json_object(bound_tuples_list[x + 1])
),
else_=None
).label(LT_BOUND),
).cte()

ranges_aggregation_cols = [
col for col in ranges_cte.columns if col.name in column_names
]
window_def = GroupingWindowDefinition(
order_by=ranges_aggregation_cols,
partition_by=ranges_cte.columns[RANGE_ID]
)
group_id_expr = window_def.partition_by
geq_expr = ranges_cte.columns[GEQ_BOUND]
lt_expr = ranges_cte.columns[LT_BOUND]
return select(
*[col for col in ranges_cte.columns if col.name in table.columns],
_get_group_metadata_definition(
window_def,
ranges_aggregation_cols,
group_id_expr,
geq_expr=geq_expr,
lt_expr=lt_expr,
)
).where(ranges_cte.columns[RANGE_ID] != None) # noqa


def _get_percentile_range_group_select(table, columns, num_groups):
column_names = [col.name for col in columns]
# cume_dist is a PostgreSQL function that calculates the cumulative
Expand Down
Loading