From f624e1681c92ebdf313943bfb00a592469b38556 Mon Sep 17 00:00:00 2001 From: Kriti Godey Date: Thu, 22 Jul 2021 15:57:51 -0400 Subject: [PATCH 1/4] Consolidate repeated code into a function. --- db/records.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/db/records.py b/db/records.py index 301a07e4b6..9d779fb92f 100644 --- a/db/records.py +++ b/db/records.py @@ -32,6 +32,15 @@ def _create_col_objects(table, column_list): ] +def _get_query(table, limit, offset, order_by, filters): + query = select(table).limit(limit).offset(offset) + if order_by is not None: + query = apply_sort(query, order_by) + if filters is not None: + query = apply_filters(query, filters) + return query + + def get_record(table, engine, id_value): primary_key_column = _get_primary_key_column(table) query = select(table).where(primary_key_column == id_value) @@ -59,11 +68,7 @@ def get_records( field, in addition to an 'value' field if appropriate. See: https://github.com/centerofci/sqlalchemy-filters#filters-format """ - query = select(table).limit(limit).offset(offset) - if order_by is not None: - query = apply_sort(query, order_by) - if filters is not None: - query = apply_filters(query, filters) + query = _get_query(table, limit, offset, order_by, filters) with engine.begin() as conn: return conn.execute(query).fetchall() @@ -96,15 +101,7 @@ def get_group_counts( if field_name not in table.c: raise GroupFieldNotFound(f"Group field {field} not found in {table}.") - query = ( - select(table) - .limit(limit) - .offset(offset) - ) - if order_by is not None: - query = apply_sort(query, order_by) - if filters is not None: - query = apply_filters(query, filters) + query = _get_query(table, limit, offset, order_by, filters) subquery = query.subquery() group_by = [ From 5df91b727876c88a8f3efcfcc0de0af8c0ac8301 Mon Sep 17 00:00:00 2001 From: Kriti Godey Date: Thu, 22 Jul 2021 16:44:10 -0400 Subject: [PATCH 2/4] Don't reuse the same variable names for different things. --- db/records.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/db/records.py b/db/records.py index 9d779fb92f..ff62c3f8a6 100644 --- a/db/records.py +++ b/db/records.py @@ -101,16 +101,18 @@ def get_group_counts( if field_name not in table.c: raise GroupFieldNotFound(f"Group field {field} not found in {table}.") - query = _get_query(table, limit, offset, order_by, filters) - subquery = query.subquery() + # Get the list of groups that we should count. + # We're considering limit and offset here so that we only count relevant groups + relevant_groups_query = _get_query(table, limit, offset, order_by, filters) + subquery = relevant_groups_query.subquery() - group_by = [ + columns = [ subquery.columns[col] if type(col) == str else subquery.columns[col.name] for col in group_by ] - query = select(*group_by, func.count(subquery.c[ID])).group_by(*group_by) + count_query = select(*columns, func.count(subquery.c[ID])).group_by(*columns) with engine.begin() as conn: - records = conn.execute(query).fetchall() + records = conn.execute(count_query).fetchall() # Last field is the count, preceding fields are the group by fields counts = { From a37db866f7e5e64dd9d2663ae1b4d117f18d4943 Mon Sep 17 00:00:00 2001 From: Kriti Godey Date: Thu, 22 Jul 2021 16:56:32 -0400 Subject: [PATCH 3/4] Don't assume that there's going to be an ID column to count. --- db/records.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/db/records.py b/db/records.py index ff62c3f8a6..bdafae3bbd 100644 --- a/db/records.py +++ b/db/records.py @@ -4,7 +4,6 @@ from sqlalchemy_filters import apply_filters, apply_sort from sqlalchemy_filters.exceptions import FieldNotFound -from db.constants import ID logger = logging.getLogger(__name__) @@ -110,7 +109,7 @@ def get_group_counts( subquery.columns[col] if type(col) == str else subquery.columns[col.name] for col in group_by ] - count_query = select(*columns, func.count(subquery.c[ID])).group_by(*columns) + count_query = select(*columns, func.count(columns[0])).group_by(*columns) with engine.begin() as conn: records = conn.execute(count_query).fetchall() From 2e57f4653ff114ed010d21e831b0fded6712a4c0 Mon Sep 17 00:00:00 2001 From: Kriti Godey Date: Thu, 22 Jul 2021 17:55:55 -0400 Subject: [PATCH 4/4] Move query execution into its own function. --- db/records.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/db/records.py b/db/records.py index bdafae3bbd..dfeb7db0b6 100644 --- a/db/records.py +++ b/db/records.py @@ -40,13 +40,18 @@ def _get_query(table, limit, offset, order_by, filters): return query +def _execute_query(query, engine): + with engine.begin() as conn: + records = conn.execute(query).fetchall() + return records + + def get_record(table, engine, id_value): primary_key_column = _get_primary_key_column(table) query = select(table).where(primary_key_column == id_value) - with engine.begin() as conn: - result = conn.execute(query).fetchall() - assert len(result) <= 1 - return result[0] if result else None + result = _execute_query(query, engine) + assert len(result) <= 1 + return result[0] if result else None def get_records( @@ -68,8 +73,7 @@ def get_records( See: https://github.com/centerofci/sqlalchemy-filters#filters-format """ query = _get_query(table, limit, offset, order_by, filters) - with engine.begin() as conn: - return conn.execute(query).fetchall() + return _execute_query(query, engine) def get_group_counts( @@ -110,8 +114,7 @@ def get_group_counts( for col in group_by ] count_query = select(*columns, func.count(columns[0])).group_by(*columns) - with engine.begin() as conn: - records = conn.execute(count_query).fetchall() + records = _execute_query(count_query, engine) # Last field is the count, preceding fields are the group by fields counts = { @@ -153,9 +156,8 @@ def get_distinct_tuple_values( .limit(limit) .offset(offset) ) - with engine.begin() as conn: - res = conn.execute(query).fetchall() - return [tuple(zip(column_objects, row)) for row in res] + result = _execute_query(query, engine) + return [tuple(zip(column_objects, row)) for row in result] def distinct_tuples_to_filter(distinct_tuples):