diff --git a/db/records.py b/db/records.py index 301a07e4b6..dfeb7db0b6 100644 --- a/db/records.py +++ b/db/records.py @@ -4,7 +4,6 @@ from sqlalchemy_filters import apply_filters, apply_sort from sqlalchemy_filters.exceptions import FieldNotFound -from db.constants import ID logger = logging.getLogger(__name__) @@ -32,13 +31,27 @@ def _create_col_objects(table, column_list): ] +def _get_query(table, limit, offset, order_by, filters): + query = select(table).limit(limit).offset(offset) + if order_by is not None: + query = apply_sort(query, order_by) + if filters is not None: + query = apply_filters(query, filters) + return query + + +def _execute_query(query, engine): + with engine.begin() as conn: + records = conn.execute(query).fetchall() + return records + + def get_record(table, engine, id_value): primary_key_column = _get_primary_key_column(table) query = select(table).where(primary_key_column == id_value) - with engine.begin() as conn: - result = conn.execute(query).fetchall() - assert len(result) <= 1 - return result[0] if result else None + result = _execute_query(query, engine) + assert len(result) <= 1 + return result[0] if result else None def get_records( @@ -59,13 +72,8 @@ def get_records( field, in addition to an 'value' field if appropriate. See: https://github.com/centerofci/sqlalchemy-filters#filters-format """ - query = select(table).limit(limit).offset(offset) - if order_by is not None: - query = apply_sort(query, order_by) - if filters is not None: - query = apply_filters(query, filters) - with engine.begin() as conn: - return conn.execute(query).fetchall() + query = _get_query(table, limit, offset, order_by, filters) + return _execute_query(query, engine) def get_group_counts( @@ -96,24 +104,17 @@ def get_group_counts( if field_name not in table.c: raise GroupFieldNotFound(f"Group field {field} not found in {table}.") - query = ( - select(table) - .limit(limit) - .offset(offset) - ) - if order_by is not None: - query = apply_sort(query, order_by) - if filters is not None: - query = apply_filters(query, filters) - subquery = query.subquery() + # Get the list of groups that we should count. + # We're considering limit and offset here so that we only count relevant groups + relevant_groups_query = _get_query(table, limit, offset, order_by, filters) + subquery = relevant_groups_query.subquery() - group_by = [ + columns = [ subquery.columns[col] if type(col) == str else subquery.columns[col.name] for col in group_by ] - query = select(*group_by, func.count(subquery.c[ID])).group_by(*group_by) - with engine.begin() as conn: - records = conn.execute(query).fetchall() + count_query = select(*columns, func.count(columns[0])).group_by(*columns) + records = _execute_query(count_query, engine) # Last field is the count, preceding fields are the group by fields counts = { @@ -155,9 +156,8 @@ def get_distinct_tuple_values( .limit(limit) .offset(offset) ) - with engine.begin() as conn: - res = conn.execute(query).fetchall() - return [tuple(zip(column_objects, row)) for row in res] + result = _execute_query(query, engine) + return [tuple(zip(column_objects, row)) for row in result] def distinct_tuples_to_filter(distinct_tuples):