diff --git a/app/api/src/db/sql/functions/indicators/count_public_transport_station_frequency.sql b/app/api/src/db/sql/functions/indicators/count_public_transport_services_station.sql similarity index 100% rename from app/api/src/db/sql/functions/indicators/count_public_transport_station_frequency.sql rename to app/api/src/db/sql/functions/indicators/count_public_transport_services_station.sql diff --git a/app/api/src/db/sql/init_sql.py b/app/api/src/db/sql/init_sql.py index acc21eede..096b3a6ef 100644 --- a/app/api/src/db/sql/init_sql.py +++ b/app/api/src/db/sql/init_sql.py @@ -1,3 +1,7 @@ +#! /usr/bin/env python +import argparse +import sys +import textwrap from pathlib import Path from alembic_utils.pg_function import PGFunction @@ -7,7 +11,7 @@ from src.core.config import settings from src.db.session import legacy_engine -from src.db.sql.utils import sorted_path_by_dependency +from src.db.sql.utils import report, sorted_path_by_dependency def sql_function_entities(): @@ -71,6 +75,48 @@ def upgrade_triggers(): legacy_engine.execute(text(statement.text)) +def run(args): + action = args.action + material = args.material + if action == "report": + report() + else: + globals()[f"{action}_{material}"]() + print(f"{action.title()} {material} complete!") + + +def main(): + parser = argparse.ArgumentParser( + description="Upgrade and Downgrade sql functions and triggers", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=textwrap.dedent( + """ + example usage: + cd /app + python src/db/sql/init_sql.py upgrade -m functions + python src/db/sql/init_sql.py downgrade -m triggers + python src/db/sql/init_sql.py report + """ + ), + ) + parser.add_argument( + "action", + help="The action to do on database", + choices=["upgrade", "downgrade", "report"], + type=str, + ) + parser.add_argument( + "--material", + "-m", + required="upgrade" in sys.argv or "downgrade" in sys.argv, + help="functions or triggers", + choices=["functions", "triggers"], + type=str, + ) + parser.set_defaults(func=run) + args = parser.parse_args() + args.func(args) + + if __name__ == "__main__": - upgrade_functions() - print() + main() diff --git a/app/api/src/db/sql/utils.py b/app/api/src/db/sql/utils.py index 72282b16a..2b1f2cd26 100644 --- a/app/api/src/db/sql/utils.py +++ b/app/api/src/db/sql/utils.py @@ -1,7 +1,14 @@ +import functools import os -from collections import namedtuple +from collections import defaultdict, namedtuple from pathlib import Path +import rich +from sqlalchemy import text + +from src.core.config import settings +from src.db.session import legacy_engine + def find_unapplied_dependencies(function_content, function_list): dependencies = set() @@ -51,7 +58,61 @@ def sorted_path_by_dependency(path_list): return new_path_list -if __name__ == "__main__": - path_list = Path(str(Path().resolve()) + "/src/db/sql/functions").rglob("*.sql") - new_path_list = sorted_path_by_dependency(path_list) - print(new_path_list) +def list_functions(): + query = """ + SELECT + routine_name + FROM + information_schema.routines + WHERE + routine_type = 'FUNCTION' + AND + routine_schema = :functions_schema; + """ + query = text(query) + with legacy_engine.connect() as session: + functions = session.execute( + query, {"functions_schema": settings.POSTGRES_FUNCTIONS_SCHEMA} + ) + return [f["routine_name"] for f in functions] + + +def report(): + function_paths = Path(str(Path().resolve()) + "/src/db/sql/functions").rglob("*.sql") + triger_paths = Path(str(Path().resolve()) + "/src/db/sql/triggers").glob("*.sql") + files = list(function_paths) + list(triger_paths) + + functions = list_functions() + classified_functions = defaultdict(list) + not_in_db = [] + for fn in files: + if not fn: + continue + + file_name = fn.parts[-1] + directory_name = fn.parts[-2] + if get_name_from_path(fn) in functions: + functions.remove(get_name_from_path(fn)) + else: + not_in_db.append((directory_name, file_name)) + classified_functions[directory_name].append(file_name) + + for key in classified_functions.keys(): + rich.print(f"[bold blue]## {key}/[/bold blue]") + for fn in classified_functions[key]: + print("- ", fn) + print() + if functools: + rich.print(f"[bold red]# in db but not in files:[/bold red]") + for fn in functions: + rich.print(f"[orange1]- {fn}[/orange1]") + else: + rich.print(f"[bold green]All database function names found in files.[/bold green]") + + print() + if not_in_db: + rich.print(f"[bold red]## in files but not in db:[/bold red]") + for fn in not_in_db: + rich.print(f"- {fn[0]}/[orange1]{fn[1]}[/orange1]") + else: + rich.print(f"[bold green]All function file names found in database.[/bold green]")