Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix grouping #477

Merged
merged 5 commits into from
Jul 23, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 29 additions & 29 deletions db/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from sqlalchemy_filters import apply_filters, apply_sort
from sqlalchemy_filters.exceptions import FieldNotFound

from db.constants import ID

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -32,13 +31,27 @@ def _create_col_objects(table, column_list):
]


def _get_query(table, limit, offset, order_by, filters):
query = select(table).limit(limit).offset(offset)
if order_by is not None:
query = apply_sort(query, order_by)
if filters is not None:
query = apply_filters(query, filters)
return query


def _execute_query(query, engine):
with engine.begin() as conn:
records = conn.execute(query).fetchall()
return records


def get_record(table, engine, id_value):
primary_key_column = _get_primary_key_column(table)
query = select(table).where(primary_key_column == id_value)
with engine.begin() as conn:
result = conn.execute(query).fetchall()
assert len(result) <= 1
return result[0] if result else None
result = _execute_query(query, engine)
assert len(result) <= 1
return result[0] if result else None


def get_records(
Expand All @@ -59,13 +72,8 @@ def get_records(
field, in addition to an 'value' field if appropriate.
See: https://github.com/centerofci/sqlalchemy-filters#filters-format
"""
query = select(table).limit(limit).offset(offset)
if order_by is not None:
query = apply_sort(query, order_by)
if filters is not None:
query = apply_filters(query, filters)
with engine.begin() as conn:
return conn.execute(query).fetchall()
query = _get_query(table, limit, offset, order_by, filters)
return _execute_query(query, engine)


def get_group_counts(
Expand Down Expand Up @@ -96,24 +104,17 @@ def get_group_counts(
if field_name not in table.c:
raise GroupFieldNotFound(f"Group field {field} not found in {table}.")

query = (
select(table)
.limit(limit)
.offset(offset)
)
if order_by is not None:
query = apply_sort(query, order_by)
if filters is not None:
query = apply_filters(query, filters)
subquery = query.subquery()
# Get the list of groups that we should count.
# We're considering limit and offset here so that we only count relevant groups
relevant_groups_query = _get_query(table, limit, offset, order_by, filters)
subquery = relevant_groups_query.subquery()

group_by = [
columns = [
subquery.columns[col] if type(col) == str else subquery.columns[col.name]
for col in group_by
]
query = select(*group_by, func.count(subquery.c[ID])).group_by(*group_by)
with engine.begin() as conn:
records = conn.execute(query).fetchall()
count_query = select(*columns, func.count(columns[0])).group_by(*columns)
records = _execute_query(count_query, engine)

# Last field is the count, preceding fields are the group by fields
counts = {
Expand Down Expand Up @@ -155,9 +156,8 @@ def get_distinct_tuple_values(
.limit(limit)
.offset(offset)
)
with engine.begin() as conn:
res = conn.execute(query).fetchall()
return [tuple(zip(column_objects, row)) for row in res]
result = _execute_query(query, engine)
return [tuple(zip(column_objects, row)) for row in result]


def distinct_tuples_to_filter(distinct_tuples):
Expand Down