diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a666d78f..bebe92e2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,36 @@ # Changelog +## [11.2.0](https://github.com/uktrade/platform-tools/compare/11.1.0...11.2.0) (2024-11-04) + + +### Features + +* DBTP-1071 Generate terraform config for environment pipeline ([#611](https://github.com/uktrade/platform-tools/issues/611)) ([237fb35](https://github.com/uktrade/platform-tools/commit/237fb35fe06df7fd13e93419d282dc067187d952)) + + +### Documentation + +* Call out the offline command fix in the changelog ([#613](https://github.com/uktrade/platform-tools/issues/613)) ([e2a6396](https://github.com/uktrade/platform-tools/commit/e2a63961260d3b60a1ae9aa99a1bd06927e98ae9)) + +## [11.1.0](https://github.com/uktrade/platform-tools/compare/11.0.1...11.1.0) (2024-10-30) + + +### Features + +* DBTP-1159 Add validation for duplicate entries in platform-config.yml ([#604](https://github.com/uktrade/platform-tools/issues/604)) ([d00e143](https://github.com/uktrade/platform-tools/commit/d00e143ecaa9e86645563d996ed79779cae52597)) +* DBTP-1215 Improve error message when AWS profile not set ([#607](https://github.com/uktrade/platform-tools/issues/607)) ([beb0e7f](https://github.com/uktrade/platform-tools/commit/beb0e7f12013f035a1ffe2796a22b2a1bc70ed5f)) +* Delete data dump from S3 after data load has been successful ([#600](https://github.com/uktrade/platform-tools/issues/600)) ([410cd56](https://github.com/uktrade/platform-tools/commit/410cd5673eccce5855d03b4f0cbb4d6c1377085a)) + + +### Bug Fixes + +* Fix issue with offline command resulting in 'CreateRule operation: Priority '100' is currently in use' error + + +### Documentation + +* Add a note about regression/integration testing to the README.md ([#612](https://github.com/uktrade/platform-tools/issues/612)) ([d219356](https://github.com/uktrade/platform-tools/commit/d219356e41efb3b6eab3950a921aaf6e5b3b7d9c)) + ## [11.0.1](https://github.com/uktrade/platform-tools/compare/11.0.0...11.0.1) (2024-10-22) diff --git a/README.md b/README.md index 3ab852cb7..c49950929 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,10 @@ Run `pip install ` and confirm the installation has worked by running `pla > [!IMPORTANT] > When testing is complete, do not forget to revert the `dbt-platform-helper` installation back to what it was; e.g. `pip install dbt-platform-helper==0.1.39`. +#### End to end testing + +Because this codebase is only fully exercised in conjunction with several others, we have [platform-end-to-end-tests](https://github.com/uktrade/platform-end-to-end-tests), which orchestrates the testing of them working together. + ### Publishing Publishing to PyPI happens automatically when a GitHub Release is published. To publish the Python package `dbt-platform-helper` manually, you will need an API token. diff --git a/buildspec.database-copy.yml b/buildspec.database-copy.yml new file mode 100644 index 000000000..bafdbd9a0 --- /dev/null +++ b/buildspec.database-copy.yml @@ -0,0 +1,24 @@ +version: 0.2 + +phases: + pre_build: + commands: + - echo "3.9" > .python-version + - cd images/tools/database-copy/ + - echo Login to Amazon ECR + - aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/uktrade + - SHORT_HASH=$(git rev-parse --short HEAD) + + build: + commands: + - echo "Build database-copy ($SHORT_HASH) started on $(date)" + - docker build --tag public.ecr.aws/uktrade/database-copy:tag-latest . + - docker tag public.ecr.aws/uktrade/database-copy:tag-latest public.ecr.aws/uktrade/database-copy:$SHORT_HASH + - echo "Build database-copy ($SHORT_HASH) completed on $(date)" + + post_build: + commands: + - echo "Push database-copy ($SHORT_HASH) started on $(date)" + - docker push public.ecr.aws/uktrade/database-copy:tag-latest + - docker push public.ecr.aws/uktrade/database-copy:$SHORT_HASH + - echo "Push database-copy ($SHORT_HASH) completed on $(date)" diff --git a/dbt_platform_helper/COMMANDS.md b/dbt_platform_helper/COMMANDS.md index cd6a0b944..34cb5ebb9 100644 --- a/dbt_platform_helper/COMMANDS.md +++ b/dbt_platform_helper/COMMANDS.md @@ -42,6 +42,7 @@ - [platform-helper database](#platform-helper-database) - [platform-helper database dump](#platform-helper-database-dump) - [platform-helper database load](#platform-helper-database-load) +- [platform-helper database copy](#platform-helper-database-copy) - [platform-helper version](#platform-helper-version) - [platform-helper version get-platform-helper-for-project](#platform-helper-version-get-platform-helper-for-project) @@ -792,14 +793,36 @@ platform-helper pipeline generate Given a platform-config.yml file, generate environment and service deployment pipelines. + This command does the following in relation to the environment pipelines: + - Reads contents of `platform-config.yml/environment-pipelines` configuration. + The `terraform/environment-pipelines//main.tf` file is generated using this configuration. + The `main.tf` file is then used to generate Terraform for creating an environment pipeline resource. + + This command does the following in relation to the codebase pipelines: + - Generates the copilot pipeline manifest.yml for copilot/pipelines/ + + (Deprecated) This command does the following for non terraform projects (legacy AWS Copilot): + - Generates the copilot manifest.yml for copilot/environments/ + ## Usage ``` -platform-helper pipeline generate +platform-helper pipeline generate [--terraform-platform-modules-version ] + [--deploy-branch ] ``` ## Options +- `--terraform-platform-modules-version ` + - Override the default version of terraform-platform-modules with a specific version or branch. +Precedence of version used is version supplied via CLI, then the version found in +platform-config.yml/default_versions/terraform-platform-modules. +In absence of these inputs, defaults to version '5'. +- `--deploy-branch ` + - Specify the branch of -deploy used to configure the source stage in the environment-pipeline resource. +This is generated from the terraform/environments-pipeline//main.tf file. +(Default -deploy branch is specified in +-deploy/platform-config.yml/environment_pipelines//branch). - `--help ` _Defaults to False._ - Show this message and exit. @@ -962,10 +985,12 @@ platform-helper notify add-comment [↩ Parent](#platform-helper) + Commands to copy data between databases. + ## Usage ``` -platform-helper database (dump|load) +platform-helper database (dump|load|copy) ``` ## Options @@ -975,6 +1000,7 @@ platform-helper database (dump|load) ## Commands +- [`copy` ↪](#platform-helper-database-copy) - [`dump` ↪](#platform-helper-database-dump) - [`load` ↪](#platform-helper-database-load) @@ -987,23 +1013,20 @@ platform-helper database (dump|load) ## Usage ``` -platform-helper database dump --account-id --app - --env --database - --vpc-name +platform-helper database dump --from --database + [--app ] [--from-vpc ] ``` ## Options -- `--account-id ` - - `--app ` - -- `--env ` - + - The application name. Required unless you are running the command from your deploy repo +- `--from ` + - The environment you are dumping data from - `--database ` - -- `--vpc-name ` - + - The name of the database you are dumping data from +- `--from-vpc ` + - The vpc the specified environment is running in. Required unless you are running the command from your deploy repo - `--help ` _Defaults to False._ - Show this message and exit. @@ -1016,22 +1039,62 @@ platform-helper database dump --account-id --app ## Usage ``` -platform-helper database load --account-id --app - --env --database - --vpc-name +platform-helper database load --to --database + [--app ] [--to-vpc ] + [--auto-approve] ``` ## Options -- `--account-id ` - - `--app ` + - The application name. Required unless you are running the command from your deploy repo +- `--to ` + - The environment you are loading data into +- `--database ` + - The name of the database you are loading data into +- `--to-vpc ` + - The vpc the specified environment is running in. Required unless you are running the command from your deploy repo +- `--auto-approve ` _Defaults to False._ -- `--env ` +- `--help ` _Defaults to False._ + - Show this message and exit. +# platform-helper database copy + +[↩ Parent](#platform-helper-database) + + Copy a database between environments. + +## Usage + +``` +platform-helper database copy --from --to --database + --svc [--app ] [--from-vpc ] + [--to-vpc ] [--template (default|migration|dmas-migration)] + [--auto-approve] [--no-maintenance-page] +``` + +## Options + +- `--app ` + - The application name. Required unless you are running the command from your deploy repo +- `--from ` + - The environment you are copying data from +- `--to ` + - The environment you are copying data into - `--database ` + - The name of the database you are copying +- `--from-vpc ` + - The vpc the environment you are copying from is running in. Required unless you are running the command from your deploy repo +- `--to-vpc ` + - The vpc the environment you are copying into is running in. Required unless you are running the command from your deploy repo +- `--auto-approve ` _Defaults to False._ -- `--vpc-name ` +- `--svc ` _Defaults to ['web']._ + +- `--template ` _Defaults to default._ + - The maintenance page you wish to put up. +- `--no-maintenance-page ` _Defaults to False._ - `--help ` _Defaults to False._ - Show this message and exit. diff --git a/dbt_platform_helper/commands/database.py b/dbt_platform_helper/commands/database.py index 0c1f9a881..c8d4e729e 100644 --- a/dbt_platform_helper/commands/database.py +++ b/dbt_platform_helper/commands/database.py @@ -1,33 +1,112 @@ import click -from dbt_platform_helper.commands.database_helpers import DatabaseCopy +from dbt_platform_helper.commands.environment import AVAILABLE_TEMPLATES +from dbt_platform_helper.domain.database_copy import DatabaseCopy from dbt_platform_helper.utils.click import ClickDocOptGroup @click.group(chain=True, cls=ClickDocOptGroup) def database(): - pass + """Commands to copy data between databases.""" @database.command(name="dump") -@click.option("--account-id", type=str, required=True) -@click.option("--app", type=str, required=True) -@click.option("--env", type=str, required=True) -@click.option("--database", type=str, required=True) -@click.option("--vpc-name", type=str, required=True) -def dump(account_id, app, env, database, vpc_name): +@click.option( + "--app", + type=str, + help="The application name. Required unless you are running the command from your deploy repo", +) +@click.option( + "--from", + "from_env", + type=str, + required=True, + help="The environment you are dumping data from", +) +@click.option( + "--database", type=str, required=True, help="The name of the database you are dumping data from" +) +@click.option( + "--from-vpc", + type=str, + help="The vpc the specified environment is running in. Required unless you are running the command from your deploy repo", +) +def dump(app, from_env, database, from_vpc): """Dump a database into an S3 bucket.""" - data_copy = DatabaseCopy(account_id, app, env, database, vpc_name) - data_copy.dump() + data_copy = DatabaseCopy(app, database) + data_copy.dump(from_env, from_vpc) @database.command(name="load") -@click.option("--account-id", type=str, required=True) -@click.option("--app", type=str, required=True) -@click.option("--env", type=str, required=True) -@click.option("--database", type=str, required=True) -@click.option("--vpc-name", type=str, required=True) -def load(account_id, app, env, database, vpc_name): +@click.option( + "--app", + type=str, + help="The application name. Required unless you are running the command from your deploy repo", +) +@click.option( + "--to", "to_env", type=str, required=True, help="The environment you are loading data into" +) +@click.option( + "--database", type=str, required=True, help="The name of the database you are loading data into" +) +@click.option( + "--to-vpc", + type=str, + help="The vpc the specified environment is running in. Required unless you are running the command from your deploy repo", +) +@click.option("--auto-approve/--no-auto-approve", default=False) +def load(app, to_env, database, to_vpc, auto_approve): """Load a database from an S3 bucket.""" - data_copy = DatabaseCopy(account_id, app, env, database, vpc_name) - data_copy.load() + data_copy = DatabaseCopy(app, database, auto_approve) + data_copy.load(to_env, to_vpc) + + +@database.command(name="copy") +@click.option( + "--app", + type=str, + help="The application name. Required unless you are running the command from your deploy repo", +) +@click.option( + "--from", "from_env", type=str, required=True, help="The environment you are copying data from" +) +@click.option( + "--to", "to_env", type=str, required=True, help="The environment you are copying data into" +) +@click.option( + "--database", type=str, required=True, help="The name of the database you are copying" +) +@click.option( + "--from-vpc", + type=str, + help="The vpc the environment you are copying from is running in. Required unless you are running the command from your deploy repo", +) +@click.option( + "--to-vpc", + type=str, + help="The vpc the environment you are copying into is running in. Required unless you are running the command from your deploy repo", +) +@click.option("--auto-approve/--no-auto-approve", default=False) +@click.option("--svc", type=str, required=True, multiple=True, default=["web"]) +@click.option( + "--template", + type=click.Choice(AVAILABLE_TEMPLATES), + default="default", + help="The maintenance page you wish to put up.", +) +@click.option("--no-maintenance-page", flag_value=True) +def copy( + app, + from_env, + to_env, + database, + from_vpc, + to_vpc, + auto_approve, + svc, + template, + no_maintenance_page, +): + """Copy a database between environments.""" + data_copy = DatabaseCopy(app, database, auto_approve) + data_copy.copy(from_env, to_env, from_vpc, to_vpc, svc, template, no_maintenance_page) diff --git a/dbt_platform_helper/commands/database_helpers.py b/dbt_platform_helper/commands/database_helpers.py deleted file mode 100644 index f3c5a8ccd..000000000 --- a/dbt_platform_helper/commands/database_helpers.py +++ /dev/null @@ -1,145 +0,0 @@ -import boto3 -import click - -from dbt_platform_helper.utils.aws import Vpc -from dbt_platform_helper.utils.aws import get_aws_session_or_abort -from dbt_platform_helper.utils.aws import get_connection_string -from dbt_platform_helper.utils.aws import get_vpc_info_by_name - - -def run_database_copy_task( - session: boto3.session.Session, - account_id: str, - app: str, - env: str, - database: str, - vpc_config: Vpc, - is_dump: bool, - db_connection_string: str, -): - client = session.client("ecs") - action = "dump" if is_dump else "load" - response = client.run_task( - taskDefinition=f"arn:aws:ecs:eu-west-2:{account_id}:task-definition/{app}-{env}-{database}-{action}", - cluster=f"{app}-{env}", - capacityProviderStrategy=[ - {"capacityProvider": "FARGATE", "weight": 1, "base": 0}, - ], - networkConfiguration={ - "awsvpcConfiguration": { - "subnets": vpc_config.subnets, - "securityGroups": vpc_config.security_groups, - "assignPublicIp": "DISABLED", - } - }, - overrides={ - "containerOverrides": [ - { - "name": f"{app}-{env}-{database}-{action}", - "environment": [ - {"name": "DATA_COPY_OPERATION", "value": action.upper()}, - {"name": "DB_CONNECTION_STRING", "value": db_connection_string}, - ], - } - ] - }, - ) - - return response.get("tasks", [{}])[0].get("taskArn") - - -class DatabaseCopy: - def __init__( - self, - account_id, - app, - env, - database, - vpc_name, - get_session_fn=get_aws_session_or_abort, - run_database_copy_fn=run_database_copy_task, - vpc_config_fn=get_vpc_info_by_name, - db_connection_string_fn=get_connection_string, - input_fn=click.prompt, - echo_fn=click.secho, - ): - self.account_id = account_id - self.app = app - self.env = env - self.database = database - self.vpc_name = vpc_name - self.get_session_fn = get_session_fn - self.run_database_copy_fn = run_database_copy_fn - self.vpc_config_fn = vpc_config_fn - self.db_connection_string_fn = db_connection_string_fn - self.input_fn = input_fn - self.echo_fn = echo_fn - - def _execute_operation(self, is_dump): - session = self.get_session_fn() - vpc_config = self.vpc_config_fn(session, self.app, self.env, self.vpc_name) - database_identifier = f"{self.app}-{self.env}-{self.database}" - db_connection_string = self.db_connection_string_fn( - session, self.app, self.env, database_identifier - ) - task_arn = self.run_database_copy_fn( - session, - self.account_id, - self.app, - self.env, - self.database, - vpc_config, - is_dump, - db_connection_string, - ) - - self.echo_fn( - f"Task {task_arn} started. Waiting for it to complete (this may take some time)...", - fg="green", - ) - self.tail_logs(is_dump) - self.wait_for_task_to_stop(task_arn) - - def dump(self): - self._execute_operation(True) - - def load(self): - if self.is_confirmed_ready_to_load(): - self._execute_operation(False) - - def is_confirmed_ready_to_load(self): - user_input = self.input_fn( - f"Are all tasks using {self.database} in the {self.env} environment stopped? (y/n)" - ) - return user_input.lower().strip() in ["y", "yes"] - - def tail_logs(self, is_dump: bool): - action = "dump" if is_dump else "load" - log_group_name = f"/ecs/{self.app}-{self.env}-{self.database}-{action}" - log_group_arn = f"arn:aws:logs:eu-west-2:{self.account_id}:log-group:{log_group_name}" - self.echo_fn(f"Tailing logs for {log_group_name}", fg="yellow") - session = self.get_session_fn() - response = session.client("logs").start_live_tail(logGroupIdentifiers=[log_group_arn]) - - stopped = False - for data in response["responseStream"]: - if stopped: - break - results = data.get("sessionUpdate", {}).get("sessionResults", []) - for result in results: - message = result.get("message") - - if message: - if message.startswith("Stopping data "): - stopped = True - self.echo_fn(message) - - def wait_for_task_to_stop(self, task_arn): - self.echo_fn("Waiting for task to complete", fg="yellow") - client = self.get_session_fn().client("ecs") - waiter = client.get_waiter("tasks_stopped") - waiter.wait( - cluster=f"{self.app}-{self.env}", - tasks=[task_arn], - WaiterConfig={"Delay": 6, "MaxAttempts": 300}, - ) diff --git a/dbt_platform_helper/commands/environment.py b/dbt_platform_helper/commands/environment.py index c25f79880..824dfce40 100644 --- a/dbt_platform_helper/commands/environment.py +++ b/dbt_platform_helper/commands/environment.py @@ -1,19 +1,11 @@ -import random -import re -import string -from pathlib import Path -from typing import List -from typing import Union - import boto3 import click from schema import SchemaError from dbt_platform_helper.constants import DEFAULT_TERRAFORM_PLATFORM_MODULES_VERSION from dbt_platform_helper.constants import PLATFORM_CONFIG_FILE -from dbt_platform_helper.utils.application import Environment -from dbt_platform_helper.utils.application import Service -from dbt_platform_helper.utils.application import load_application +from dbt_platform_helper.domain.maintenance_page import MaintenancePageProvider +from dbt_platform_helper.providers.load_balancers import find_https_listener from dbt_platform_helper.utils.aws import get_aws_session_or_abort from dbt_platform_helper.utils.click import ClickDocOptGroup from dbt_platform_helper.utils.files import apply_environment_defaults @@ -47,72 +39,8 @@ def environment(): @click.option("--vpc", type=str) def offline(app, env, svc, template, vpc): """Take load-balanced web services offline with a maintenance page.""" - application = get_application(app) - application_environment = get_app_environment(app, env) - - if "*" in svc: - services = [ - s for s in application.services.values() if s.kind == "Load Balanced Web Service" - ] - else: - all_services = [get_app_service(app, s) for s in list(svc)] - services = [s for s in all_services if s.kind == "Load Balanced Web Service"] - - if not services: - click.secho(f"No services deployed yet to {app} environment {env}", fg="red") - raise click.Abort - - try: - https_listener = find_https_listener(application_environment.session, app, env) - current_maintenance_page = get_maintenance_page( - application_environment.session, https_listener - ) - remove_current_maintenance_page = False - if current_maintenance_page: - remove_current_maintenance_page = click.confirm( - f"There is currently a '{current_maintenance_page}' maintenance page for the {env} " - f"environment in {app}.\nWould you like to replace it with a '{template}' " - f"maintenance page?" - ) - if not remove_current_maintenance_page: - raise click.Abort - - if remove_current_maintenance_page or click.confirm( - f"You are about to enable the '{template}' maintenance page for the {env} " - f"environment in {app}.\nWould you like to continue?" - ): - if current_maintenance_page and remove_current_maintenance_page: - remove_maintenance_page(application_environment.session, https_listener) - - allowed_ips = get_env_ips(vpc, application_environment) - - add_maintenance_page( - application_environment.session, - https_listener, - app, - env, - services, - allowed_ips, - template, - ) - click.secho( - f"Maintenance page '{template}' added for environment {env} in application {app}", - fg="green", - ) - else: - raise click.Abort - - except LoadBalancerNotFoundError: - click.secho( - f"No load balancer found for environment {env} in the application {app}.", fg="red" - ) - raise click.Abort - - except ListenerNotFoundError: - click.secho( - f"No HTTPS listener found for environment {env} in the application {app}.", fg="red" - ) - raise click.Abort + maintenance_page = MaintenancePageProvider() + maintenance_page.activate(app, env, svc, template, vpc) @environment.command() @@ -120,85 +48,8 @@ def offline(app, env, svc, template, vpc): @click.option("--env", type=str, required=True) def online(app, env): """Remove a maintenance page from an environment.""" - application_environment = get_app_environment(app, env) - - try: - https_listener = find_https_listener(application_environment.session, app, env) - current_maintenance_page = get_maintenance_page( - application_environment.session, https_listener - ) - if not current_maintenance_page: - click.secho("There is no current maintenance page to remove", fg="red") - raise click.Abort - - if not click.confirm( - f"There is currently a '{current_maintenance_page}' maintenance page, " - f"would you like to remove it?" - ): - raise click.Abort - - remove_maintenance_page(application_environment.session, https_listener) - click.secho( - f"Maintenance page removed from environment {env} in application {app}", fg="green" - ) - - except LoadBalancerNotFoundError: - click.secho( - f"No load balancer found for environment {env} in the application {app}.", fg="red" - ) - raise click.Abort - - except ListenerNotFoundError: - click.secho( - f"No HTTPS listener found for environment {env} in the application {app}.", fg="red" - ) - raise click.Abort - - -def get_application(app_name: str): - return load_application(app_name) - - -def get_app_environment(app_name: str, env_name: str) -> Environment: - application = get_application(app_name) - application_environment = application.environments.get(env_name) - - if not application_environment: - click.secho( - f"The environment {env_name} was not found in the application {app_name}. " - f"It either does not exist, or has not been deployed.", - fg="red", - ) - raise click.Abort - - return application_environment - - -def get_app_service(app_name: str, svc_name: str) -> Service: - application = get_application(app_name) - application_service = application.services.get(svc_name) - - if not application_service: - click.secho( - f"The service {svc_name} was not found in the application {app_name}. " - f"It either does not exist, or has not been deployed.", - fg="red", - ) - raise click.Abort - - return application_service - - -def get_listener_rule_by_tag(elbv2_client, listener_arn, tag_key, tag_value): - response = elbv2_client.describe_rules(ListenerArn=listener_arn) - for rule in response["Rules"]: - rule_arn = rule["RuleArn"] - - tags_response = elbv2_client.describe_tags(ResourceArns=[rule_arn]) - for tag_description in tags_response["TagDescriptions"]: - for tag in tag_description["Tags"]: - if tag["Key"] == tag_key and tag["Value"] == tag_value: - return rule + maintenance_page = MaintenancePageProvider() + maintenance_page.deactivate(app, env) def get_vpc_id(session, env_name, vpc_name=None): @@ -250,20 +101,6 @@ def get_cert_arn(session, application, env_name): return arn -def get_env_ips(vpc: str, application_environment: Environment) -> List[str]: - account_name = f"{application_environment.session.profile_name}-vpc" - vpc_name = vpc if vpc else account_name - ssm_client = application_environment.session.client("ssm") - - try: - param_value = ssm_client.get_parameter(Name=f"/{vpc_name}/EGRESS_IPS")["Parameter"]["Value"] - except ssm_client.exceptions.ParameterNotFound: - click.secho(f"No parameter found with name: /{vpc_name}/EGRESS_IPS") - raise click.Abort - - return [ip.strip() for ip in param_value.split(",")] - - @environment.command() @click.option("--vpc-name", hidden=True) @click.option("--name", "-n", required=True) @@ -362,44 +199,6 @@ def _determine_terraform_platform_modules_version(env_conf, cli_terraform_platfo return [version for version in version_preference_order if version][0] -def find_load_balancer(session: boto3.Session, app: str, env: str) -> str: - lb_client = session.client("elbv2") - - describe_response = lb_client.describe_load_balancers() - load_balancers = [lb["LoadBalancerArn"] for lb in describe_response["LoadBalancers"]] - - load_balancers = lb_client.describe_tags(ResourceArns=load_balancers)["TagDescriptions"] - - load_balancer_arn = None - for lb in load_balancers: - tags = {t["Key"]: t["Value"] for t in lb["Tags"]} - if tags.get("copilot-application") == app and tags.get("copilot-environment") == env: - load_balancer_arn = lb["ResourceArn"] - - if not load_balancer_arn: - raise LoadBalancerNotFoundError() - - return load_balancer_arn - - -def find_https_listener(session: boto3.Session, app: str, env: str) -> str: - load_balancer_arn = find_load_balancer(session, app, env) - lb_client = session.client("elbv2") - listeners = lb_client.describe_listeners(LoadBalancerArn=load_balancer_arn)["Listeners"] - - listener_arn = None - - try: - listener_arn = next(l["ListenerArn"] for l in listeners if l["Protocol"] == "HTTPS") - except StopIteration: - pass - - if not listener_arn: - raise ListenerNotFoundError() - - return listener_arn - - def find_https_certificate(session: boto3.Session, app: str, env: str) -> str: listener_arn = find_https_listener(session, app, env) cert_client = session.client("elbv2") @@ -407,8 +206,6 @@ def find_https_certificate(session: boto3.Session, app: str, env: str) -> str: "Certificates" ] - certificate_arn = None - try: certificate_arn = next(c["CertificateArn"] for c in certificates if c["IsDefault"]) except StopIteration: @@ -417,314 +214,5 @@ def find_https_certificate(session: boto3.Session, app: str, env: str) -> str: return certificate_arn -def find_target_group(app: str, env: str, svc: str, session: boto3.Session) -> str: - rg_tagging_client = session.client("resourcegroupstaggingapi") - response = rg_tagging_client.get_resources( - TagFilters=[ - { - "Key": "copilot-application", - "Values": [ - app, - ], - "Key": "copilot-environment", - "Values": [ - env, - ], - "Key": "copilot-service", - "Values": [ - svc, - ], - }, - ], - ResourceTypeFilters=[ - "elasticloadbalancing:targetgroup", - ], - ) - for resource in response["ResourceTagMappingList"]: - tags = {tag["Key"]: tag["Value"] for tag in resource["Tags"]} - - if ( - "copilot-service" in tags - and tags["copilot-service"] == svc - and "copilot-environment" in tags - and tags["copilot-environment"] == env - and "copilot-application" in tags - and tags["copilot-application"] == app - ): - return resource["ResourceARN"] - - click.secho( - f"No target group found for application: {app}, environment: {env}, service: {svc}", - fg="red", - ) - - return None - - -def get_maintenance_page(session: boto3.Session, listener_arn: str) -> Union[str, None]: - lb_client = session.client("elbv2") - - rules = lb_client.describe_rules(ListenerArn=listener_arn)["Rules"] - tag_descriptions = get_rules_tag_descriptions(rules, lb_client) - - maintenance_page_type = None - for description in tag_descriptions: - tags = {t["Key"]: t["Value"] for t in description["Tags"]} - if tags.get("name") == "MaintenancePage": - maintenance_page_type = tags.get("type") - - return maintenance_page_type - - -def delete_listener_rule(tag_descriptions: list, tag_name: str, lb_client: boto3.client): - current_rule_arn = None - - for description in tag_descriptions: - tags = {t["Key"]: t["Value"] for t in description["Tags"]} - if tags.get("name") == tag_name: - current_rule_arn = description["ResourceArn"] - if current_rule_arn: - lb_client.delete_rule(RuleArn=current_rule_arn) - - return current_rule_arn - - -def remove_maintenance_page(session: boto3.Session, listener_arn: str): - lb_client = session.client("elbv2") - - rules = lb_client.describe_rules(ListenerArn=listener_arn)["Rules"] - tag_descriptions = get_rules_tag_descriptions(rules, lb_client) - tag_descriptions = lb_client.describe_tags(ResourceArns=[r["RuleArn"] for r in rules])[ - "TagDescriptions" - ] - - for name in ["MaintenancePage", "AllowedIps", "BypassIpFilter", "AllowedSourceIps"]: - deleted = delete_listener_rule(tag_descriptions, name, lb_client) - - if name == "MaintenancePage" and not deleted: - raise ListenerRuleNotFoundError() - - -def get_rules_tag_descriptions(rules: list, lb_client): - tag_descriptions = [] - chunk_size = 20 - - for i in range(0, len(rules), chunk_size): - chunk = rules[i : i + chunk_size] - resource_arns = [r["RuleArn"] for r in chunk] - response = lb_client.describe_tags(ResourceArns=resource_arns) - tag_descriptions.extend(response["TagDescriptions"]) - - return tag_descriptions - - -def get_host_conditions(lb_client: boto3.client, listener_arn: str, target_group_arn: str): - rules = lb_client.describe_rules(ListenerArn=listener_arn)["Rules"] - - # Get current set of forwarding conditions for the target group - for rule in rules: - for action in rule["Actions"]: - if action["Type"] == "forward" and action["TargetGroupArn"] == target_group_arn: - conditions = rule["Conditions"] - - # filter to host-header conditions - conditions = [ - {i: condition[i] for i in condition if i != "Values"} - for condition in conditions - if condition["Field"] == "host-header" - ] - - # remove internal hosts - conditions[0]["HostHeaderConfig"]["Values"] = [ - v for v in conditions[0]["HostHeaderConfig"]["Values"] - ] - - return conditions - - -def create_header_rule( - lb_client: boto3.client, - listener_arn: str, - target_group_arn: str, - header_name: str, - values: list, - rule_name: str, - priority: int, -): - conditions = get_host_conditions(lb_client, listener_arn, target_group_arn) - - # add new condition to existing conditions - combined_conditions = [ - { - "Field": "http-header", - "HttpHeaderConfig": {"HttpHeaderName": header_name, "Values": values}, - } - ] + conditions - - lb_client.create_rule( - ListenerArn=listener_arn, - Priority=priority, - Conditions=combined_conditions, - Actions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], - Tags=[ - {"Key": "name", "Value": rule_name}, - ], - ) - - click.secho( - f"Creating listener rule {rule_name} for HTTPS Listener with arn {listener_arn}.\n\nIf request header {header_name} contains one of the values {values}, the request will be forwarded to target group with arn {target_group_arn}.", - fg="green", - ) - - -def create_source_ip_rule( - lb_client: boto3.client, - listener_arn: str, - target_group_arn: str, - values: list, - rule_name: str, - priority: int, -): - conditions = get_host_conditions(lb_client, listener_arn, target_group_arn) - - # add new condition to existing conditions - combined_conditions = [ - { - "Field": "source-ip", - "SourceIpConfig": {"Values": [value + "/32" for value in values]}, - } - ] + conditions - - lb_client.create_rule( - ListenerArn=listener_arn, - Priority=priority, - Conditions=combined_conditions, - Actions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], - Tags=[ - {"Key": "name", "Value": rule_name}, - ], - ) - - click.secho( - f"Creating listener rule {rule_name} for HTTPS Listener with arn {listener_arn}.\n\nIf request source ip matches one of the values {values}, the request will be forwarded to target group with arn {target_group_arn}.", - fg="green", - ) - - -def add_maintenance_page( - session: boto3.Session, - listener_arn: str, - app: str, - env: str, - services: List[Service], - allowed_ips: tuple, - template: str = "default", -): - lb_client = session.client("elbv2") - maintenance_page_content = get_maintenance_page_template(template) - bypass_value = "".join(random.choices(string.ascii_lowercase + string.digits, k=12)) - - service_number = 1 - - for svc in services: - target_group_arn = find_target_group(app, env, svc.name, session) - - # not all of an application's services are guaranteed to have been deployed to an environment - if not target_group_arn: - continue - - allowed_ips = list(allowed_ips) - max_allowed_ips = 100 - for ip_index, ip in enumerate(allowed_ips): - forwarded_rule_priority = (service_number * max_allowed_ips) + ip_index - create_header_rule( - lb_client, - listener_arn, - target_group_arn, - "X-Forwarded-For", - [ip], - "AllowedIps", - forwarded_rule_priority, - ) - create_source_ip_rule( - lb_client, - listener_arn, - target_group_arn, - [ip], - "AllowedSourceIps", - forwarded_rule_priority + 1, - ) - - bypass_rule_priority = service_number - create_header_rule( - lb_client, - listener_arn, - target_group_arn, - "Bypass-Key", - [bypass_value], - "BypassIpFilter", - bypass_rule_priority, - ) - - service_number += 1 - - click.secho( - f"\nUse a browser plugin to add `Bypass-Key` header with value {bypass_value} to your requests. For more detail, visit https://platform.readme.trade.gov.uk/activities/holding-and-maintenance-pages/", - fg="green", - ) - - fixed_rule_priority = (service_number + 5) * max_allowed_ips - lb_client.create_rule( - ListenerArn=listener_arn, - Priority=fixed_rule_priority, # big number because we create multiple higher priority "AllowedIps" rules for each allowed ip for each service above. - Conditions=[ - { - "Field": "path-pattern", - "PathPatternConfig": {"Values": ["/*"]}, - } - ], - Actions=[ - { - "Type": "fixed-response", - "FixedResponseConfig": { - "StatusCode": "503", - "ContentType": "text/html", - "MessageBody": maintenance_page_content, - }, - } - ], - Tags=[ - {"Key": "name", "Value": "MaintenancePage"}, - {"Key": "type", "Value": template}, - ], - ) - - -def get_maintenance_page_template(template) -> str: - template_contents = ( - Path(__file__) - .parent.parent.joinpath( - f"templates/svc/maintenance_pages/{template}.html", - ) - .read_text() - .replace("\n", "") - ) - - # [^\S]\s+ - Remove any space that is not preceded by a non-space character. - return re.sub(r"[^\S]\s+", "", template_contents) - - class CertificateNotFoundError(Exception): pass - - -class LoadBalancerNotFoundError(Exception): - pass - - -class ListenerNotFoundError(Exception): - pass - - -class ListenerRuleNotFoundError(Exception): - pass diff --git a/dbt_platform_helper/commands/pipeline.py b/dbt_platform_helper/commands/pipeline.py index 682f40acd..3acdedf19 100644 --- a/dbt_platform_helper/commands/pipeline.py +++ b/dbt_platform_helper/commands/pipeline.py @@ -5,6 +5,7 @@ import click +from dbt_platform_helper.constants import DEFAULT_TERRAFORM_PLATFORM_MODULES_VERSION from dbt_platform_helper.utils.application import get_application_name from dbt_platform_helper.utils.aws import get_account_details from dbt_platform_helper.utils.aws import get_codestar_connection_arn @@ -24,6 +25,7 @@ CODEBASE_PIPELINES_KEY = "codebase_pipelines" ENVIRONMENTS_KEY = "environments" +ENVIRONMENT_PIPELINES_KEY = "environment_pipelines" @click.group(chain=True, cls=ClickDocOptGroup) @@ -33,18 +35,55 @@ def pipeline(): @pipeline.command() -def generate(): - """Given a platform-config.yml file, generate environment and service - deployment pipelines.""" +@click.option( + "--terraform-platform-modules-version", + help=f"""Override the default version of terraform-platform-modules with a specific version or branch. + Precedence of version used is version supplied via CLI, then the version found in + platform-config.yml/default_versions/terraform-platform-modules. + In absence of these inputs, defaults to version '{DEFAULT_TERRAFORM_PLATFORM_MODULES_VERSION}'.""", +) +@click.option( + "--deploy-branch", + help="""Specify the branch of -deploy used to configure the source stage in the environment-pipeline resource. + This is generated from the terraform/environments-pipeline//main.tf file. + (Default -deploy branch is specified in + -deploy/platform-config.yml/environment_pipelines//branch).""", + default=None, +) +def generate(terraform_platform_modules_version, deploy_branch): + """ + Given a platform-config.yml file, generate environment and service + deployment pipelines. + + This command does the following in relation to the environment pipelines: + - Reads contents of `platform-config.yml/environment-pipelines` configuration. + The `terraform/environment-pipelines//main.tf` file is generated using this configuration. + The `main.tf` file is then used to generate Terraform for creating an environment pipeline resource. + + This command does the following in relation to the codebase pipelines: + - Generates the copilot pipeline manifest.yml for copilot/pipelines/ + + (Deprecated) This command does the following for non terraform projects (legacy AWS Copilot): + - Generates the copilot manifest.yml for copilot/environments/ + """ pipeline_config = load_and_validate_platform_config() - no_codebase_pipelines = CODEBASE_PIPELINES_KEY not in pipeline_config - no_environment_pipelines = ENVIRONMENTS_KEY not in pipeline_config + has_codebase_pipelines = CODEBASE_PIPELINES_KEY in pipeline_config + has_legacy_environment_pipelines = ENVIRONMENTS_KEY in pipeline_config + has_environment_pipelines = ENVIRONMENT_PIPELINES_KEY in pipeline_config - if no_codebase_pipelines and no_environment_pipelines: + if ( + not has_codebase_pipelines + and not has_legacy_environment_pipelines + and not has_environment_pipelines + ): click.secho("No pipelines defined: nothing to do.", err=True, fg="yellow") return + platform_config_terraform_modules_default_version = pipeline_config.get( + "default_versions", {} + ).get("terraform-platform-modules", "") + templates = setup_templates() app_name = get_application_name() @@ -57,22 +96,34 @@ def generate(): abort_with_error(f'There is no CodeStar Connection named "{app_name}" to use') base_path = Path(".") - pipelines_dir = base_path / f"copilot/pipelines" + copilot_pipelines_dir = base_path / f"copilot/pipelines" - _clean_pipeline_config(pipelines_dir) + _clean_pipeline_config(copilot_pipelines_dir) - if not is_terraform_project() and ENVIRONMENTS_KEY in pipeline_config: + if is_terraform_project() and has_environment_pipelines: + environment_pipelines = pipeline_config[ENVIRONMENT_PIPELINES_KEY] + + for config in environment_pipelines.values(): + aws_account = config.get("account") + _generate_terraform_environment_pipeline_manifest( + pipeline_config["application"], + aws_account, + terraform_platform_modules_version, + platform_config_terraform_modules_default_version, + deploy_branch, + ) + if not is_terraform_project() and has_legacy_environment_pipelines: _generate_copilot_environments_pipeline( app_name, codestar_connection_arn, git_repo, apply_environment_defaults(pipeline_config)[ENVIRONMENTS_KEY], base_path, - pipelines_dir, + copilot_pipelines_dir, templates, ) - if CODEBASE_PIPELINES_KEY in pipeline_config: + if has_codebase_pipelines: account_id, _ = get_account_details() for codebase in pipeline_config[CODEBASE_PIPELINES_KEY]: @@ -83,7 +134,7 @@ def generate(): git_repo, codebase, base_path, - pipelines_dir, + copilot_pipelines_dir, templates, ) @@ -170,3 +221,43 @@ def _create_file_from_template( ).render(template_data) message = mkfile(base_path, pipelines_dir / file_name, contents, overwrite=True) click.echo(message) + + +def _generate_terraform_environment_pipeline_manifest( + application, + aws_account, + cli_terraform_platform_modules_version, + platform_config_terraform_modules_default_version, + deploy_branch, +): + env_pipeline_template = setup_templates().get_template("environment-pipelines/main.tf") + + terraform_platform_modules_version = _determine_terraform_platform_modules_version( + cli_terraform_platform_modules_version, platform_config_terraform_modules_default_version + ) + + contents = env_pipeline_template.render( + { + "application": application, + "aws_account": aws_account, + "terraform_platform_modules_version": terraform_platform_modules_version, + "deploy_branch": deploy_branch, + } + ) + + dir_path = f"terraform/environment-pipelines/{aws_account}" + makedirs(dir_path, exist_ok=True) + + click.echo(mkfile(".", f"{dir_path}/main.tf", contents, overwrite=True)) + + +def _determine_terraform_platform_modules_version( + cli_terraform_platform_modules_version, platform_config_terraform_modules_default_version +): + + version_preference_order = [ + cli_terraform_platform_modules_version, + platform_config_terraform_modules_default_version, + DEFAULT_TERRAFORM_PLATFORM_MODULES_VERSION, + ] + return [version for version in version_preference_order if version][0] diff --git a/dbt_platform_helper/domain/__init__.py b/dbt_platform_helper/domain/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbt_platform_helper/domain/database_copy.py b/dbt_platform_helper/domain/database_copy.py new file mode 100644 index 000000000..338319d10 --- /dev/null +++ b/dbt_platform_helper/domain/database_copy.py @@ -0,0 +1,220 @@ +import re +from collections.abc import Callable +from pathlib import Path + +import boto3 +import click +from boto3 import Session + +from dbt_platform_helper.constants import PLATFORM_CONFIG_FILE +from dbt_platform_helper.domain.maintenance_page import MaintenancePageProvider +from dbt_platform_helper.exceptions import AWSException +from dbt_platform_helper.utils.application import Application +from dbt_platform_helper.utils.application import ApplicationNotFoundError +from dbt_platform_helper.utils.application import load_application +from dbt_platform_helper.utils.aws import Vpc +from dbt_platform_helper.utils.aws import get_connection_string +from dbt_platform_helper.utils.aws import get_vpc_info_by_name +from dbt_platform_helper.utils.messages import abort_with_error +from dbt_platform_helper.utils.validation import load_and_validate_platform_config + + +class DatabaseCopy: + def __init__( + self, + app: str, + database: str, + auto_approve: bool = False, + load_application_fn: Callable[[str], Application] = load_application, + vpc_config_fn: Callable[[Session, str, str, str], Vpc] = get_vpc_info_by_name, + db_connection_string_fn: Callable[ + [Session, str, str, str, Callable], str + ] = get_connection_string, + maintenance_page_provider: Callable[ + [str, str, list[str], str, str], None + ] = MaintenancePageProvider(), + input_fn: Callable[[str], str] = click.prompt, + echo_fn: Callable[[str], str] = click.secho, + abort_fn: Callable[[str], None] = abort_with_error, + ): + self.app = app + self.database = database + self.auto_approve = auto_approve + self.vpc_config_fn = vpc_config_fn + self.db_connection_string_fn = db_connection_string_fn + self.maintenance_page_provider = maintenance_page_provider + self.input_fn = input_fn + self.echo_fn = echo_fn + self.abort_fn = abort_fn + + if not self.app: + if not Path(PLATFORM_CONFIG_FILE).exists(): + self.abort_fn("You must either be in a deploy repo, or provide the --app option.") + + config = load_and_validate_platform_config(disable_aws_validation=True) + self.app = config["application"] + + try: + self.application = load_application_fn(self.app) + except ApplicationNotFoundError: + abort_fn(f"No such application '{app}'.") + + def _execute_operation(self, is_dump: bool, env: str, vpc_name: str): + vpc_name = self.enrich_vpc_name(env, vpc_name) + + environments = self.application.environments + environment = environments.get(env) + if not environment: + self.abort_fn( + f"No such environment '{env}'. Available environments are: {', '.join(environments.keys())}" + ) + + env_session = environment.session + + try: + vpc_config = self.vpc_config_fn(env_session, self.app, env, vpc_name) + except AWSException as ex: + self.abort_fn(str(ex)) + + database_identifier = f"{self.app}-{env}-{self.database}" + + try: + db_connection_string = self.db_connection_string_fn( + env_session, self.app, env, database_identifier + ) + except Exception as exc: + self.abort_fn(f"{exc} (Database: {database_identifier})") + + try: + task_arn = self.run_database_copy_task( + env_session, env, vpc_config, is_dump, db_connection_string + ) + except Exception as exc: + self.abort_fn(f"{exc} (Account id: {self.account_id(env)})") + + if is_dump: + message = f"Dumping {self.database} from the {env} environment into S3" + else: + message = f"Loading data into {self.database} in the {env} environment from S3" + + self.echo_fn(message, fg="white", bold=True) + self.echo_fn( + f"Task {task_arn} started. Waiting for it to complete (this may take some time)...", + fg="white", + ) + self.tail_logs(is_dump, env) + + def enrich_vpc_name(self, env, vpc_name): + if not vpc_name: + if not Path(PLATFORM_CONFIG_FILE).exists(): + self.abort_fn( + "You must either be in a deploy repo, or provide the vpc name option." + ) + config = load_and_validate_platform_config(disable_aws_validation=True) + vpc_name = config.get("environments", {}).get(env, {}).get("vpc") + return vpc_name + + def run_database_copy_task( + self, + session: boto3.session.Session, + env: str, + vpc_config: Vpc, + is_dump: bool, + db_connection_string: str, + ) -> str: + client = session.client("ecs") + action = "dump" if is_dump else "load" + env_vars = [ + {"name": "DATA_COPY_OPERATION", "value": action.upper()}, + {"name": "DB_CONNECTION_STRING", "value": db_connection_string}, + ] + if not is_dump: + env_vars.append({"name": "ECS_CLUSTER", "value": f"{self.app}-{env}"}) + + response = client.run_task( + taskDefinition=f"arn:aws:ecs:eu-west-2:{self.account_id(env)}:task-definition/{self.app}-{env}-{self.database}-{action}", + cluster=f"{self.app}-{env}", + capacityProviderStrategy=[ + {"capacityProvider": "FARGATE", "weight": 1, "base": 0}, + ], + networkConfiguration={ + "awsvpcConfiguration": { + "subnets": vpc_config.subnets, + "securityGroups": vpc_config.security_groups, + "assignPublicIp": "DISABLED", + } + }, + overrides={ + "containerOverrides": [ + { + "name": f"{self.app}-{env}-{self.database}-{action}", + "environment": env_vars, + } + ] + }, + ) + + return response.get("tasks", [{}])[0].get("taskArn") + + def dump(self, env: str, vpc_name: str): + self._execute_operation(True, env, vpc_name) + + def load(self, env: str, vpc_name: str): + if self.is_confirmed_ready_to_load(env): + self._execute_operation(False, env, vpc_name) + + def copy( + self, + from_env: str, + to_env: str, + from_vpc: str, + to_vpc: str, + services: tuple[str], + template: str, + no_maintenance_page: bool = False, + ): + to_vpc = self.enrich_vpc_name(to_env, to_vpc) + if not no_maintenance_page: + self.maintenance_page_provider.activate(self.app, to_env, services, template, to_vpc) + self.dump(from_env, from_vpc) + self.load(to_env, to_vpc) + if not no_maintenance_page: + self.maintenance_page_provider.deactivate(self.app, to_env) + + def is_confirmed_ready_to_load(self, env: str) -> bool: + if self.auto_approve: + return True + + user_input = self.input_fn( + f"\nWARNING: the load operation is destructive and will delete the {self.database} database in the {env} environment. Continue? (y/n)" + ) + return user_input.lower().strip() in ["y", "yes"] + + def tail_logs(self, is_dump: bool, env: str): + action = "dump" if is_dump else "load" + log_group_name = f"/ecs/{self.app}-{env}-{self.database}-{action}" + log_group_arn = f"arn:aws:logs:eu-west-2:{self.account_id(env)}:log-group:{log_group_name}" + self.echo_fn(f"Tailing {log_group_name} logs", fg="yellow") + session = self.application.environments[env].session + response = session.client("logs").start_live_tail(logGroupIdentifiers=[log_group_arn]) + + stopped = False + for data in response["responseStream"]: + if stopped: + break + results = data.get("sessionUpdate", {}).get("sessionResults", []) + for result in results: + message = result.get("message") + + if message: + match = re.match(r"(Stopping|Aborting) data (load|dump).*", message) + if match: + if match.group(1) == "Aborting": + self.abort_fn("Task aborted abnormally. See logs above for details.") + stopped = True + self.echo_fn(message) + + def account_id(self, env): + envs = self.application.environments + if env in envs: + return envs.get(env).account_id diff --git a/dbt_platform_helper/domain/maintenance_page.py b/dbt_platform_helper/domain/maintenance_page.py new file mode 100644 index 000000000..506973266 --- /dev/null +++ b/dbt_platform_helper/domain/maintenance_page.py @@ -0,0 +1,459 @@ +import itertools +import random +import re +import string +from pathlib import Path +from typing import List +from typing import Union + +import boto3 +import click + +from dbt_platform_helper.providers.load_balancers import ListenerNotFoundError +from dbt_platform_helper.providers.load_balancers import ListenerRuleNotFoundError +from dbt_platform_helper.providers.load_balancers import LoadBalancerNotFoundError +from dbt_platform_helper.providers.load_balancers import find_https_listener +from dbt_platform_helper.utils.application import Environment +from dbt_platform_helper.utils.application import Service +from dbt_platform_helper.utils.application import load_application + + +class MaintenancePageProvider: + def activate(self, app, env, svc, template, vpc): + application = load_application(app) + application_environment = get_app_environment(app, env) + + if "*" in svc: + services = [ + s for s in application.services.values() if s.kind == "Load Balanced Web Service" + ] + else: + all_services = [get_app_service(app, s) for s in list(svc)] + services = [s for s in all_services if s.kind == "Load Balanced Web Service"] + + if not services: + click.secho(f"No services deployed yet to {app} environment {env}", fg="red") + raise click.Abort + + try: + https_listener = find_https_listener(application_environment.session, app, env) + current_maintenance_page = get_maintenance_page( + application_environment.session, https_listener + ) + remove_current_maintenance_page = False + if current_maintenance_page: + remove_current_maintenance_page = click.confirm( + f"There is currently a '{current_maintenance_page}' maintenance page for the {env} " + f"environment in {app}.\nWould you like to replace it with a '{template}' " + f"maintenance page?" + ) + if not remove_current_maintenance_page: + raise click.Abort + + if remove_current_maintenance_page or click.confirm( + f"You are about to enable the '{template}' maintenance page for the {env} " + f"environment in {app}.\nWould you like to continue?" + ): + if current_maintenance_page and remove_current_maintenance_page: + remove_maintenance_page(application_environment.session, https_listener) + + allowed_ips = get_env_ips(vpc, application_environment) + + add_maintenance_page( + application_environment.session, + https_listener, + app, + env, + services, + allowed_ips, + template, + ) + click.secho( + f"Maintenance page '{template}' added for environment {env} in application {app}", + fg="green", + ) + else: + raise click.Abort + + except LoadBalancerNotFoundError: + click.secho( + f"No load balancer found for environment {env} in the application {app}.", fg="red" + ) + raise click.Abort + + except ListenerNotFoundError: + click.secho( + f"No HTTPS listener found for environment {env} in the application {app}.", fg="red" + ) + raise click.Abort + + def deactivate(self, app, env): + application_environment = get_app_environment(app, env) + + try: + https_listener = find_https_listener(application_environment.session, app, env) + current_maintenance_page = get_maintenance_page( + application_environment.session, https_listener + ) + if not current_maintenance_page: + click.secho("There is no current maintenance page to remove", fg="red") + raise click.Abort + + if not click.confirm( + f"There is currently a '{current_maintenance_page}' maintenance page, " + f"would you like to remove it?" + ): + raise click.Abort + + remove_maintenance_page(application_environment.session, https_listener) + click.secho( + f"Maintenance page removed from environment {env} in application {app}", fg="green" + ) + + except LoadBalancerNotFoundError: + click.secho( + f"No load balancer found for environment {env} in the application {app}.", fg="red" + ) + raise click.Abort + + except ListenerNotFoundError: + click.secho( + f"No HTTPS listener found for environment {env} in the application {app}.", fg="red" + ) + raise click.Abort + + +def get_app_service(app_name: str, svc_name: str) -> Service: + application = load_application(app_name) + application_service = application.services.get(svc_name) + + if not application_service: + click.secho( + f"The service {svc_name} was not found in the application {app_name}. " + f"It either does not exist, or has not been deployed.", + fg="red", + ) + raise click.Abort + + return application_service + + +def get_app_environment(app_name: str, env_name: str) -> Environment: + application = load_application(app_name) + application_environment = application.environments.get(env_name) + + if not application_environment: + click.secho( + f"The environment {env_name} was not found in the application {app_name}. " + f"It either does not exist, or has not been deployed.", + fg="red", + ) + raise click.Abort + + return application_environment + + +def get_maintenance_page(session: boto3.Session, listener_arn: str) -> Union[str, None]: + lb_client = session.client("elbv2") + + rules = lb_client.describe_rules(ListenerArn=listener_arn)["Rules"] + tag_descriptions = get_rules_tag_descriptions(rules, lb_client) + + maintenance_page_type = None + for description in tag_descriptions: + tags = {t["Key"]: t["Value"] for t in description["Tags"]} + if tags.get("name") == "MaintenancePage": + maintenance_page_type = tags.get("type") + + return maintenance_page_type + + +def remove_maintenance_page(session: boto3.Session, listener_arn: str): + lb_client = session.client("elbv2") + + rules = lb_client.describe_rules(ListenerArn=listener_arn)["Rules"] + # TODO: The next line doesn't appear to do anything. + tag_descriptions = get_rules_tag_descriptions(rules, lb_client) + # TODO: In fact the following line seems to do the same but better. + tag_descriptions = lb_client.describe_tags(ResourceArns=[r["RuleArn"] for r in rules])[ + "TagDescriptions" + ] + + for name in ["MaintenancePage", "AllowedIps", "BypassIpFilter", "AllowedSourceIps"]: + deleted = delete_listener_rule(tag_descriptions, name, lb_client) + + if name == "MaintenancePage" and not deleted: + raise ListenerRuleNotFoundError() + + +def get_rules_tag_descriptions(rules: list, lb_client): + tag_descriptions = [] + chunk_size = 20 + + for i in range(0, len(rules), chunk_size): + chunk = rules[i : i + chunk_size] + resource_arns = [r["RuleArn"] for r in chunk] + response = lb_client.describe_tags(ResourceArns=resource_arns) + tag_descriptions.extend(response["TagDescriptions"]) + + return tag_descriptions + + +def delete_listener_rule(tag_descriptions: list, tag_name: str, lb_client: boto3.client): + current_rule_arn = None + + for description in tag_descriptions: + tags = {t["Key"]: t["Value"] for t in description["Tags"]} + if tags.get("name") == tag_name: + current_rule_arn = description["ResourceArn"] + if current_rule_arn: + lb_client.delete_rule(RuleArn=current_rule_arn) + + return current_rule_arn + + +def add_maintenance_page( + session: boto3.Session, + listener_arn: str, + app: str, + env: str, + services: List[Service], + allowed_ips: tuple, + template: str = "default", +): + lb_client = session.client("elbv2") + maintenance_page_content = get_maintenance_page_template(template) + bypass_value = "".join(random.choices(string.ascii_lowercase + string.digits, k=12)) + + rule_priority = itertools.count(start=1) + + for svc in services: + target_group_arn = find_target_group(app, env, svc.name, session) + + # not all of an application's services are guaranteed to have been deployed to an environment + if not target_group_arn: + continue + + for ip in allowed_ips: + create_header_rule( + lb_client, + listener_arn, + target_group_arn, + "X-Forwarded-For", + [ip], + "AllowedIps", + next(rule_priority), + ) + create_source_ip_rule( + lb_client, + listener_arn, + target_group_arn, + [ip], + "AllowedSourceIps", + next(rule_priority), + ) + + create_header_rule( + lb_client, + listener_arn, + target_group_arn, + "Bypass-Key", + [bypass_value], + "BypassIpFilter", + next(rule_priority), + ) + + click.secho( + f"\nUse a browser plugin to add `Bypass-Key` header with value {bypass_value} to your requests. For more detail, visit https://platform.readme.trade.gov.uk/activities/holding-and-maintenance-pages/", + fg="green", + ) + + lb_client.create_rule( + ListenerArn=listener_arn, + Priority=next(rule_priority), + Conditions=[ + { + "Field": "path-pattern", + "PathPatternConfig": {"Values": ["/*"]}, + } + ], + Actions=[ + { + "Type": "fixed-response", + "FixedResponseConfig": { + "StatusCode": "503", + "ContentType": "text/html", + "MessageBody": maintenance_page_content, + }, + } + ], + Tags=[ + {"Key": "name", "Value": "MaintenancePage"}, + {"Key": "type", "Value": template}, + ], + ) + + +def get_maintenance_page_template(template) -> str: + template_contents = ( + Path(__file__) + .parent.parent.joinpath( + f"templates/svc/maintenance_pages/{template}.html", + ) + .read_text() + .replace("\n", "") + ) + + # [^\S]\s+ - Remove any space that is not preceded by a non-space character. + return re.sub(r"[^\S]\s+", "", template_contents) + + +def find_target_group(app: str, env: str, svc: str, session: boto3.Session) -> str: + rg_tagging_client = session.client("resourcegroupstaggingapi") + response = rg_tagging_client.get_resources( + TagFilters=[ + { + "Key": "copilot-application", + "Values": [ + app, + ], + "Key": "copilot-environment", + "Values": [ + env, + ], + "Key": "copilot-service", + "Values": [ + svc, + ], + }, + ], + ResourceTypeFilters=[ + "elasticloadbalancing:targetgroup", + ], + ) + for resource in response["ResourceTagMappingList"]: + tags = {tag["Key"]: tag["Value"] for tag in resource["Tags"]} + + if ( + "copilot-service" in tags + and tags["copilot-service"] == svc + and "copilot-environment" in tags + and tags["copilot-environment"] == env + and "copilot-application" in tags + and tags["copilot-application"] == app + ): + return resource["ResourceARN"] + + click.secho( + f"No target group found for application: {app}, environment: {env}, service: {svc}", + fg="red", + ) + + return None + + +def create_header_rule( + lb_client: boto3.client, + listener_arn: str, + target_group_arn: str, + header_name: str, + values: list, + rule_name: str, + priority: int, +): + conditions = get_host_conditions(lb_client, listener_arn, target_group_arn) + + # add new condition to existing conditions + combined_conditions = [ + { + "Field": "http-header", + "HttpHeaderConfig": {"HttpHeaderName": header_name, "Values": values}, + } + ] + conditions + + lb_client.create_rule( + ListenerArn=listener_arn, + Priority=priority, + Conditions=combined_conditions, + Actions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], + Tags=[ + {"Key": "name", "Value": rule_name}, + ], + ) + + click.secho( + f"Creating listener rule {rule_name} for HTTPS Listener with arn {listener_arn}.\n\nIf request header {header_name} contains one of the values {values}, the request will be forwarded to target group with arn {target_group_arn}.", + fg="green", + ) + + +def create_source_ip_rule( + lb_client: boto3.client, + listener_arn: str, + target_group_arn: str, + values: list, + rule_name: str, + priority: int, +): + conditions = get_host_conditions(lb_client, listener_arn, target_group_arn) + + # add new condition to existing conditions + combined_conditions = [ + { + "Field": "source-ip", + "SourceIpConfig": {"Values": [value + "/32" for value in values]}, + } + ] + conditions + + lb_client.create_rule( + ListenerArn=listener_arn, + Priority=priority, + Conditions=combined_conditions, + Actions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], + Tags=[ + {"Key": "name", "Value": rule_name}, + ], + ) + + click.secho( + f"Creating listener rule {rule_name} for HTTPS Listener with arn {listener_arn}.\n\nIf request source ip matches one of the values {values}, the request will be forwarded to target group with arn {target_group_arn}.", + fg="green", + ) + + +def get_host_conditions(lb_client: boto3.client, listener_arn: str, target_group_arn: str): + rules = lb_client.describe_rules(ListenerArn=listener_arn)["Rules"] + + # Get current set of forwarding conditions for the target group + for rule in rules: + for action in rule["Actions"]: + if action["Type"] == "forward" and action["TargetGroupArn"] == target_group_arn: + conditions = rule["Conditions"] + + # filter to host-header conditions + conditions = [ + {i: condition[i] for i in condition if i != "Values"} + for condition in conditions + if condition["Field"] == "host-header" + ] + + # remove internal hosts + conditions[0]["HostHeaderConfig"]["Values"] = [ + v for v in conditions[0]["HostHeaderConfig"]["Values"] + ] + + return conditions + + +def get_env_ips(vpc: str, application_environment: Environment) -> List[str]: + account_name = f"{application_environment.session.profile_name}-vpc" + vpc_name = vpc if vpc else account_name + ssm_client = application_environment.session.client("ssm") + + try: + param_value = ssm_client.get_parameter(Name=f"/{vpc_name}/EGRESS_IPS")["Parameter"]["Value"] + except ssm_client.exceptions.ParameterNotFound: + click.secho(f"No parameter found with name: /{vpc_name}/EGRESS_IPS") + raise click.Abort + + return [ip.strip() for ip in param_value.split(",")] diff --git a/dbt_platform_helper/providers/load_balancers.py b/dbt_platform_helper/providers/load_balancers.py new file mode 100644 index 000000000..7be823ed6 --- /dev/null +++ b/dbt_platform_helper/providers/load_balancers.py @@ -0,0 +1,51 @@ +import boto3 + + +def find_load_balancer(session: boto3.Session, app: str, env: str) -> str: + lb_client = session.client("elbv2") + + describe_response = lb_client.describe_load_balancers() + load_balancers = [lb["LoadBalancerArn"] for lb in describe_response["LoadBalancers"]] + + load_balancers = lb_client.describe_tags(ResourceArns=load_balancers)["TagDescriptions"] + + load_balancer_arn = None + for lb in load_balancers: + tags = {t["Key"]: t["Value"] for t in lb["Tags"]} + if tags.get("copilot-application") == app and tags.get("copilot-environment") == env: + load_balancer_arn = lb["ResourceArn"] + + if not load_balancer_arn: + raise LoadBalancerNotFoundError() + + return load_balancer_arn + + +def find_https_listener(session: boto3.Session, app: str, env: str) -> str: + load_balancer_arn = find_load_balancer(session, app, env) + lb_client = session.client("elbv2") + listeners = lb_client.describe_listeners(LoadBalancerArn=load_balancer_arn)["Listeners"] + + listener_arn = None + + try: + listener_arn = next(l["ListenerArn"] for l in listeners if l["Protocol"] == "HTTPS") + except StopIteration: + pass + + if not listener_arn: + raise ListenerNotFoundError() + + return listener_arn + + +class LoadBalancerNotFoundError(Exception): + pass + + +class ListenerNotFoundError(Exception): + pass + + +class ListenerRuleNotFoundError(Exception): + pass diff --git a/dbt_platform_helper/templates/environment-pipelines/main.tf b/dbt_platform_helper/templates/environment-pipelines/main.tf new file mode 100644 index 000000000..1b6d5f528 --- /dev/null +++ b/dbt_platform_helper/templates/environment-pipelines/main.tf @@ -0,0 +1,52 @@ +# {% extra_header %} +# {% version_info %} +locals { + platform_config = yamldecode(file("../../../platform-config.yml")) + all_pipelines = local.platform_config["environment_pipelines"] + pipelines = { for pipeline, config in local.platform_config["environment_pipelines"] : pipeline => config if config.account == "{{ aws_account }}" } + environment_config = local.platform_config["environments"] +} + +provider "aws" { + region = "eu-west-2" + profile = "{{ aws_account }}" + alias = "{{ aws_account }}" + shared_credentials_files = ["~/.aws/config"] +} + +terraform { + required_version = "~> 1.8" + backend "s3" { + bucket = "terraform-platform-state-{{ aws_account }}" + key = "tfstate/application/{{ application }}-pipelines.tfstate" + region = "eu-west-2" + encrypt = true + kms_key_id = "alias/terraform-platform-state-s3-key-{{ aws_account }}" + dynamodb_table = "terraform-platform-lockdb-{{ aws_account }}" + } + required_providers { + aws = { + source = "hashicorp/aws" + version = "~> 5" + } + } +} + + +module "environment-pipelines" { + source = "git::https://github.com/uktrade/terraform-platform-modules.git//environment-pipelines?depth=1&ref={{ terraform_platform_modules_version }}" + + for_each = local.pipelines + + application = "{{ application }}" + pipeline_name = each.key + repository = "uktrade/{{ application }}-deploy" + + environments = each.value.environments + all_pipelines = local.all_pipelines + environment_config = local.environment_config + branch = {% if deploy_branch %}"{{ deploy_branch }}"{% else %}each.value.branch{% endif %} + slack_channel = each.value.slack_channel + trigger_on_push = each.value.trigger_on_push + pipeline_to_trigger = lookup(each.value, "pipeline_to_trigger", null) +} diff --git a/dbt_platform_helper/utils/aws.py b/dbt_platform_helper/utils/aws.py index 476aae144..7b300d119 100644 --- a/dbt_platform_helper/utils/aws.py +++ b/dbt_platform_helper/utils/aws.py @@ -7,6 +7,7 @@ import boto3 import botocore +import botocore.exceptions import click import yaml from boto3 import Session @@ -20,62 +21,71 @@ def get_aws_session_or_abort(aws_profile: str = None) -> boto3.session.Session: - aws_profile = aws_profile if aws_profile else os.getenv("AWS_PROFILE") + REFRESH_TOKEN_MESSAGE = ( + "To refresh this SSO session run `aws sso login` with the corresponding profile" + ) + aws_profile = aws_profile or os.getenv("AWS_PROFILE") if aws_profile in AWS_SESSION_CACHE: return AWS_SESSION_CACHE[aws_profile] - # Check that the aws profile exists and is set. - click.secho(f"""Checking AWS connection for profile "{aws_profile}"...""", fg="cyan") + click.secho(f'Checking AWS connection for profile "{aws_profile}"...', fg="cyan") try: session = boto3.session.Session(profile_name=aws_profile) + sts = session.client("sts") + account_id, user_id = get_account_details(sts) + click.secho("Credentials are valid.", fg="green") + except botocore.exceptions.ProfileNotFound: - click.secho(f"""AWS profile "{aws_profile}" is not configured.""", fg="red") - exit(1) + _handle_error(f'AWS profile "{aws_profile}" is not configured.') except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] == "ExpiredToken": - click.secho( - f"Credentials are NOT valid. \nPlease login with: aws sso login --profile {aws_profile}", - fg="red", + _handle_error( + f"Credentials are NOT valid. \nPlease login with: aws sso login --profile {aws_profile}" ) - exit(1) - - sts = session.client("sts") - try: - account_id, user_id = get_account_details(sts) - click.secho("Credentials are valid.", fg="green") - except ( - botocore.exceptions.UnauthorizedSSOTokenError, - botocore.exceptions.TokenRetrievalError, - botocore.exceptions.SSOTokenLoadError, - ): - click.secho( - "The SSO session associated with this profile has expired or is otherwise invalid." - "To refresh this SSO session run `aws sso login` with the corresponding profile", - fg="red", + except botocore.exceptions.NoCredentialsError: + _handle_error("There are no credentials set for this session.", REFRESH_TOKEN_MESSAGE) + except botocore.exceptions.UnauthorizedSSOTokenError: + _handle_error("The SSO Token used for this session is unauthorised.", REFRESH_TOKEN_MESSAGE) + except botocore.exceptions.TokenRetrievalError: + _handle_error("Unable to retrieve the Token for this session.", REFRESH_TOKEN_MESSAGE) + except botocore.exceptions.SSOTokenLoadError: + _handle_error( + "The SSO session associated with this profile has expired, is not set or is otherwise invalid.", + REFRESH_TOKEN_MESSAGE, ) - exit(1) alias_client = session.client("iam") - account_name = alias_client.list_account_aliases()["AccountAliases"] + account_name = alias_client.list_account_aliases().get("AccountAliases", []) + + _log_account_info(account_name, account_id) + + click.echo( + click.style("User: ", fg="yellow") + + click.style(f"{user_id.split(':')[-1]}\n", fg="white", bold=True) + ) + + AWS_SESSION_CACHE[aws_profile] = session + return session + + +def _handle_error(message: str, refresh_token_message: str = None) -> None: + full_message = message + (" " + refresh_token_message if refresh_token_message else "") + click.secho(full_message, fg="red") + exit(1) + + +def _log_account_info(account_name: list, account_id: str) -> None: if account_name: click.echo( click.style("Logged in with AWS account: ", fg="yellow") - + click.style(f"{account_name[0]}/{account_id}", fg="white", bold=True), + + click.style(f"{account_name[0]}/{account_id}", fg="white", bold=True) ) else: click.echo( click.style("Logged in with AWS account id: ", fg="yellow") - + click.style(f"{account_id}", fg="white", bold=True), + + click.style(f"{account_id}", fg="white", bold=True) ) - click.echo( - click.style("User: ", fg="yellow") - + click.style(f"{user_id.split(':')[-1]}\n", fg="white", bold=True), - ) - - AWS_SESSION_CACHE[aws_profile] = session - - return session class NoProfileForAccountIdError(Exception): @@ -362,12 +372,12 @@ def get_connection_string( class Vpc: - def __init__(self, subnets, security_groups): + def __init__(self, subnets: list[str], security_groups: list[str]): self.subnets = subnets self.security_groups = security_groups -def get_vpc_info_by_name(session, app, env, vpc_name): +def get_vpc_info_by_name(session: Session, app: str, env: str, vpc_name: str) -> Vpc: ec2_client = session.client("ec2") vpc_response = ec2_client.describe_vpcs(Filters=[{"Name": "tag:Name", "Values": [vpc_name]}]) diff --git a/images/tools/database-copy/Dockerfile b/images/tools/database-copy/Dockerfile index 3f47f6128..9ae958289 100644 --- a/images/tools/database-copy/Dockerfile +++ b/images/tools/database-copy/Dockerfile @@ -1,21 +1,12 @@ -FROM public.ecr.aws/docker/library/debian:12-slim +FROM public.ecr.aws/docker/library/postgres:16 -RUN apt-get update && \ - apt-get upgrade -y && \ - # Add repository for postgres version 16 - apt-get install -y postgresql-common && \ - /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && \ - apt-get install -y curl unzip jq procps postgresql-client-16 && \ - # Start installing AWS CLI - curl "https://awscli.amazonaws.com/awscli-exe-linux-aarch64.zip" -o "awscliv2.zip" && \ - unzip awscliv2.zip && \ - ./aws/install && \ - rm -rf aws && rm awscliv2.zip && \ - apt-get remove -y apt-utils gnupg unzip && \ - # Finish installing AWS CLI - apt-get clean +RUN apt-get update && apt upgrade +RUN apt-get install -y curl zip jq +RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-aarch64.zip" -o "awscliv2.zip" +RUN unzip awscliv2.zip +RUN ./aws/install -COPY shell-profile.sh /root/.bashrc COPY entrypoint.sh /entrypoint.sh +COPY clear_db.sql /clear_db.sql ENTRYPOINT ["bash", "/entrypoint.sh"] diff --git a/images/tools/database-copy/clear_db.sql b/images/tools/database-copy/clear_db.sql new file mode 100644 index 000000000..4bc0e413d --- /dev/null +++ b/images/tools/database-copy/clear_db.sql @@ -0,0 +1,11 @@ +DO $$ DECLARE + r RECORD; +BEGIN + FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public') LOOP + CASE + WHEN r.tablename NOT IN ('spatial_ref_sys') THEN + EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE'; + ELSE null; + END CASE; + END LOOP; +END $$; diff --git a/images/tools/database-copy/entrypoint.sh b/images/tools/database-copy/entrypoint.sh index 4fc8dd80e..71b4bd6df 100644 --- a/images/tools/database-copy/entrypoint.sh +++ b/images/tools/database-copy/entrypoint.sh @@ -1,30 +1,102 @@ -#!/usr/bin/env bash +#!/bin/bash -TASKS_RUNNING=0 +clean_up(){ + echo "Cleaning up dump file" + rm data_dump.sql + echo "Removing dump file from S3" + aws s3 rm s3://${S3_BUCKET_NAME}/data_dump.sql + if [ ${exit_code} -ne 0 ] + then + echo "Aborting data load: Clean up failed" + exit $exit_code + fi +} + +handle_errors(){ + exit_code=$1 + message=$2 + if [ ${exit_code} -ne 0 ] + then + clean_up + echo "Aborting data load: {$message}" + exit $exit_code + fi +} -CHECK_COUNT=0 -CHECK_NUMBER=1 -CHECK_INTERVAL=60 +if [ "${DATA_COPY_OPERATION:-DUMP}" != "LOAD" ] +then + echo "Starting data dump" + pg_dump --no-owner --no-acl --format c "${DB_CONNECTION_STRING}" > data_dump.sql + exit_code=$? -CLIENT_TASK="ssm-session-wor" + if [ ${exit_code} -ne 0 ] + then + echo "Aborting data dump" + exit $exit_code + fi -while [ $CHECK_COUNT -lt $CHECK_NUMBER ]; do - sleep $CHECK_INTERVAL - TASKS_RUNNING="$(ps -e -o pid,comm | grep -c "$CLIENT_TASK")" + aws s3 cp data_dump.sql s3://${S3_BUCKET_NAME}/ + exit_code=$? - if [[ $TASKS_RUNNING == 0 ]]; then - CHECK_COUNT=$(( $CHECK_COUNT + 1 )) - TIME_TO_SHUTDOWN="$(( (CHECK_NUMBER - CHECK_COUNT) * CHECK_INTERVAL ))" - echo "No clients connected, will shutdown in approximately $TIME_TO_SHUTDOWN seconds" - else - CHECK_COUNT=0 - echo "$TASKS_RUNNING clients are connected" + if [ ${exit_code} -ne 0 ] + then + echo "Aborting data dump" + exit $exit_code fi -done -# Trigger CloudFormation stack delete before shutting down -if [[ ! -z $ECS_CONTAINER_METADATA_URI_V4 ]]; then - aws cloudformation delete-stack --stack-name task-$(curl $ECS_CONTAINER_METADATA_URI_V4 -s | jq -r ".Name") -fi + echo "Stopping data dump" +else + echo "Starting data load" + + echo "Copying data dump from S3" + aws s3 cp s3://${S3_BUCKET_NAME}/data_dump.sql data_dump.sql + + handle_errors $? "Copy failed" + + echo "Scaling down services" + SERVICES_DATA=$(aws ecs list-services --cluster "${ECS_CLUSTER}") + handle_errors $? "Failed to list services" + SERVICES=$(echo "${SERVICES_DATA}" | jq -r '.serviceArns[]') -echo "Shutting down" + for service in ${SERVICES} + do + COUNT_DATA=$(aws ecs describe-services --cluster "${ECS_CLUSTER}" --services "${service}") + handle_errors $? "Failed to describe service" + COUNT=$(echo "${COUNT_DATA}" | jq '.services[0].desiredCount') + + SERVICE_NAME=$(basename "${service}") + CONFIG_FILE="${SERVICE_NAME}.desired_count" + echo "${COUNT}" > "${CONFIG_FILE}" + + echo ${SERVICE_NAME} + UPDATE_DATA=$(aws ecs update-service --cluster "${ECS_CLUSTER}" --service "${service}" --desired-count 0) + handle_errors $? "Failed to update service ${SERVICE_NAME}" + echo "${UPDATE_DATA}" | jq -r '" Desired Count: \(.service.desiredCount)\n Running Count: \(.service.runningCount)"' + + done + + echo "Clearing down the database prior to loading new data" + psql "${DB_CONNECTION_STRING}" -f /clear_db.sql + + handle_errors $? "Clear down failed" + + echo "Restoring data from dump file" + pg_restore --format c --dbname "${DB_CONNECTION_STRING}" data_dump.sql + + handle_errors $? "Restore failed" + for service in ${SERVICES} + do + CONFIG_FILE="$(basename "${service}").desired_count" + COUNT=$(cat "${CONFIG_FILE}") + SERVICE_NAME=$(basename "${service}") + echo "Scaling up services" + echo ${SERVICE_NAME} + UPDATE_DATA=$(aws ecs update-service --cluster "${ECS_CLUSTER}" --service "${service}" --desired-count "${COUNT}") + handle_errors $? "Failed to update service ${SERVICE_NAME}" + echo "${UPDATE_DATA}" | jq -r '" Desired Count: \(.service.desiredCount)\n Running Count: \(.service.runningCount)"' + done + + clean_up + + echo "Stopping data load" +fi diff --git a/images/tools/database-copy/shell-profile.sh b/images/tools/database-copy/shell-profile.sh deleted file mode 100644 index 91bcd5880..000000000 --- a/images/tools/database-copy/shell-profile.sh +++ /dev/null @@ -1,3 +0,0 @@ -echo "Starting database copy." - -pg_dump $SOURCE_DB_CONNECTION | psql $TARGET_DB_CONNECTION main diff --git a/images/tools/database-copy2/Dockerfile b/images/tools/database-copy2/Dockerfile deleted file mode 100644 index d7c6568f9..000000000 --- a/images/tools/database-copy2/Dockerfile +++ /dev/null @@ -1,11 +0,0 @@ -FROM postgres:16 - -RUN apt update && apt upgrade -RUN apt install -y curl zip jq -RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-aarch64.zip" -o "awscliv2.zip" -RUN unzip awscliv2.zip -RUN ./aws/install - -COPY entrypoint.sh /entrypoint.sh - -ENTRYPOINT ["bash", "/entrypoint.sh"] diff --git a/images/tools/database-copy2/entrypoint.sh b/images/tools/database-copy2/entrypoint.sh deleted file mode 100644 index c0871e35b..000000000 --- a/images/tools/database-copy2/entrypoint.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -if [ "${DATA_COPY_OPERATION:-DUMP}" != "LOAD" ] -then - echo "Starting data dump" - pg_dump --format c "${DB_CONNECTION_STRING}" > data_dump.sql - aws s3 cp data_dump.sql s3://${S3_BUCKET_NAME}/ - echo "Stopping data dump" -else - echo "Starting data restore" - aws s3 cp s3://${S3_BUCKET_NAME}/data_dump.sql data_dump.sql - pg_restore --format c --dbname "${DB_CONNECTION_STRING}" data_dump.sql - echo "Stopping data restore" -fi diff --git a/pyproject.toml b/pyproject.toml index c63d2441f..6f3f30f06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ line-length = 100 [tool.poetry] name = "dbt-platform-helper" -version = "11.0.1" +version = "11.2.0" description = "Set of tools to help transfer applications/services from GOV.UK PaaS to DBT PaaS augmenting AWS Copilot." authors = ["Department for Business and Trade Platform Team "] license = "MIT" diff --git a/release-manifest.json b/release-manifest.json index c656ffb20..af8cc29f4 100644 --- a/release-manifest.json +++ b/release-manifest.json @@ -1,3 +1,3 @@ { - ".": "11.0.1" + ".": "11.2.0" } diff --git a/tests/platform_helper/domain/__init__.py b/tests/platform_helper/domain/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/platform_helper/domain/test_database_copy.py b/tests/platform_helper/domain/test_database_copy.py new file mode 100644 index 000000000..9b8ef4456 --- /dev/null +++ b/tests/platform_helper/domain/test_database_copy.py @@ -0,0 +1,590 @@ +from unittest.mock import Mock +from unittest.mock import call + +import pytest +import yaml + +from dbt_platform_helper.constants import PLATFORM_CONFIG_FILE +from dbt_platform_helper.domain.database_copy import DatabaseCopy +from dbt_platform_helper.exceptions import AWSException +from dbt_platform_helper.utils.application import Application +from dbt_platform_helper.utils.application import ApplicationNotFoundError +from dbt_platform_helper.utils.aws import Vpc + + +class DataCopyMocks: + def __init__(self, app="test-app", env="test-env", acc="12345", vpc=Vpc([], [])): + self.application = Application(app) + self.environment = Mock() + self.environment.account_id = acc + self.application.environments = {env: self.environment, "test-env-2": Mock()} + self.load_application_fn = Mock(return_value=self.application) + self.client = Mock() + self.environment.session.client.return_value = self.client + + self.vpc = vpc + self.vpc_config_fn = Mock() + self.vpc_config_fn.return_value = vpc + self.db_connection_string_fn = Mock(return_value="test-db-connection-string") + self.maintenance_page_provider = Mock() + + self.input_fn = Mock(return_value="yes") + self.echo_fn = Mock() + self.abort_fn = Mock(side_effect=SystemExit(1)) + + def params(self): + return { + "load_application_fn": self.load_application_fn, + "vpc_config_fn": self.vpc_config_fn, + "db_connection_string_fn": self.db_connection_string_fn, + "maintenance_page_provider": self.maintenance_page_provider, + "input_fn": self.input_fn, + "echo_fn": self.echo_fn, + "abort_fn": self.abort_fn, + } + + +@pytest.mark.parametrize("is_dump, exp_operation", [(True, "dump"), (False, "load")]) +def test_run_database_copy_task(is_dump, exp_operation): + vpc = Vpc(["subnet_1", "subnet_2"], ["sec_group_1"]) + mocks = DataCopyMocks(vpc=vpc) + db_connection_string = "connection_string" + + db_copy = DatabaseCopy("test-app", "test-postgres", **mocks.params()) + + mock_client = Mock() + mock_session = Mock() + mock_session.client.return_value = mock_client + mock_client.run_task.return_value = {"tasks": [{"taskArn": "arn:aws:ecs:test-task-arn"}]} + + actual_task_arn = db_copy.run_database_copy_task( + mock_session, "test-env", vpc, is_dump, db_connection_string + ) + + assert actual_task_arn == "arn:aws:ecs:test-task-arn" + + mock_session.client.assert_called_once_with("ecs") + expected_env_vars = [ + {"name": "DATA_COPY_OPERATION", "value": exp_operation.upper()}, + {"name": "DB_CONNECTION_STRING", "value": "connection_string"}, + ] + if not is_dump: + expected_env_vars.append( + {"name": "ECS_CLUSTER", "value": "test-app-test-env"}, + ) + + mock_client.run_task.assert_called_once_with( + taskDefinition=f"arn:aws:ecs:eu-west-2:12345:task-definition/test-app-test-env-test-postgres-{exp_operation}", + cluster="test-app-test-env", + capacityProviderStrategy=[ + {"capacityProvider": "FARGATE", "weight": 1, "base": 0}, + ], + networkConfiguration={ + "awsvpcConfiguration": { + "subnets": ["subnet_1", "subnet_2"], + "securityGroups": [ + "sec_group_1", + ], + "assignPublicIp": "DISABLED", + } + }, + overrides={ + "containerOverrides": [ + { + "name": f"test-app-test-env-test-postgres-{exp_operation}", + "environment": expected_env_vars, + } + ] + }, + ) + + +def test_database_dump(): + app = "test-app" + env = "test-env" + vpc_name = "test-vpc" + database = "test-db" + + mocks = DataCopyMocks(app, env) + + mock_run_database_copy_task = Mock(return_value="arn://task-arn") + + db_copy = DatabaseCopy(app, database, **mocks.params()) + db_copy.run_database_copy_task = mock_run_database_copy_task + + db_copy.tail_logs = Mock() + db_copy.enrich_vpc_name = Mock() + db_copy.enrich_vpc_name.return_value = "test-vpc-override" + + db_copy.dump(env, vpc_name) + + mocks.load_application_fn.assert_called_once() + mocks.vpc_config_fn.assert_called_once_with( + mocks.environment.session, app, env, "test-vpc-override" + ) + mocks.db_connection_string_fn.assert_called_once_with( + mocks.environment.session, app, env, "test-app-test-env-test-db" + ) + mock_run_database_copy_task.assert_called_once_with( + mocks.environment.session, + env, + mocks.vpc, + True, + "test-db-connection-string", + ) + mocks.input_fn.assert_not_called() + mocks.echo_fn.assert_has_calls( + [ + call("Dumping test-db from the test-env environment into S3", fg="white", bold=True), + call( + "Task arn://task-arn started. Waiting for it to complete (this may take some time)...", + fg="white", + ), + ] + ) + db_copy.tail_logs.assert_called_once_with(True, env) + db_copy.enrich_vpc_name.assert_called_once_with("test-env", "test-vpc") + + +def test_database_load_with_response_of_yes(): + app = "test-app" + env = "test-env" + vpc_name = "test-vpc" + mocks = DataCopyMocks() + + mock_run_database_copy_task = Mock(return_value="arn://task-arn") + + db_copy = DatabaseCopy(app, "test-db", **mocks.params()) + db_copy.tail_logs = Mock() + db_copy.enrich_vpc_name = Mock() + db_copy.enrich_vpc_name.return_value = "test-vpc-override" + db_copy.run_database_copy_task = mock_run_database_copy_task + + db_copy.load(env, vpc_name) + + mocks.load_application_fn.assert_called_once() + + mocks.vpc_config_fn.assert_called_once_with( + mocks.environment.session, app, env, "test-vpc-override" + ) + + mocks.db_connection_string_fn.assert_called_once_with( + mocks.environment.session, app, env, "test-app-test-env-test-db" + ) + + mock_run_database_copy_task.assert_called_once_with( + mocks.environment.session, + env, + mocks.vpc, + False, + "test-db-connection-string", + ) + + mocks.input_fn.assert_called_once_with( + f"\nWARNING: the load operation is destructive and will delete the test-db database in the test-env environment. Continue? (y/n)" + ) + + mocks.echo_fn.assert_has_calls( + [ + call( + "Loading data into test-db in the test-env environment from S3", + fg="white", + bold=True, + ), + call( + "Task arn://task-arn started. Waiting for it to complete (this may take some time)...", + fg="white", + ), + ] + ) + db_copy.tail_logs.assert_called_once_with(False, "test-env") + db_copy.enrich_vpc_name.assert_called_once_with("test-env", "test-vpc") + + +def test_database_load_with_response_of_no(): + mocks = DataCopyMocks() + mocks.input_fn = Mock(return_value="no") + + mock_run_database_copy_task_fn = Mock() + + db_copy = DatabaseCopy("test-app", "test-db", **mocks.params()) + db_copy.tail_logs = Mock() + db_copy.run_database_copy_task = mock_run_database_copy_task_fn + + db_copy.load("test-env", "test-vpc") + + mocks.environment.session_fn.assert_not_called() + + mocks.vpc_config_fn.assert_not_called() + + mocks.db_connection_string_fn.assert_not_called() + + mock_run_database_copy_task_fn.assert_not_called() + + mocks.input_fn.assert_called_once_with( + f"\nWARNING: the load operation is destructive and will delete the test-db database in the test-env environment. Continue? (y/n)" + ) + mocks.echo_fn.assert_not_called() + db_copy.tail_logs.assert_not_called() + + +@pytest.mark.parametrize("is_dump", (True, False)) +def test_database_dump_handles_vpc_errors(is_dump): + mocks = DataCopyMocks() + mocks.vpc_config_fn.side_effect = AWSException("A VPC error occurred") + + db_copy = DatabaseCopy("test-app", "test-db", **mocks.params()) + + with pytest.raises(SystemExit) as exc: + if is_dump: + db_copy.dump("test-env", "bad-vpc-name") + else: + db_copy.load("test-env", "bad-vpc-name") + + assert exc.value.code == 1 + mocks.abort_fn.assert_called_once_with("A VPC error occurred") + + +@pytest.mark.parametrize("is_dump", (True, False)) +def test_database_dump_handles_db_name_errors(is_dump): + mocks = DataCopyMocks() + mocks.db_connection_string_fn = Mock(side_effect=Exception("Parameter not found.")) + + db_copy = DatabaseCopy("test-app", "bad-db", **mocks.params()) + + with pytest.raises(SystemExit) as exc: + if is_dump: + db_copy.dump("test-env", "vpc-name") + else: + db_copy.load("test-env", "vpc-name") + + assert exc.value.code == 1 + mocks.abort_fn.assert_called_once_with( + "Parameter not found. (Database: test-app-test-env-bad-db)" + ) + + +@pytest.mark.parametrize("is_dump", (True, False)) +def test_database_dump_handles_env_name_errors(is_dump): + mocks = DataCopyMocks() + + db_copy = DatabaseCopy("test-app", "test-db", **mocks.params()) + + with pytest.raises(SystemExit) as exc: + if is_dump: + db_copy.dump("bad-env", "vpc-name") + else: + db_copy.load("bad-env", "vpc-name") + + assert exc.value.code == 1 + mocks.abort_fn.assert_called_once_with( + "No such environment 'bad-env'. Available environments are: test-env, test-env-2" + ) + + +@pytest.mark.parametrize("is_dump", (True, False)) +def test_database_dump_handles_account_id_errors(is_dump): + mocks = DataCopyMocks() + db_copy = DatabaseCopy("test-app", "test-db", **mocks.params()) + error_msg = "An error occurred (InvalidParameterException) when calling the RunTask operation: AccountIDs mismatch" + db_copy.run_database_copy_task = Mock(side_effect=Exception(error_msg)) + + db_copy.tail_logs = Mock() + + with pytest.raises(SystemExit) as exc: + if is_dump: + db_copy.dump("test-env", "vpc-name") + else: + db_copy.load("test-env", "vpc-name") + + assert exc.value.code == 1 + mocks.abort_fn.assert_called_once_with(f"{error_msg} (Account id: 12345)") + + +def test_database_copy_initialization_handles_app_name_errors(): + mocks = DataCopyMocks() + mocks.load_application_fn = Mock(side_effect=ApplicationNotFoundError()) + + with pytest.raises(SystemExit) as exc: + DatabaseCopy("bad-app", "test-db", **mocks.params()) + + assert exc.value.code == 1 + mocks.abort_fn.assert_called_once_with("No such application 'bad-app'.") + + +@pytest.mark.parametrize("user_response", ["y", "Y", " y ", "\ny", "YES", "yes"]) +def test_is_confirmed_ready_to_load(user_response): + mocks = DataCopyMocks() + mocks.input_fn.return_value = user_response + + db_copy = DatabaseCopy("test-app", "test-db", **mocks.params()) + + assert db_copy.is_confirmed_ready_to_load("test-env") + + mocks.input_fn.assert_called_once_with( + f"\nWARNING: the load operation is destructive and will delete the test-db database in the test-env environment. Continue? (y/n)" + ) + + +@pytest.mark.parametrize("user_response", ["n", "N", " no ", "squiggly"]) +def test_is_not_confirmed_ready_to_load(user_response): + mocks = DataCopyMocks() + mocks.input_fn.return_value = user_response + + db_copy = DatabaseCopy("test-app", "test-db", **mocks.params()) + + assert not db_copy.is_confirmed_ready_to_load("test-env") + + mocks.input_fn.assert_called_once_with( + f"\nWARNING: the load operation is destructive and will delete the test-db database in the test-env environment. Continue? (y/n)" + ) + + +def test_is_confirmed_ready_to_load_with_yes_flag(): + mocks = DataCopyMocks() + + db_copy = DatabaseCopy("test-app", "test-db", True, **mocks.params()) + + assert db_copy.is_confirmed_ready_to_load("test-env") + + mocks.input_fn.assert_not_called() + + +@pytest.mark.parametrize( + "services, template", + ( + (["web"], "default"), + (["*"], "default"), + (["web", "other"], "migrations"), + ), +) +def test_copy_command(services, template): + mocks = DataCopyMocks() + db_copy = DatabaseCopy("test-app", "test-db", True, **mocks.params()) + db_copy.dump = Mock() + db_copy.load = Mock() + db_copy.enrich_vpc_name = Mock() + db_copy.enrich_vpc_name.return_value = "test-vpc-override" + + db_copy.copy("test-from-env", "test-to-env", "test-from-vpc", "test-to-vpc", services, template) + + db_copy.enrich_vpc_name.assert_called_once_with("test-to-env", "test-to-vpc") + mocks.maintenance_page_provider.activate.assert_called_once_with( + "test-app", "test-to-env", services, template, "test-vpc-override" + ) + db_copy.dump.assert_called_once_with("test-from-env", "test-from-vpc") + db_copy.load.assert_called_once_with("test-to-env", "test-vpc-override") + mocks.maintenance_page_provider.deactivate.assert_called_once_with("test-app", "test-to-env") + + +@pytest.mark.parametrize( + "services, template", + ( + (["web"], "default"), + (["*"], "default"), + (["web", "other"], "migrations"), + ), +) +def test_copy_command_with_no_maintenance_page(services, template): + mocks = DataCopyMocks() + db_copy = DatabaseCopy("test-app", "test-db", True, **mocks.params()) + db_copy.dump = Mock() + db_copy.load = Mock() + db_copy.enrich_vpc_name = Mock() + db_copy.enrich_vpc_name.return_value = "test-vpc-override" + + db_copy.copy( + "test-from-env", "test-to-env", "test-from-vpc", "test-to-vpc", services, template, True + ) + + mocks.maintenance_page_provider.offline.assert_not_called() + mocks.maintenance_page_provider.online.assert_not_called() + + +@pytest.mark.parametrize("is_dump", [True, False]) +def test_tail_logs(is_dump): + action = "dump" if is_dump else "load" + + mocks = DataCopyMocks() + + mocks.client.start_live_tail.return_value = { + "responseStream": [ + {"sessionStart": {}}, + {"sessionUpdate": {"sessionResults": []}}, + {"sessionUpdate": {"sessionResults": [{"message": ""}]}}, + {"sessionUpdate": {"sessionResults": [{"message": f"Starting data {action}"}]}}, + {"sessionUpdate": {"sessionResults": [{"message": "A load of SQL shenanigans"}]}}, + {"sessionUpdate": {"sessionResults": [{"message": f"Stopping data {action}"}]}}, + ] + } + + db_copy = DatabaseCopy("test-app", "test-db", **mocks.params()) + db_copy.tail_logs(is_dump, "test-env") + + mocks.environment.session.client.assert_called_once_with("logs") + mocks.client.start_live_tail.assert_called_once_with( + logGroupIdentifiers=[ + f"arn:aws:logs:eu-west-2:12345:log-group:/ecs/test-app-test-env-test-db-{action}" + ], + ) + + mocks.echo_fn.assert_has_calls( + [ + call( + f"Tailing /ecs/test-app-test-env-test-db-{action} logs", + fg="yellow", + ), + call(f"Starting data {action}"), + call("A load of SQL shenanigans"), + call(f"Stopping data {action}"), + ] + ) + + +@pytest.mark.parametrize("is_dump", [True, False]) +def test_tail_logs_exits_with_error_if_task_aborts(is_dump): + action = "dump" if is_dump else "load" + + mocks = DataCopyMocks() + + mocks.client.start_live_tail.return_value = { + "responseStream": [ + {"sessionStart": {}}, + {"sessionUpdate": {"sessionResults": []}}, + {"sessionUpdate": {"sessionResults": [{"message": ""}]}}, + {"sessionUpdate": {"sessionResults": [{"message": f"Starting data {action}"}]}}, + {"sessionUpdate": {"sessionResults": [{"message": "A load of SQL shenanigans"}]}}, + {"sessionUpdate": {"sessionResults": [{"message": f"Aborting data {action}"}]}}, + ] + } + + db_copy = DatabaseCopy("test-app", "test-db", **mocks.params()) + + with pytest.raises(SystemExit) as exc: + db_copy.tail_logs(is_dump, "test-env") + + assert exc.value.code == 1 + mocks.abort_fn.assert_called_once_with("Task aborted abnormally. See logs above for details.") + + +def test_database_copy_account_id(): + mocks = DataCopyMocks() + + db_copy = DatabaseCopy("test-app", "test-db", **mocks.params()) + + assert db_copy.account_id("test-env") == "12345" + + +def test_update_application_from_platform_config_if_application_not_specified(fs): + fs.create_file(PLATFORM_CONFIG_FILE, contents=yaml.dump({"application": "test-app"})) + mocks = DataCopyMocks() + + db_copy = DatabaseCopy(None, "test-db", **mocks.params()) + + assert db_copy.app == "test-app" + + +def test_error_if_neither_platform_config_or_application_supplied(fs): + # fakefs used here to ensure the platform-config.yml isn't picked up from the filesystem + mocks = DataCopyMocks() + + with pytest.raises(SystemExit) as exc: + DatabaseCopy(None, "test-db", **mocks.params()) + + assert exc.value.code == 1 + mocks.abort_fn.assert_called_once_with( + "You must either be in a deploy repo, or provide the --app option." + ) + + +@pytest.mark.parametrize("is_dump", [True, False]) +def test_database_dump_with_no_vpc_works_in_deploy_repo(fs, is_dump): + fs.create_file( + PLATFORM_CONFIG_FILE, + contents=yaml.dump( + {"application": "test-app", "environments": {"test-env": {"vpc": "test-env-vpc"}}} + ), + ) + env = "test-env" + database = "test-db" + + mocks = DataCopyMocks() + + mock_run_database_copy_task = Mock(return_value="arn://task-arn") + + db_copy = DatabaseCopy(None, database, **mocks.params()) + + db_copy.run_database_copy_task = mock_run_database_copy_task + db_copy.tail_logs = Mock() + + if is_dump: + db_copy.dump(env, None) + else: + db_copy.load(env, None) + + mocks.vpc_config_fn.assert_called_once_with( + mocks.environment.session, "test-app", env, "test-env-vpc" + ) + + +@pytest.mark.parametrize("is_dump", [True, False]) +def test_database_dump_with_no_vpc_fails_if_not_in_deploy_repo(fs, is_dump): + # fakefs used here to ensure the platform-config.yml isn't picked up from the filesystem + env = "test-env" + database = "test-db" + + mocks = DataCopyMocks() + + mock_run_database_copy_task = Mock(return_value="arn://task-arn") + + db_copy = DatabaseCopy("test-app", database, **mocks.params()) + + db_copy.run_database_copy_task = mock_run_database_copy_task + db_copy.tail_logs = Mock() + + with pytest.raises(SystemExit) as exc: + if is_dump: + db_copy.dump(env, None) + else: + db_copy.load(env, None) + + assert exc.value.code == 1 + mocks.abort_fn.assert_called_once_with( + f"You must either be in a deploy repo, or provide the vpc name option." + ) + + +def test_enrich_vpc_name_returns_the_vpc_name_passed_in(): + db_copy = DatabaseCopy("test-app", "test-db", **DataCopyMocks().params()) + vpc_name = db_copy.enrich_vpc_name("test-env", "test-vpc") + + assert vpc_name == "test-vpc" + + +def test_enrich_vpc_name_aborts_if_no_platform_config(fs): + # fakefs used here to ensure the platform-config.yml isn't picked up from the filesystem + mocks = DataCopyMocks() + db_copy = DatabaseCopy("test-app", "test-db", **mocks.params()) + + with pytest.raises(SystemExit): + db_copy.enrich_vpc_name("test-env", None) + + mocks.abort_fn.assert_called_once_with( + f"You must either be in a deploy repo, or provide the vpc name option." + ) + + +def test_enrich_vpc_name_enriches_vpc_name_from_platform_config(fs): + # fakefs used here to ensure the platform-config.yml isn't picked up from the filesystem + fs.create_file( + PLATFORM_CONFIG_FILE, + contents=yaml.dump( + {"application": "test-app", "environments": {"test-env": {"vpc": "test-env-vpc"}}} + ), + ) + mocks = DataCopyMocks() + db_copy = DatabaseCopy("test-app", "test-db", **mocks.params()) + + vpc_name = db_copy.enrich_vpc_name("test-env", None) + + assert vpc_name == "test-env-vpc" diff --git a/tests/platform_helper/domain/test_maintenance_page.py b/tests/platform_helper/domain/test_maintenance_page.py new file mode 100644 index 000000000..4903538b0 --- /dev/null +++ b/tests/platform_helper/domain/test_maintenance_page.py @@ -0,0 +1,499 @@ +from unittest.mock import MagicMock +from unittest.mock import Mock +from unittest.mock import call +from unittest.mock import patch + +import pytest +from click.testing import CliRunner +from moto import mock_aws + +from dbt_platform_helper.domain.maintenance_page import * +from dbt_platform_helper.utils.application import Application + + +class TestGetMaintenancePage: + def test_when_environment_online(self): + + boto_mock = MagicMock() + boto_mock.client().describe_rules.return_value = {"Rules": [{"RuleArn": "rule_arn"}]} + boto_mock.client().describe_tags.return_value = { + "TagDescriptions": [{"ResourceArn": "rule_arn", "Tags": []}] + } + + maintenance_page = get_maintenance_page(boto_mock, "listener_arn") + assert maintenance_page is None + + def test_when_environment_offline_with_default_page(self): + + boto_mock = MagicMock() + boto_mock.client().describe_rules.return_value = {"Rules": [{"RuleArn": "rule_arn"}]} + boto_mock.client().describe_tags.return_value = { + "TagDescriptions": [ + { + "ResourceArn": "rule_arn", + "Tags": [ + {"Key": "name", "Value": "MaintenancePage"}, + {"Key": "type", "Value": "default"}, + ], + } + ] + } + + maintenance_page = get_maintenance_page(boto_mock, "listener_arn") + assert maintenance_page == "default" + + +class TestRemoveMaintenancePage: + def test_when_environment_online(self): + + boto_mock = MagicMock() + boto_mock.client().describe_rules.return_value = {"Rules": [{"RuleArn": "rule_arn"}]} + boto_mock.client().describe_tags.return_value = { + "TagDescriptions": [{"ResourceArn": "rule_arn", "Tags": []}] + } + + with pytest.raises(ListenerRuleNotFoundError): + remove_maintenance_page(boto_mock, "listener_arn") + + @patch("dbt_platform_helper.domain.maintenance_page.delete_listener_rule") + def test_when_environment_offline(self, delete_listener_rule): + + boto_mock = MagicMock() + boto_mock.client().describe_rules.return_value = { + "Rules": [{"RuleArn": "rule_arn"}, {"RuleArn": "allowed_ips_rule_arn"}] + } + tag_descriptions = [ + { + "ResourceArn": "rule_arn", + "Tags": [ + {"Key": "name", "Value": "MaintenancePage"}, + {"Key": "type", "Value": "default"}, + ], + }, + { + "ResourceArn": "allowed_ips_rule_arn", + "Tags": [ + {"Key": "name", "Value": "AllowedIps"}, + {"Key": "type", "Value": "default"}, + ], + }, + { + "ResourceArn": "allowed_source_ips_rule_arn", + "Tags": [ + {"Key": "name", "Value": "AllowedSourceIps"}, + {"Key": "type", "Value": "default"}, + ], + }, + ] + boto_mock.client().describe_tags.return_value = {"TagDescriptions": tag_descriptions} + boto_mock.client().delete_rule.return_value = None + + remove_maintenance_page(boto_mock, "listener_arn") + + delete_listener_rule.assert_has_calls( + [ + call(tag_descriptions, "MaintenancePage", boto_mock.client()), + call().__bool__(), # return value of mock is referenced in line: `if name == "MaintenancePage" and not deleted` + call(tag_descriptions, "AllowedIps", boto_mock.client()), + call(tag_descriptions, "BypassIpFilter", boto_mock.client()), + call(tag_descriptions, "AllowedSourceIps", boto_mock.client()), + ] + ) + + +class TestAddMaintenancePage: + @pytest.mark.parametrize("template", ["default", "migration", "dmas-migration"]) + @patch( + "dbt_platform_helper.domain.maintenance_page.random.choices", return_value=["a", "b", "c"] + ) + @patch("dbt_platform_helper.domain.maintenance_page.create_source_ip_rule") + @patch("dbt_platform_helper.domain.maintenance_page.create_header_rule") + @patch("dbt_platform_helper.domain.maintenance_page.find_target_group") + @patch("dbt_platform_helper.domain.maintenance_page.get_maintenance_page_template") + def test_adding_existing_template( + self, + get_maintenance_page_template, + find_target_group, + create_header_rule, + create_source_ip, + choices, + template, + mock_application, + ): + + boto_mock = MagicMock() + get_maintenance_page_template.return_value = template + find_target_group.return_value = "target_group_arn" + + add_maintenance_page( + boto_mock, + "listener_arn", + "test-application", + "development", + [mock_application.services["web"]], + ["1.2.3.4"], + template, + ) + + assert create_header_rule.call_count == 2 + create_header_rule.assert_has_calls( + [ + call( + boto_mock.client(), + "listener_arn", + "target_group_arn", + "X-Forwarded-For", + ["1.2.3.4"], + "AllowedIps", + 1, + ), + call( + boto_mock.client(), + "listener_arn", + "target_group_arn", + "Bypass-Key", + ["abc"], + "BypassIpFilter", + 3, + ), + ] + ) + create_source_ip.assert_has_calls( + [ + call( + boto_mock.client(), + "listener_arn", + "target_group_arn", + ["1.2.3.4"], + "AllowedSourceIps", + 2, + ) + ] + ) + boto_mock.client().create_rule.assert_called_once_with( + ListenerArn="listener_arn", + Priority=4, + Conditions=[ + { + "Field": "path-pattern", + "PathPatternConfig": {"Values": ["/*"]}, + } + ], + Actions=[ + { + "Type": "fixed-response", + "FixedResponseConfig": { + "StatusCode": "503", + "ContentType": "text/html", + "MessageBody": template, + }, + } + ], + Tags=[ + {"Key": "name", "Value": "MaintenancePage"}, + {"Key": "type", "Value": template}, + ], + ) + + +class TestEnvironmentMaintenanceTemplates: + @pytest.mark.parametrize("template", ["default", "migration", "dmas-migration"]) + def test_template_length(self, template): + + contents = get_maintenance_page_template(template) + assert len(contents) <= 1024 + + @pytest.mark.parametrize("template", ["default", "migration", "dmas-migration"]) + def test_template_no_new_lines(self, template): + + contents = get_maintenance_page_template(template) + assert "\n" not in contents + + +class TestCommandHelperMethods: + @patch("dbt_platform_helper.domain.maintenance_page.load_application") + def test_get_app_environment(self, mock_load_application): + + development = Mock() + application = Application(name="test-application") + application.environments = {"development": development} + mock_load_application.return_value = application + + app_environment = get_app_environment("test-application", "development") + + assert app_environment == development + + @patch("dbt_platform_helper.domain.maintenance_page.load_application") + def test_get_app_environment_does_not_exist(self, mock_load_application, capsys): + + CliRunner() + application = Application(name="test-application") + mock_load_application.return_value = application + + with pytest.raises(click.Abort): + get_app_environment("test-application", "development") + + captured = capsys.readouterr() + + assert ( + "The environment development was not found in the application test-application." + in captured.out + ) + + def _create_subnet(self, session): + ec2 = session.client("ec2") + vpc_id = ec2.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"]["VpcId"] + + return ( + vpc_id, + ec2.create_subnet(VpcId=vpc_id, CidrBlock="10.0.1.0/24")["Subnet"]["SubnetId"], + ) + + def _create_listener(self, elbv2_client): + _, subnet_id = self._create_subnet(boto3.Session()) + load_balancer_arn = elbv2_client.create_load_balancer( + Name="test-load-balancer", Subnets=[subnet_id] + )["LoadBalancers"][0]["LoadBalancerArn"] + return elbv2_client.create_listener( + LoadBalancerArn=load_balancer_arn, DefaultActions=[{"Type": "forward"}] + )["Listeners"][0]["ListenerArn"] + + def _create_listener_rule(self, elbv2_client=None, listener_arn=None, priority=1): + if not elbv2_client: + elbv2_client = boto3.client("elbv2") + + if not listener_arn: + listener_arn = self._create_listener(elbv2_client) + + rule_response = elbv2_client.create_rule( + ListenerArn=listener_arn, + Tags=[{"Key": "test-key", "Value": "test-value"}], + Conditions=[{"Field": "path-pattern", "PathPatternConfig": {"Values": ["/test-path"]}}], + Priority=priority, + Actions=[ + { + "Type": "fixed-response", + "FixedResponseConfig": { + "MessageBody": "test response", + "StatusCode": "200", + "ContentType": "text/plain", + }, + } + ], + ) + + return rule_response["Rules"][0]["RuleArn"], elbv2_client, listener_arn + + def _create_target_group(self): + ec2_client = boto3.client("ec2") + vpc_response = ec2_client.create_vpc(CidrBlock="10.0.0.0/16") + vpc_id = vpc_response["Vpc"]["VpcId"] + + return boto3.client("elbv2").create_target_group( + Name="test-target-group", + Protocol="HTTPS", + Port=123, + VpcId=vpc_id, + Tags=[ + {"Key": "copilot-application", "Value": "test-application"}, + {"Key": "copilot-environment", "Value": "development"}, + {"Key": "copilot-service", "Value": "web"}, + ], + )["TargetGroups"][0]["TargetGroupArn"] + + @mock_aws + def test_find_target_group(self): + + target_group_arn = self._create_target_group() + + assert ( + find_target_group("test-application", "development", "web", boto3.session.Session()) + == target_group_arn + ) + + @mock_aws + def test_find_target_group_not_found(self): + + assert ( + find_target_group("test-application", "development", "web", boto3.session.Session()) + is None + ) + + @mock_aws + def test_delete_listener_rule(self): + + rule_arn, elbv2_client, listener_arn = self._create_listener_rule() + rule_2_arn, _, _ = self._create_listener_rule( + priority=2, elbv2_client=elbv2_client, listener_arn=listener_arn + ) + rules = [ + {"ResourceArn": rule_arn, "Tags": [{"Key": "name", "Value": "test-tag"}]}, + {"ResourceArn": rule_2_arn, "Tags": [{"Key": "name", "Value": "test-tag"}]}, + ] + + described_rules = elbv2_client.describe_rules(ListenerArn=listener_arn)["Rules"] + + # sanity check that default and two newly created rules exist + assert len(described_rules) == 3 + + delete_listener_rule(rules, "test-tag", elbv2_client) + + rules = elbv2_client.describe_rules(ListenerArn=listener_arn)["Rules"] + + assert len(rules) == 1 + + @mock_aws + def test_create_header_rule(self, capsys): + + elbv2_client = boto3.client("elbv2") + listener_arn = self._create_listener(elbv2_client) + target_group_arn = self._create_target_group() + elbv2_client.create_rule( + ListenerArn=listener_arn, + Tags=[{"Key": "test-key", "Value": "test-value"}], + Conditions=[{"Field": "host-header", "HostHeaderConfig": {"Values": ["/test-path"]}}], + Priority=500, + Actions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], + ) + rules = elbv2_client.describe_rules(ListenerArn=listener_arn)["Rules"] + assert len(rules) == 2 + + create_header_rule( + elbv2_client, + listener_arn, + target_group_arn, + "X-Forwarded-For", + ["1.2.3.4", "5.6.7.8"], + "AllowedIps", + 333, + ) + + rules = elbv2_client.describe_rules(ListenerArn=listener_arn)["Rules"] + assert len(rules) == 3 # 1 default + 1 forward + 1 newly created + assert rules[1]["Conditions"][0]["HttpHeaderConfig"]["Values"], ["1.2.3.4", "5.6.7.8"] + assert rules[1]["Priority"] == "333" + + captured = capsys.readouterr() + + assert ( + f"Creating listener rule AllowedIps for HTTPS Listener with arn {listener_arn}.\n\nIf request header X-Forwarded-For contains one of the values ['1.2.3.4', '5.6.7.8'], the request will be forwarded to target group with arn {target_group_arn}." + in captured.out + ) + + @mock_aws + def test_create_source_ip_rule(self, capsys): + + elbv2_client = boto3.client("elbv2") + listener_arn = self._create_listener(elbv2_client) + target_group_arn = self._create_target_group() + elbv2_client.create_rule( + ListenerArn=listener_arn, + Tags=[{"Key": "test-key", "Value": "test-value"}], + Conditions=[{"Field": "host-header", "HostHeaderConfig": {"Values": ["/test-path"]}}], + Priority=500, + Actions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], + ) + rules = elbv2_client.describe_rules(ListenerArn=listener_arn)["Rules"] + assert len(rules) == 2 + + create_source_ip_rule( + elbv2_client, + listener_arn, + target_group_arn, + ["1.2.3.4", "5.6.7.8"], + "AllowedSourceIps", + 333, + ) + + rules = elbv2_client.describe_rules(ListenerArn=listener_arn)["Rules"] + assert len(rules) == 3 # 1 default + 1 forward + 1 newly created + assert rules[1]["Conditions"][0]["SourceIpConfig"]["Values"], ["1.2.3.4", "5.6.7.8"] + assert rules[1]["Priority"] == "333" + + captured = capsys.readouterr() + + assert ( + f"Creating listener rule AllowedSourceIps for HTTPS Listener with arn {listener_arn}.\n\nIf request source ip matches one of the values ['1.2.3.4', '5.6.7.8'], the request will be forwarded to target group with arn {target_group_arn}." + in captured.out + ) + + @pytest.mark.parametrize( + "vpc, param_value, expected", + [ + ( + "vpc1", + "192.168.1.1,192.168.1.2,192.168.1.3", + ["192.168.1.1", "192.168.1.2", "192.168.1.3"], + ), + ( + "vpc2", + " 192.168.2.1 , 192.168.2.2 , 192.168.2.3 ", + ["192.168.2.1", "192.168.2.2", "192.168.2.3"], + ), + ( + None, + "192.168.1.1,192.168.1.2,192.168.1.3", + ["192.168.1.1", "192.168.1.2", "192.168.1.3"], + ), + ], + ) + @mock_aws + def test_get_env_ips(self, vpc, param_value, expected, mock_application): + + response = boto3.client("organizations").create_organization(FeatureSet="ALL") + response["Organization"]["Id"] + create_account_response = boto3.client("organizations").create_account( + Email="test-email@example.com", AccountName="test" + ) + account_id = create_account_response["CreateAccountStatus"]["AccountId"] + mock_application.environments["development"].account_id = account_id + mock_application.environments["development"].sessions[account_id] = boto3.session.Session() + vpc = vpc if vpc else "test" + boto3.client("ssm").put_parameter( + Name=f"/{vpc}/EGRESS_IPS", Value=param_value, Type="String" + ) + environment = mock_application.environments["development"] + result = get_env_ips(vpc, environment) + + assert result == expected + + @mock_aws + def test_get_env_ips_param_not_found(self, capsys, mock_application): + + response = boto3.client("organizations").create_organization(FeatureSet="ALL") + response["Organization"]["Id"] + create_account_response = boto3.client("organizations").create_account( + Email="test-email@example.com", AccountName="test" + ) + account_id = create_account_response["CreateAccountStatus"]["AccountId"] + mock_application.environments["development"].account_id = account_id + mock_application.environments["development"].sessions[account_id] = boto3.session.Session() + environment = mock_application.environments["development"] + + with pytest.raises(click.Abort): + get_env_ips("vpc", environment) + + captured = capsys.readouterr() + + assert "No parameter found with name: /vpc/EGRESS_IPS\n" in captured.out + + @patch("boto3.client") + def test_get_rules_tag_descriptions(self, mock_boto_client): + + mock_client = Mock() + mock_client.describe_tags.side_effect = [ + {"TagDescriptions": ["TagDescriptions1"]}, + {"TagDescriptions": ["TagDescriptions2"]}, + ] + + mock_boto_client.return_value = mock_client + + rules = [] + + for i in range(21): + rules.append({"RuleArn": i}) + + tag_descriptions = get_rules_tag_descriptions(rules, boto3.client("elbv2")) + + assert tag_descriptions == ["TagDescriptions1", "TagDescriptions2"] + assert mock_client.describe_tags.call_count == 2 diff --git a/tests/platform_helper/fixtures/pipeline/platform-config-for-terraform-environment-pipelines-with-tpm-version.yml b/tests/platform_helper/fixtures/pipeline/platform-config-for-terraform-environment-pipelines-with-tpm-version.yml new file mode 100644 index 000000000..cb50c58ab --- /dev/null +++ b/tests/platform_helper/fixtures/pipeline/platform-config-for-terraform-environment-pipelines-with-tpm-version.yml @@ -0,0 +1,39 @@ +application: test-app + +default_versions: + terraform-platform-modules: 4.0.0 + +environments: + dev: + accounts: + deploy: + name: "platform-sandbox-test" + id: "1111111111" + dns: + name: "platform-sandbox-test" + id: "2222222222" + prod: + accounts: + deploy: + name: "platform-prod-test" + id: "3333333333" + dns: + name: "platform-prod-test" + id: "4444444444" + requires_approval: true + +environment_pipelines: + main: + account: platform-sandbox-test + branch: main + slack_channel: "/codebuild/test-slack-channel" + trigger_on_push: false + environments: + dev: + prod-main: + account: platform-prod-test + branch: main + slack_channel: "/codebuild/test-slack-channel" + trigger_on_push: false + environments: + prod: diff --git a/tests/platform_helper/fixtures/pipeline/platform-config-for-terraform-environment-pipelines.yml b/tests/platform_helper/fixtures/pipeline/platform-config-for-terraform-environment-pipelines.yml new file mode 100644 index 000000000..ba08e7c90 --- /dev/null +++ b/tests/platform_helper/fixtures/pipeline/platform-config-for-terraform-environment-pipelines.yml @@ -0,0 +1,36 @@ +application: test-app + +environments: + dev: + accounts: + deploy: + name: "platform-sandbox-test" + id: "1111111111" + dns: + name: "platform-sandbox-test" + id: "2222222222" + prod: + accounts: + deploy: + name: "platform-prod-test" + id: "3333333333" + dns: + name: "platform-prod-test" + id: "4444444444" + requires_approval: true + +environment_pipelines: + main: + account: platform-sandbox-test + branch: main + slack_channel: "/codebuild/test-slack-channel" + trigger_on_push: false + environments: + dev: + prod-main: + account: platform-prod-test + branch: main + slack_channel: "/codebuild/test-slack-channel" + trigger_on_push: false + environments: + prod: diff --git a/tests/platform_helper/fixtures/pipeline/platform-config-legacy-project.yml b/tests/platform_helper/fixtures/pipeline/platform-config-legacy-project.yml new file mode 100644 index 000000000..354db6351 --- /dev/null +++ b/tests/platform_helper/fixtures/pipeline/platform-config-legacy-project.yml @@ -0,0 +1,36 @@ +application: test-app +legacy_project: True +environments: + dev: + accounts: + deploy: + name: "platform-sandbox-test" + id: "1111111111" + dns: + name: "platform-sandbox-test" + id: "2222222222" + prod: + accounts: + deploy: + name: "platform-prod-test" + id: "3333333333" + dns: + name: "platform-prod-test" + id: "4444444444" + requires_approval: true + +environment_pipelines: + main: + account: platform-sandbox-test + branch: main + slack_channel: "/codebuild/test-slack-channel" + trigger_on_push: false + environments: + dev: + prod-main: + account: platform-prod-test + branch: main + slack_channel: "/codebuild/test-slack-channel" + trigger_on_push: false + environments: + prod: diff --git a/tests/platform_helper/providers/test_load_balancers.py b/tests/platform_helper/providers/test_load_balancers.py new file mode 100644 index 000000000..0f9e8cb6d --- /dev/null +++ b/tests/platform_helper/providers/test_load_balancers.py @@ -0,0 +1,59 @@ +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from dbt_platform_helper.providers.load_balancers import ListenerNotFoundError +from dbt_platform_helper.providers.load_balancers import LoadBalancerNotFoundError +from dbt_platform_helper.providers.load_balancers import find_https_listener +from dbt_platform_helper.providers.load_balancers import find_load_balancer + + +class TestFindHTTPSListener: + @patch("dbt_platform_helper.providers.load_balancers.find_load_balancer", return_value="lb_arn") + def test_when_no_https_listener_present(self, find_load_balancer): + boto_mock = MagicMock() + boto_mock.client().describe_listeners.return_value = {"Listeners": []} + with pytest.raises(ListenerNotFoundError): + find_https_listener(boto_mock, "test-application", "development") + + @patch("dbt_platform_helper.providers.load_balancers.find_load_balancer", return_value="lb_arn") + def test_when_https_listener_present(self, find_load_balancer): + + boto_mock = MagicMock() + boto_mock.client().describe_listeners.return_value = { + "Listeners": [{"ListenerArn": "listener_arn", "Protocol": "HTTPS"}] + } + + listener_arn = find_https_listener(boto_mock, "test-application", "development") + assert "listener_arn" == listener_arn + + +class TestFindLoadBalancer: + def test_when_no_load_balancer_exists(self): + + boto_mock = MagicMock() + boto_mock.client().describe_load_balancers.return_value = {"LoadBalancers": []} + with pytest.raises(LoadBalancerNotFoundError): + find_load_balancer(boto_mock, "test-application", "development") + + def test_when_a_load_balancer_exists(self): + + boto_mock = MagicMock() + boto_mock.client().describe_load_balancers.return_value = { + "LoadBalancers": [{"LoadBalancerArn": "lb_arn"}] + } + boto_mock.client().describe_tags.return_value = { + "TagDescriptions": [ + { + "ResourceArn": "lb_arn", + "Tags": [ + {"Key": "copilot-application", "Value": "test-application"}, + {"Key": "copilot-environment", "Value": "development"}, + ], + } + ] + } + + lb_arn = find_load_balancer(boto_mock, "test-application", "development") + assert "lb_arn" == lb_arn diff --git a/tests/platform_helper/test_command_database.py b/tests/platform_helper/test_command_database.py index 8647ffbb7..7a0428563 100644 --- a/tests/platform_helper/test_command_database.py +++ b/tests/platform_helper/test_command_database.py @@ -2,6 +2,7 @@ from click.testing import CliRunner +from dbt_platform_helper.commands.database import copy from dbt_platform_helper.commands.database import dump from dbt_platform_helper.commands.database import load @@ -14,24 +15,20 @@ def test_command_dump_success(mock_database_copy_object): result = runner.invoke( dump, [ - "--account-id", - "12345", "--app", "my_app", - "--env", + "--from", "my_env", "--database", "my_postgres", - "--vpc-name", + "--from-vpc", "my_vpc", ], ) assert result.exit_code == 0 - mock_database_copy_object.assert_called_once_with( - "12345", "my_app", "my_env", "my_postgres", "my_vpc" - ) - mock_database_copy_instance.dump.assert_called_once_with() + mock_database_copy_object.assert_called_once_with("my_app", "my_postgres") + mock_database_copy_instance.dump.assert_called_once_with("my_env", "my_vpc") @patch("dbt_platform_helper.commands.database.DatabaseCopy") @@ -41,21 +38,199 @@ def test_command_load_success(mock_database_copy_object): result = runner.invoke( load, [ - "--account-id", - "12345", "--app", "my_app", - "--env", + "--to", "my_env", "--database", "my_postgres", - "--vpc-name", + "--to-vpc", "my_vpc", ], ) assert result.exit_code == 0 - mock_database_copy_object.assert_called_once_with( - "12345", "my_app", "my_env", "my_postgres", "my_vpc" + mock_database_copy_object.assert_called_once_with("my_app", "my_postgres", False) + mock_database_copy_instance.load.assert_called_once_with("my_env", "my_vpc") + + +@patch("dbt_platform_helper.commands.database.DatabaseCopy") +def test_command_load_success_with_auto_approve(mock_database_copy_object): + mock_database_copy_instance = mock_database_copy_object.return_value + runner = CliRunner() + result = runner.invoke( + load, + [ + "--app", + "my_app", + "--to", + "my_env", + "--database", + "my_postgres", + "--to-vpc", + "my_vpc", + "--auto-approve", + ], + ) + + assert result.exit_code == 0 + mock_database_copy_object.assert_called_once_with("my_app", "my_postgres", True) + mock_database_copy_instance.load.assert_called_once_with("my_env", "my_vpc") + + +@patch("dbt_platform_helper.commands.database.DatabaseCopy") +def test_command_copy_success(mock_database_copy_object): + mock_database_copy_instance = mock_database_copy_object.return_value + runner = CliRunner() + result = runner.invoke( + copy, + [ + "--app", + "my_app", + "--from", + "my_prod_env", + "--to", + "my_hotfix_env", + "--database", + "my_postgres", + "--from-vpc", + "my_from_vpc", + "--to-vpc", + "my_to_vpc", + ], + ) + + assert result.exit_code == 0 + mock_database_copy_object.assert_called_once_with("my_app", "my_postgres", False) + mock_database_copy_instance.copy.assert_called_once_with( + "my_prod_env", + "my_hotfix_env", + "my_from_vpc", + "my_to_vpc", + ("web",), + "default", + False, + ) + + +@patch("dbt_platform_helper.commands.database.DatabaseCopy") +def test_command_copy_success_with_auto_approve(mock_database_copy_object): + mock_database_copy_instance = mock_database_copy_object.return_value + runner = CliRunner() + result = runner.invoke( + copy, + [ + "--app", + "my_app", + "--from", + "my_prod_env", + "--to", + "my_hotfix_env", + "--database", + "my_postgres", + "--from-vpc", + "my_from_vpc", + "--to-vpc", + "my_to_vpc", + "--auto-approve", + "--svc", + "other", + "--svc", + "service", + "--template", + "migration", + ], + ) + + assert result.exit_code == 0 + mock_database_copy_object.assert_called_once_with("my_app", "my_postgres", True) + mock_database_copy_instance.copy.assert_called_once_with( + "my_prod_env", + "my_hotfix_env", + "my_from_vpc", + "my_to_vpc", + ("other", "service"), + "migration", + False, + ) + + +@patch("dbt_platform_helper.commands.database.DatabaseCopy") +def test_command_copy_success_with_no_maintenance_page(mock_database_copy_object): + mock_database_copy_instance = mock_database_copy_object.return_value + runner = CliRunner() + result = runner.invoke( + copy, + [ + "--app", + "my_app", + "--from", + "my_prod_env", + "--to", + "my_hotfix_env", + "--database", + "my_postgres", + "--from-vpc", + "my_from_vpc", + "--to-vpc", + "my_to_vpc", + "--auto-approve", + "--svc", + "other", + "--svc", + "service", + "--no-maintenance-page", + ], + ) + + assert result.exit_code == 0 + mock_database_copy_object.assert_called_once_with("my_app", "my_postgres", True) + mock_database_copy_instance.copy.assert_called_once_with( + "my_prod_env", + "my_hotfix_env", + "my_from_vpc", + "my_to_vpc", + ("other", "service"), + "default", + True, + ) + + +@patch("dbt_platform_helper.commands.database.DatabaseCopy") +def test_command_copy_success_with_maintenance_page(mock_database_copy_object): + mock_database_copy_instance = mock_database_copy_object.return_value + runner = CliRunner() + result = runner.invoke( + copy, + [ + "--app", + "my_app", + "--from", + "my_prod_env", + "--to", + "my_hotfix_env", + "--database", + "my_postgres", + "--from-vpc", + "my_from_vpc", + "--to-vpc", + "my_to_vpc", + "--auto-approve", + "--svc", + "other", + "--svc", + "service", + ], + ) + + assert result.exit_code == 0 + mock_database_copy_object.assert_called_once_with("my_app", "my_postgres", True) + mock_database_copy_instance.copy.assert_called_once_with( + "my_prod_env", + "my_hotfix_env", + "my_from_vpc", + "my_to_vpc", + ("other", "service"), + "default", + False, ) - mock_database_copy_instance.load.assert_called_once_with() diff --git a/tests/platform_helper/test_command_environment.py b/tests/platform_helper/test_command_environment.py index a4514a01a..903d5135a 100644 --- a/tests/platform_helper/test_command_environment.py +++ b/tests/platform_helper/test_command_environment.py @@ -2,7 +2,6 @@ from unittest.mock import ANY from unittest.mock import MagicMock from unittest.mock import Mock -from unittest.mock import call from unittest.mock import patch import boto3 @@ -13,48 +12,32 @@ from moto import mock_aws from dbt_platform_helper.commands.environment import CertificateNotFoundError -from dbt_platform_helper.commands.environment import ListenerNotFoundError -from dbt_platform_helper.commands.environment import ListenerRuleNotFoundError -from dbt_platform_helper.commands.environment import LoadBalancerNotFoundError -from dbt_platform_helper.commands.environment import add_maintenance_page -from dbt_platform_helper.commands.environment import create_header_rule -from dbt_platform_helper.commands.environment import create_source_ip_rule -from dbt_platform_helper.commands.environment import delete_listener_rule from dbt_platform_helper.commands.environment import find_https_certificate -from dbt_platform_helper.commands.environment import find_https_listener -from dbt_platform_helper.commands.environment import find_load_balancer -from dbt_platform_helper.commands.environment import find_target_group from dbt_platform_helper.commands.environment import generate from dbt_platform_helper.commands.environment import generate_terraform -from dbt_platform_helper.commands.environment import get_app_environment from dbt_platform_helper.commands.environment import get_cert_arn -from dbt_platform_helper.commands.environment import get_env_ips -from dbt_platform_helper.commands.environment import get_listener_rule_by_tag -from dbt_platform_helper.commands.environment import get_maintenance_page -from dbt_platform_helper.commands.environment import get_maintenance_page_template -from dbt_platform_helper.commands.environment import get_rules_tag_descriptions from dbt_platform_helper.commands.environment import get_subnet_ids from dbt_platform_helper.commands.environment import get_vpc_id from dbt_platform_helper.commands.environment import offline from dbt_platform_helper.commands.environment import online -from dbt_platform_helper.commands.environment import remove_maintenance_page from dbt_platform_helper.constants import PLATFORM_CONFIG_FILE -from dbt_platform_helper.utils.application import Application +from dbt_platform_helper.providers.load_balancers import ListenerNotFoundError +from dbt_platform_helper.providers.load_balancers import LoadBalancerNotFoundError from dbt_platform_helper.utils.application import Service from tests.platform_helper.conftest import BASE_DIR class TestEnvironmentOfflineCommand: - @patch("dbt_platform_helper.commands.environment.load_application") + @patch("dbt_platform_helper.domain.maintenance_page.load_application") @patch( - "dbt_platform_helper.commands.environment.find_https_listener", + "dbt_platform_helper.domain.maintenance_page.find_https_listener", return_value="https_listener", ) - @patch("dbt_platform_helper.commands.environment.get_maintenance_page", return_value=None) + @patch("dbt_platform_helper.domain.maintenance_page.get_maintenance_page", return_value=None) @patch( - "dbt_platform_helper.commands.environment.get_env_ips", return_value=["0.1.2.3, 4.5.6.7"] + "dbt_platform_helper.domain.maintenance_page.get_env_ips", return_value=["0.1.2.3, 4.5.6.7"] ) - @patch("dbt_platform_helper.commands.environment.add_maintenance_page", return_value=None) + @patch("dbt_platform_helper.domain.maintenance_page.add_maintenance_page", return_value=None) def test_successful_offline( self, add_maintenance_page, @@ -94,16 +77,16 @@ def test_successful_offline( "application test-application" ) in result.output - @patch("dbt_platform_helper.commands.environment.load_application") + @patch("dbt_platform_helper.domain.maintenance_page.load_application") @patch( - "dbt_platform_helper.commands.environment.find_https_listener", + "dbt_platform_helper.domain.maintenance_page.find_https_listener", return_value="https_listener", ) - @patch("dbt_platform_helper.commands.environment.get_maintenance_page", return_value=None) + @patch("dbt_platform_helper.domain.maintenance_page.get_maintenance_page", return_value=None) @patch( - "dbt_platform_helper.commands.environment.get_env_ips", return_value=["0.1.2.3, 4.5.6.7"] + "dbt_platform_helper.domain.maintenance_page.get_env_ips", return_value=["0.1.2.3, 4.5.6.7"] ) - @patch("dbt_platform_helper.commands.environment.add_maintenance_page", return_value=None) + @patch("dbt_platform_helper.domain.maintenance_page.add_maintenance_page", return_value=None) def test_successful_offline_with_custom_template( self, add_maintenance_page, @@ -145,19 +128,20 @@ def test_successful_offline_with_custom_template( "application test-application" ) in result.output - @patch("dbt_platform_helper.commands.environment.load_application") + @patch("dbt_platform_helper.domain.maintenance_page.load_application") @patch( - "dbt_platform_helper.commands.environment.find_https_listener", + "dbt_platform_helper.domain.maintenance_page.find_https_listener", return_value="https_listener", ) @patch( - "dbt_platform_helper.commands.environment.get_maintenance_page", return_value="maintenance" + "dbt_platform_helper.domain.maintenance_page.get_maintenance_page", + return_value="maintenance", ) - @patch("dbt_platform_helper.commands.environment.remove_maintenance_page", return_value=None) + @patch("dbt_platform_helper.domain.maintenance_page.remove_maintenance_page", return_value=None) @patch( - "dbt_platform_helper.commands.environment.get_env_ips", return_value=["0.1.2.3, 4.5.6.7"] + "dbt_platform_helper.domain.maintenance_page.get_env_ips", return_value=["0.1.2.3, 4.5.6.7"] ) - @patch("dbt_platform_helper.commands.environment.add_maintenance_page", return_value=None) + @patch("dbt_platform_helper.domain.maintenance_page.add_maintenance_page", return_value=None) def test_successful_offline_when_already_offline( self, add_maintenance_page, @@ -202,11 +186,11 @@ def test_successful_offline_when_already_offline( "application test-application" ) in result.output - @patch("dbt_platform_helper.commands.environment.load_application") - @patch("dbt_platform_helper.commands.environment.find_https_listener") - @patch("dbt_platform_helper.commands.environment.get_maintenance_page") - @patch("dbt_platform_helper.commands.environment.remove_maintenance_page") - @patch("dbt_platform_helper.commands.environment.add_maintenance_page") + @patch("dbt_platform_helper.domain.maintenance_page.load_application") + @patch("dbt_platform_helper.domain.maintenance_page.find_https_listener") + @patch("dbt_platform_helper.domain.maintenance_page.get_maintenance_page") + @patch("dbt_platform_helper.domain.maintenance_page.remove_maintenance_page") + @patch("dbt_platform_helper.domain.maintenance_page.add_maintenance_page") def test_offline_an_environment_when_load_balancer_not_found( self, add_maintenance_page, @@ -233,11 +217,11 @@ def test_offline_an_environment_when_load_balancer_not_found( get_maintenance_page.assert_not_called() remove_maintenance_page.assert_not_called() - @patch("dbt_platform_helper.commands.environment.load_application") - @patch("dbt_platform_helper.commands.environment.find_https_listener") - @patch("dbt_platform_helper.commands.environment.get_maintenance_page") - @patch("dbt_platform_helper.commands.environment.remove_maintenance_page") - @patch("dbt_platform_helper.commands.environment.add_maintenance_page") + @patch("dbt_platform_helper.domain.maintenance_page.load_application") + @patch("dbt_platform_helper.domain.maintenance_page.find_https_listener") + @patch("dbt_platform_helper.domain.maintenance_page.get_maintenance_page") + @patch("dbt_platform_helper.domain.maintenance_page.remove_maintenance_page") + @patch("dbt_platform_helper.domain.maintenance_page.add_maintenance_page") def test_offline_an_environment_when_listener_not_found( self, add_maintenance_page, @@ -266,16 +250,16 @@ def test_offline_an_environment_when_listener_not_found( remove_maintenance_page.assert_not_called() add_maintenance_page.assert_not_called() - @patch("dbt_platform_helper.commands.environment.load_application") + @patch("dbt_platform_helper.domain.maintenance_page.load_application") @patch( - "dbt_platform_helper.commands.environment.find_https_listener", + "dbt_platform_helper.domain.maintenance_page.find_https_listener", return_value="https_listener", ) - @patch("dbt_platform_helper.commands.environment.get_maintenance_page", return_value=None) + @patch("dbt_platform_helper.domain.maintenance_page.get_maintenance_page", return_value=None) @patch( - "dbt_platform_helper.commands.environment.get_env_ips", return_value=["0.1.2.3, 4.5.6.7"] + "dbt_platform_helper.domain.maintenance_page.get_env_ips", return_value=["0.1.2.3, 4.5.6.7"] ) - @patch("dbt_platform_helper.commands.environment.add_maintenance_page", return_value=None) + @patch("dbt_platform_helper.domain.maintenance_page.add_maintenance_page", return_value=None) def test_successful_offline_multiple_services( self, add_maintenance_page, @@ -321,13 +305,15 @@ def test_successful_offline_multiple_services( class TestEnvironmentOnlineCommand: - @patch("dbt_platform_helper.commands.environment.load_application") + @patch("dbt_platform_helper.domain.maintenance_page.load_application") @patch( - "dbt_platform_helper.commands.environment.find_https_listener", + "dbt_platform_helper.domain.maintenance_page.find_https_listener", return_value="https_listener", ) - @patch("dbt_platform_helper.commands.environment.get_maintenance_page", return_value="default") - @patch("dbt_platform_helper.commands.environment.remove_maintenance_page", return_value=None) + @patch( + "dbt_platform_helper.domain.maintenance_page.get_maintenance_page", return_value="default" + ) + @patch("dbt_platform_helper.domain.maintenance_page.remove_maintenance_page", return_value=None) def test_successful_online( self, remove_maintenance_page, @@ -357,13 +343,13 @@ def test_successful_online( "application test-application" ) in result.output - @patch("dbt_platform_helper.commands.environment.load_application") + @patch("dbt_platform_helper.domain.maintenance_page.load_application") @patch( - "dbt_platform_helper.commands.environment.find_https_listener", + "dbt_platform_helper.domain.maintenance_page.find_https_listener", return_value="https_listener", ) - @patch("dbt_platform_helper.commands.environment.get_maintenance_page", return_value=None) - @patch("dbt_platform_helper.commands.environment.remove_maintenance_page", return_value=None) + @patch("dbt_platform_helper.domain.maintenance_page.get_maintenance_page", return_value=None) + @patch("dbt_platform_helper.domain.maintenance_page.remove_maintenance_page", return_value=None) def test_online_an_environment_that_is_not_offline( self, remove_maintenance_page, @@ -385,10 +371,10 @@ def test_online_an_environment_that_is_not_offline( get_maintenance_page.assert_called_with(ANY, "https_listener") remove_maintenance_page.assert_not_called() - @patch("dbt_platform_helper.commands.environment.load_application") - @patch("dbt_platform_helper.commands.environment.find_https_listener") - @patch("dbt_platform_helper.commands.environment.get_maintenance_page") - @patch("dbt_platform_helper.commands.environment.remove_maintenance_page") + @patch("dbt_platform_helper.domain.maintenance_page.load_application") + @patch("dbt_platform_helper.domain.maintenance_page.find_https_listener") + @patch("dbt_platform_helper.domain.maintenance_page.get_maintenance_page") + @patch("dbt_platform_helper.domain.maintenance_page.remove_maintenance_page") def test_online_an_environment_when_listener_not_found( self, remove_maintenance_page, @@ -415,10 +401,10 @@ def test_online_an_environment_when_listener_not_found( get_maintenance_page.assert_not_called() remove_maintenance_page.assert_not_called() - @patch("dbt_platform_helper.commands.environment.load_application") - @patch("dbt_platform_helper.commands.environment.find_https_listener") - @patch("dbt_platform_helper.commands.environment.get_maintenance_page") - @patch("dbt_platform_helper.commands.environment.remove_maintenance_page") + @patch("dbt_platform_helper.domain.maintenance_page.load_application") + @patch("dbt_platform_helper.domain.maintenance_page.find_https_listener") + @patch("dbt_platform_helper.domain.maintenance_page.get_maintenance_page") + @patch("dbt_platform_helper.domain.maintenance_page.remove_maintenance_page") def test_online_an_environment_when_load_balancer_not_found( self, remove_maintenance_page, @@ -776,56 +762,6 @@ def test_cert_arn_failure(self, capsys): ) -class TestFindLoadBalancer: - def test_when_no_load_balancer_exists(self): - - boto_mock = MagicMock() - boto_mock.client().describe_load_balancers.return_value = {"LoadBalancers": []} - with pytest.raises(LoadBalancerNotFoundError): - find_load_balancer(boto_mock, "test-application", "development") - - def test_when_a_load_balancer_exists(self): - - boto_mock = MagicMock() - boto_mock.client().describe_load_balancers.return_value = { - "LoadBalancers": [{"LoadBalancerArn": "lb_arn"}] - } - boto_mock.client().describe_tags.return_value = { - "TagDescriptions": [ - { - "ResourceArn": "lb_arn", - "Tags": [ - {"Key": "copilot-application", "Value": "test-application"}, - {"Key": "copilot-environment", "Value": "development"}, - ], - } - ] - } - - lb_arn = find_load_balancer(boto_mock, "test-application", "development") - assert "lb_arn" == lb_arn - - -class TestFindHTTPSListener: - @patch("dbt_platform_helper.commands.environment.find_load_balancer", return_value="lb_arn") - def test_when_no_https_listener_present(self, find_load_balancer): - boto_mock = MagicMock() - boto_mock.client().describe_listeners.return_value = {"Listeners": []} - with pytest.raises(ListenerNotFoundError): - find_https_listener(boto_mock, "test-application", "development") - - @patch("dbt_platform_helper.commands.environment.find_load_balancer", return_value="lb_arn") - def test_when_https_listener_present(self, find_load_balancer): - - boto_mock = MagicMock() - boto_mock.client().describe_listeners.return_value = { - "Listeners": [{"ListenerArn": "listener_arn", "Protocol": "HTTPS"}] - } - - listener_arn = find_https_listener(boto_mock, "test-application", "development") - assert "listener_arn" == listener_arn - - class TestFindHTTPSCertificate: @patch( "dbt_platform_helper.commands.environment.find_https_listener", @@ -866,497 +802,3 @@ def test_when_multiple_https_certificate_present(self, mock_find_https_listener) certificate_arn = find_https_certificate(boto_mock, "test-application", "development") assert "certificate_arn_default" == certificate_arn - - -class TestGetMaintenancePage: - def test_when_environment_online(self): - - boto_mock = MagicMock() - boto_mock.client().describe_rules.return_value = {"Rules": [{"RuleArn": "rule_arn"}]} - boto_mock.client().describe_tags.return_value = { - "TagDescriptions": [{"ResourceArn": "rule_arn", "Tags": []}] - } - - maintenance_page = get_maintenance_page(boto_mock, "listener_arn") - assert maintenance_page is None - - def test_when_environment_offline_with_default_page(self): - - boto_mock = MagicMock() - boto_mock.client().describe_rules.return_value = {"Rules": [{"RuleArn": "rule_arn"}]} - boto_mock.client().describe_tags.return_value = { - "TagDescriptions": [ - { - "ResourceArn": "rule_arn", - "Tags": [ - {"Key": "name", "Value": "MaintenancePage"}, - {"Key": "type", "Value": "default"}, - ], - } - ] - } - - maintenance_page = get_maintenance_page(boto_mock, "listener_arn") - assert maintenance_page == "default" - - -class TestRemoveMaintenancePage: - def test_when_environment_online(self): - - boto_mock = MagicMock() - boto_mock.client().describe_rules.return_value = {"Rules": [{"RuleArn": "rule_arn"}]} - boto_mock.client().describe_tags.return_value = { - "TagDescriptions": [{"ResourceArn": "rule_arn", "Tags": []}] - } - - with pytest.raises(ListenerRuleNotFoundError): - remove_maintenance_page(boto_mock, "listener_arn") - - @patch("dbt_platform_helper.commands.environment.delete_listener_rule") - def test_when_environment_offline(self, delete_listener_rule): - - boto_mock = MagicMock() - boto_mock.client().describe_rules.return_value = { - "Rules": [{"RuleArn": "rule_arn"}, {"RuleArn": "allowed_ips_rule_arn"}] - } - tag_descriptions = [ - { - "ResourceArn": "rule_arn", - "Tags": [ - {"Key": "name", "Value": "MaintenancePage"}, - {"Key": "type", "Value": "default"}, - ], - }, - { - "ResourceArn": "allowed_ips_rule_arn", - "Tags": [ - {"Key": "name", "Value": "AllowedIps"}, - {"Key": "type", "Value": "default"}, - ], - }, - { - "ResourceArn": "allowed_source_ips_rule_arn", - "Tags": [ - {"Key": "name", "Value": "AllowedSourceIps"}, - {"Key": "type", "Value": "default"}, - ], - }, - ] - boto_mock.client().describe_tags.return_value = {"TagDescriptions": tag_descriptions} - boto_mock.client().delete_rule.return_value = None - - remove_maintenance_page(boto_mock, "listener_arn") - - delete_listener_rule.assert_has_calls( - [ - call(tag_descriptions, "MaintenancePage", boto_mock.client()), - call().__bool__(), # return value of mock is referenced in line: `if name == "MaintenancePage" and not deleted` - call(tag_descriptions, "AllowedIps", boto_mock.client()), - call(tag_descriptions, "BypassIpFilter", boto_mock.client()), - call(tag_descriptions, "AllowedSourceIps", boto_mock.client()), - ] - ) - - -class TestAddMaintenancePage: - @pytest.mark.parametrize("template", ["default", "migration", "dmas-migration"]) - @patch("dbt_platform_helper.commands.environment.random.choices", return_value=["a", "b", "c"]) - @patch("dbt_platform_helper.commands.environment.create_source_ip_rule") - @patch("dbt_platform_helper.commands.environment.create_header_rule") - @patch("dbt_platform_helper.commands.environment.find_target_group") - @patch("dbt_platform_helper.commands.environment.get_maintenance_page_template") - def test_adding_existing_template( - self, - get_maintenance_page_template, - find_target_group, - create_header_rule, - create_source_ip, - choices, - template, - mock_application, - ): - - boto_mock = MagicMock() - get_maintenance_page_template.return_value = template - find_target_group.return_value = "target_group_arn" - - add_maintenance_page( - boto_mock, - "listener_arn", - "test-application", - "development", - [mock_application.services["web"]], - ["1.2.3.4"], - template, - ) - - assert create_header_rule.call_count == 2 - create_header_rule.assert_has_calls( - [ - call( - boto_mock.client(), - "listener_arn", - "target_group_arn", - "X-Forwarded-For", - ["1.2.3.4"], - "AllowedIps", - 100, - ), - call( - boto_mock.client(), - "listener_arn", - "target_group_arn", - "Bypass-Key", - ["abc"], - "BypassIpFilter", - 1, - ), - ] - ) - create_source_ip.assert_has_calls( - [ - call( - boto_mock.client(), - "listener_arn", - "target_group_arn", - ["1.2.3.4"], - "AllowedSourceIps", - 101, - ) - ] - ) - boto_mock.client().create_rule.assert_called_once_with( - ListenerArn="listener_arn", - Priority=700, - Conditions=[ - { - "Field": "path-pattern", - "PathPatternConfig": {"Values": ["/*"]}, - } - ], - Actions=[ - { - "Type": "fixed-response", - "FixedResponseConfig": { - "StatusCode": "503", - "ContentType": "text/html", - "MessageBody": template, - }, - } - ], - Tags=[ - {"Key": "name", "Value": "MaintenancePage"}, - {"Key": "type", "Value": template}, - ], - ) - - -class TestEnvironmentMaintenanceTemplates: - @pytest.mark.parametrize("template", ["default", "migration", "dmas-migration"]) - def test_template_length(self, template): - - contents = get_maintenance_page_template(template) - assert len(contents) <= 1024 - - @pytest.mark.parametrize("template", ["default", "migration", "dmas-migration"]) - def test_template_no_new_lines(self, template): - - contents = get_maintenance_page_template(template) - assert "\n" not in contents - - -class TestCommandHelperMethods: - @patch("dbt_platform_helper.commands.environment.load_application") - def test_get_app_environment(self, mock_load_application): - - development = Mock() - application = Application(name="test-application") - application.environments = {"development": development} - mock_load_application.return_value = application - - app_environment = get_app_environment("test-application", "development") - - assert app_environment == development - - @patch("dbt_platform_helper.commands.environment.load_application") - def test_get_app_environment_does_not_exist(self, mock_load_application, capsys): - - CliRunner() - application = Application(name="test-application") - mock_load_application.return_value = application - - with pytest.raises(click.Abort): - get_app_environment("test-application", "development") - - captured = capsys.readouterr() - - assert ( - "The environment development was not found in the application test-application." - in captured.out - ) - - def _create_subnet(self, session): - ec2 = session.client("ec2") - vpc_id = ec2.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"]["VpcId"] - - return ( - vpc_id, - ec2.create_subnet(VpcId=vpc_id, CidrBlock="10.0.1.0/24")["Subnet"]["SubnetId"], - ) - - def _create_listener(self, elbv2_client): - _, subnet_id = self._create_subnet(boto3.Session()) - load_balancer_arn = elbv2_client.create_load_balancer( - Name="test-load-balancer", Subnets=[subnet_id] - )["LoadBalancers"][0]["LoadBalancerArn"] - return elbv2_client.create_listener( - LoadBalancerArn=load_balancer_arn, DefaultActions=[{"Type": "forward"}] - )["Listeners"][0]["ListenerArn"] - - def _create_listener_rule(self, elbv2_client=None, listener_arn=None, priority=1): - if not elbv2_client: - elbv2_client = boto3.client("elbv2") - - if not listener_arn: - listener_arn = self._create_listener(elbv2_client) - - rule_response = elbv2_client.create_rule( - ListenerArn=listener_arn, - Tags=[{"Key": "test-key", "Value": "test-value"}], - Conditions=[{"Field": "path-pattern", "PathPatternConfig": {"Values": ["/test-path"]}}], - Priority=priority, - Actions=[ - { - "Type": "fixed-response", - "FixedResponseConfig": { - "MessageBody": "test response", - "StatusCode": "200", - "ContentType": "text/plain", - }, - } - ], - ) - - return rule_response["Rules"][0]["RuleArn"], elbv2_client, listener_arn - - def _create_target_group(self): - ec2_client = boto3.client("ec2") - vpc_response = ec2_client.create_vpc(CidrBlock="10.0.0.0/16") - vpc_id = vpc_response["Vpc"]["VpcId"] - - return boto3.client("elbv2").create_target_group( - Name="test-target-group", - Protocol="HTTPS", - Port=123, - VpcId=vpc_id, - Tags=[ - {"Key": "copilot-application", "Value": "test-application"}, - {"Key": "copilot-environment", "Value": "development"}, - {"Key": "copilot-service", "Value": "web"}, - ], - )["TargetGroups"][0]["TargetGroupArn"] - - @mock_aws - def test_get_listener_rule_by_tag(self): - rule_arn, elbv2_client, listener_arn = self._create_listener_rule() - - rule = get_listener_rule_by_tag(elbv2_client, listener_arn, "test-key", "test-value") - - assert rule["RuleArn"] == rule_arn - - @mock_aws - def test_find_target_group(self): - - target_group_arn = self._create_target_group() - - assert ( - find_target_group("test-application", "development", "web", boto3.session.Session()) - == target_group_arn - ) - - @mock_aws - def test_find_target_group_not_found(self): - - assert ( - find_target_group("test-application", "development", "web", boto3.session.Session()) - is None - ) - - @mock_aws - def test_delete_listener_rule(self): - - rule_arn, elbv2_client, listener_arn = self._create_listener_rule() - rule_2_arn, _, _ = self._create_listener_rule( - priority=2, elbv2_client=elbv2_client, listener_arn=listener_arn - ) - rules = [ - {"ResourceArn": rule_arn, "Tags": [{"Key": "name", "Value": "test-tag"}]}, - {"ResourceArn": rule_2_arn, "Tags": [{"Key": "name", "Value": "test-tag"}]}, - ] - - described_rules = elbv2_client.describe_rules(ListenerArn=listener_arn)["Rules"] - - # sanity check that default and two newly created rules exist - assert len(described_rules) == 3 - - delete_listener_rule(rules, "test-tag", elbv2_client) - - rules = elbv2_client.describe_rules(ListenerArn=listener_arn)["Rules"] - - assert len(rules) == 1 - - @mock_aws - def test_create_header_rule(self, capsys): - - elbv2_client = boto3.client("elbv2") - listener_arn = self._create_listener(elbv2_client) - target_group_arn = self._create_target_group() - elbv2_client.create_rule( - ListenerArn=listener_arn, - Tags=[{"Key": "test-key", "Value": "test-value"}], - Conditions=[{"Field": "host-header", "HostHeaderConfig": {"Values": ["/test-path"]}}], - Priority=500, - Actions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], - ) - rules = elbv2_client.describe_rules(ListenerArn=listener_arn)["Rules"] - assert len(rules) == 2 - - create_header_rule( - elbv2_client, - listener_arn, - target_group_arn, - "X-Forwarded-For", - ["1.2.3.4", "5.6.7.8"], - "AllowedIps", - 333, - ) - - rules = elbv2_client.describe_rules(ListenerArn=listener_arn)["Rules"] - assert len(rules) == 3 # 1 default + 1 forward + 1 newly created - assert rules[1]["Conditions"][0]["HttpHeaderConfig"]["Values"], ["1.2.3.4", "5.6.7.8"] - assert rules[1]["Priority"] == "333" - - captured = capsys.readouterr() - - assert ( - f"Creating listener rule AllowedIps for HTTPS Listener with arn {listener_arn}.\n\nIf request header X-Forwarded-For contains one of the values ['1.2.3.4', '5.6.7.8'], the request will be forwarded to target group with arn {target_group_arn}." - in captured.out - ) - - @mock_aws - def test_create_source_ip_rule(self, capsys): - - elbv2_client = boto3.client("elbv2") - listener_arn = self._create_listener(elbv2_client) - target_group_arn = self._create_target_group() - elbv2_client.create_rule( - ListenerArn=listener_arn, - Tags=[{"Key": "test-key", "Value": "test-value"}], - Conditions=[{"Field": "host-header", "HostHeaderConfig": {"Values": ["/test-path"]}}], - Priority=500, - Actions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], - ) - rules = elbv2_client.describe_rules(ListenerArn=listener_arn)["Rules"] - assert len(rules) == 2 - - create_source_ip_rule( - elbv2_client, - listener_arn, - target_group_arn, - ["1.2.3.4", "5.6.7.8"], - "AllowedSourceIps", - 333, - ) - - rules = elbv2_client.describe_rules(ListenerArn=listener_arn)["Rules"] - assert len(rules) == 3 # 1 default + 1 forward + 1 newly created - assert rules[1]["Conditions"][0]["SourceIpConfig"]["Values"], ["1.2.3.4", "5.6.7.8"] - assert rules[1]["Priority"] == "333" - - captured = capsys.readouterr() - - assert ( - f"Creating listener rule AllowedSourceIps for HTTPS Listener with arn {listener_arn}.\n\nIf request source ip matches one of the values ['1.2.3.4', '5.6.7.8'], the request will be forwarded to target group with arn {target_group_arn}." - in captured.out - ) - - @pytest.mark.parametrize( - "vpc, param_value, expected", - [ - ( - "vpc1", - "192.168.1.1,192.168.1.2,192.168.1.3", - ["192.168.1.1", "192.168.1.2", "192.168.1.3"], - ), - ( - "vpc2", - " 192.168.2.1 , 192.168.2.2 , 192.168.2.3 ", - ["192.168.2.1", "192.168.2.2", "192.168.2.3"], - ), - ( - None, - "192.168.1.1,192.168.1.2,192.168.1.3", - ["192.168.1.1", "192.168.1.2", "192.168.1.3"], - ), - ], - ) - @mock_aws - def test_get_env_ips(self, vpc, param_value, expected, mock_application): - - response = boto3.client("organizations").create_organization(FeatureSet="ALL") - response["Organization"]["Id"] - create_account_response = boto3.client("organizations").create_account( - Email="test-email@example.com", AccountName="test" - ) - account_id = create_account_response["CreateAccountStatus"]["AccountId"] - mock_application.environments["development"].account_id = account_id - mock_application.environments["development"].sessions[account_id] = boto3.session.Session() - vpc = vpc if vpc else "test" - boto3.client("ssm").put_parameter( - Name=f"/{vpc}/EGRESS_IPS", Value=param_value, Type="String" - ) - environment = mock_application.environments["development"] - result = get_env_ips(vpc, environment) - - assert result == expected - - @mock_aws - def test_get_env_ips_param_not_found(self, capsys, mock_application): - - response = boto3.client("organizations").create_organization(FeatureSet="ALL") - response["Organization"]["Id"] - create_account_response = boto3.client("organizations").create_account( - Email="test-email@example.com", AccountName="test" - ) - account_id = create_account_response["CreateAccountStatus"]["AccountId"] - mock_application.environments["development"].account_id = account_id - mock_application.environments["development"].sessions[account_id] = boto3.session.Session() - environment = mock_application.environments["development"] - - with pytest.raises(click.Abort): - get_env_ips("vpc", environment) - - captured = capsys.readouterr() - - assert "No parameter found with name: /vpc/EGRESS_IPS\n" in captured.out - - @patch("boto3.client") - def test_get_rules_tag_descriptions(self, mock_boto_client): - - mock_client = Mock() - mock_client.describe_tags.side_effect = [ - {"TagDescriptions": ["TagDescriptions1"]}, - {"TagDescriptions": ["TagDescriptions2"]}, - ] - - mock_boto_client.return_value = mock_client - - rules = [] - - for i in range(21): - rules.append({"RuleArn": i}) - - tag_descriptions = get_rules_tag_descriptions(rules, boto3.client("elbv2")) - - assert tag_descriptions == ["TagDescriptions1", "TagDescriptions2"] - assert mock_client.describe_tags.call_count == 2 diff --git a/tests/platform_helper/test_command_pipeline.py b/tests/platform_helper/test_command_pipeline.py index 3b59f86ef..b684cb2da 100644 --- a/tests/platform_helper/test_command_pipeline.py +++ b/tests/platform_helper/test_command_pipeline.py @@ -3,12 +3,17 @@ from unittest.mock import Mock from unittest.mock import patch +import pytest import yaml from click.testing import CliRunner from freezegun.api import freeze_time from dbt_platform_helper.commands.pipeline import CODEBASE_PIPELINES_KEY +from dbt_platform_helper.commands.pipeline import ( + _determine_terraform_platform_modules_version, +) from dbt_platform_helper.commands.pipeline import generate +from dbt_platform_helper.constants import DEFAULT_TERRAFORM_PLATFORM_MODULES_VERSION from dbt_platform_helper.constants import PLATFORM_CONFIG_FILE from tests.platform_helper.conftest import EXPECTED_FILES_DIR from tests.platform_helper.conftest import FIXTURES_DIR @@ -354,6 +359,132 @@ def test_pipeline_generate_without_accounts_creates_the_pipeline_configuration( assert_codebase_pipeline_config_was_generated() +def assert_terraform(app_name, aws_account, expected_version, expected_branch): + expected_files_dir = Path(f"terraform/environment-pipelines/{aws_account}/main.tf") + assert expected_files_dir.exists() + content = expected_files_dir.read_text() + print(content) + + assert "# WARNING: This is an autogenerated file, not for manual editing." in content + assert "# Generated by platform-helper v0.1-TEST / 2024-10-28 12:00:00" in content + assert f'profile = "{aws_account}"' in content + assert ( + f"git::https://github.com/uktrade/terraform-platform-modules.git//environment-pipelines?depth=1&ref={expected_version}" + in content + ) + assert f'application = "{app_name}"' in content + expected_branch_value = expected_branch if expected_branch else "each.value.branch" + assert f"branch = {expected_branch_value} in content" + + +@freeze_time("2024-10-28 12:00:00") +@patch("dbt_platform_helper.jinja2_tags.version", new=Mock(return_value="v0.1-TEST")) +@patch("dbt_platform_helper.utils.aws.get_aws_session_or_abort") +@patch("dbt_platform_helper.utils.validation.get_aws_session_or_abort") +@patch("dbt_platform_helper.commands.pipeline.git_remote", return_value="uktrade/test-app-deploy") +@pytest.mark.parametrize( + "cli_terraform_platform_version, config_terraform_platform_version, expected_terraform_platform_version, cli_demodjango_branch, expected_demodjango_branch", + [ # config_terraform_platform_version sets the platform-config.yml to include the TPM version at platform-config.yml/default_versions/terraform-platform-modules + ("7", True, "7", None, None), # Case with cli_terraform_platform_version + ( + None, + True, + "4.0.0", + "demodjango-branch", + "demodjango-branch", + ), # Case with config_terraform_platform_version and specific branch + (None, True, "4.0.0", None, None), # Case with config_terraform_platform_version + (None, None, "5", None, None), # Case with default TPM version and without branch, defaults + ], +) +def test_generate_pipeline_command_generate_terraform_files_for_environment_pipeline_manifest( + git_remote, + get_aws_command_or_abort, + mock_aws_session, + fakefs, + cli_terraform_platform_version, + config_terraform_platform_version, + expected_terraform_platform_version, + cli_demodjango_branch, + expected_demodjango_branch, +): + + app_name = "test-app" + mock_codestar_connections_boto_client(mock_aws_session, [app_name]) + + if config_terraform_platform_version: + setup_fixtures( + fakefs, + pipelines_file="pipeline/platform-config-for-terraform-environment-pipelines-with-tpm-version.yml", + ) + else: + setup_fixtures( + fakefs, + pipelines_file="pipeline/platform-config-for-terraform-environment-pipelines.yml", + ) + + args = [] + if cli_terraform_platform_version: + args.extend(["--terraform-platform-modules-version", cli_terraform_platform_version]) + if cli_demodjango_branch: + args.extend(["--deploy-branch", cli_demodjango_branch]) + + CliRunner().invoke(generate, args=args) + + assert_terraform( + app_name, + "platform-sandbox-test", + expected_terraform_platform_version, + expected_demodjango_branch, + ) + assert_terraform( + app_name, + "platform-prod-test", + expected_terraform_platform_version, + expected_demodjango_branch, + ) + + +@freeze_time("2024-10-28 12:00:00") +@patch("dbt_platform_helper.jinja2_tags.version", new=Mock(return_value="v0.1-TEST")) +@patch("dbt_platform_helper.utils.aws.get_aws_session_or_abort") +@patch("dbt_platform_helper.utils.validation.get_aws_session_or_abort") +@patch("dbt_platform_helper.commands.pipeline.git_remote", return_value="uktrade/test-app-deploy") +def test_generate_pipeline_command_doesnt_generate_terraform_files_if_legacy_project( + git_remote, + get_aws_command_or_abort, + mock_aws_session, + fakefs, +): + app_name = "test-app" + mock_codestar_connections_boto_client(mock_aws_session, [app_name]) + setup_fixtures(fakefs, pipelines_file="pipeline/platform-config-legacy-project.yml") + CliRunner().invoke(generate, args=[]) + + for aws_account in ["platform-sandbox-test", "platform-prod-test"]: + expected_files_dir = Path(f"terraform/environment-pipelines/{aws_account}/main.tf") + assert not expected_files_dir.exists() + + +@pytest.mark.parametrize( + "cli_terraform_platform_version, config_terraform_platform_version, expected_version", + [ + ("feature_branch", "5", "feature_branch"), + (None, "5", "5"), + (None, None, DEFAULT_TERRAFORM_PLATFORM_MODULES_VERSION), + ], +) +def test_determine_terraform_platform_modules_version( + cli_terraform_platform_version, config_terraform_platform_version, expected_version +): + assert ( + _determine_terraform_platform_modules_version( + cli_terraform_platform_version, config_terraform_platform_version + ) + == expected_version + ) + + def assert_yaml_in_output_file_matches_expected(output_file, expected_file): def get_yaml(content): return yaml.safe_load(content) diff --git a/tests/platform_helper/test_database_helpers.py b/tests/platform_helper/test_database_helpers.py deleted file mode 100644 index b609c1d40..000000000 --- a/tests/platform_helper/test_database_helpers.py +++ /dev/null @@ -1,351 +0,0 @@ -from unittest.mock import Mock -from unittest.mock import call - -import pytest - -from dbt_platform_helper.commands.database_helpers import DatabaseCopy -from dbt_platform_helper.commands.database_helpers import run_database_copy_task -from dbt_platform_helper.utils.aws import Vpc - - -@pytest.mark.parametrize("is_dump, exp_operation", [(True, "dump"), (False, "load")]) -def test_run_database_copy_task(is_dump, exp_operation): - mock_client = Mock() - mock_session = Mock() - mock_session.client.return_value = mock_client - mock_client.run_task.return_value = {"tasks": [{"taskArn": "arn:aws:ecs:test-task-arn"}]} - - account_id = "1234567" - app = "my_app" - env = "my_env" - database = "my_postgres" - vpc_config = Vpc(["subnet_1", "subnet_2"], ["sec_group_1"]) - db_connection_string = "connection_string" - - actual_task_arn = run_database_copy_task( - mock_session, account_id, app, env, database, vpc_config, is_dump, db_connection_string - ) - - assert actual_task_arn == "arn:aws:ecs:test-task-arn" - - mock_session.client.assert_called_once_with("ecs") - mock_client.run_task.assert_called_once_with( - taskDefinition=f"arn:aws:ecs:eu-west-2:1234567:task-definition/my_app-my_env-my_postgres-{exp_operation}", - cluster="my_app-my_env", - capacityProviderStrategy=[ - {"capacityProvider": "FARGATE", "weight": 1, "base": 0}, - ], - networkConfiguration={ - "awsvpcConfiguration": { - "subnets": ["subnet_1", "subnet_2"], - "securityGroups": [ - "sec_group_1", - ], - "assignPublicIp": "DISABLED", - } - }, - overrides={ - "containerOverrides": [ - { - "name": f"my_app-my_env-my_postgres-{exp_operation}", - "environment": [ - {"name": "DATA_COPY_OPERATION", "value": exp_operation.upper()}, - {"name": "DB_CONNECTION_STRING", "value": "connection_string"}, - ], - } - ] - }, - ) - - -def test_database_dump(): - app = "my-app" - env = "my-env" - vpc_name = "test-vpc" - database = "test-db" - - account_id = "1234567" - - mock_session = Mock() - mock_session_fn = Mock(return_value=mock_session) - mock_run_database_copy_task_fn = Mock(return_value="arn://task-arn") - - vpc = Vpc([], []) - mock_vpc_config_fn = Mock() - mock_vpc_config_fn.return_value = vpc - mock_db_connection_string_fn = Mock(return_value="test-db-connection-string") - - mock_input_fn = Mock(return_value="yes") - mock_echo_fn = Mock() - - db_copy = DatabaseCopy( - account_id, - app, - env, - database, - vpc_name, - mock_session_fn, - mock_run_database_copy_task_fn, - mock_vpc_config_fn, - mock_db_connection_string_fn, - mock_input_fn, - mock_echo_fn, - ) - - db_copy.wait_for_task_to_stop = Mock() - db_copy.tail_logs = Mock() - - db_copy.dump() - - mock_session_fn.assert_called_once() - - mock_vpc_config_fn.assert_called_once_with(mock_session, app, env, vpc_name) - - mock_db_connection_string_fn.assert_called_once_with( - mock_session, app, env, "my-app-my-env-test-db" - ) - - mock_run_database_copy_task_fn.assert_called_once_with( - mock_session, account_id, app, env, database, vpc, True, "test-db-connection-string" - ) - - mock_input_fn.assert_not_called() - mock_echo_fn.assert_called_once_with( - "Task arn://task-arn started. Waiting for it to complete (this may take some time)...", - fg="green", - ) - db_copy.wait_for_task_to_stop.assert_called_once_with("arn://task-arn") - db_copy.tail_logs.assert_called_once_with(True) - - -def test_database_load_with_response_of_yes(): - app = "my-app" - env = "my-env" - vpc_name = "test-vpc" - database = "test-db" - - account_id = "1234567" - - mock_session = Mock() - mock_session_fn = Mock(return_value=mock_session) - mock_run_database_copy_task_fn = Mock(return_value="arn://task-arn") - - vpc = Vpc([], []) - mock_vpc_config_fn = Mock() - mock_vpc_config_fn.return_value = vpc - mock_db_connection_string_fn = Mock(return_value="test-db-connection-string") - - mock_input_fn = Mock(return_value="yes") - mock_echo_fn = Mock() - - db_copy = DatabaseCopy( - account_id, - app, - env, - database, - vpc_name, - mock_session_fn, - mock_run_database_copy_task_fn, - mock_vpc_config_fn, - mock_db_connection_string_fn, - mock_input_fn, - mock_echo_fn, - ) - db_copy.wait_for_task_to_stop = Mock() - db_copy.tail_logs = Mock() - - db_copy.load() - - mock_session_fn.assert_called_once() - - mock_vpc_config_fn.assert_called_once_with(mock_session, app, env, vpc_name) - - mock_db_connection_string_fn.assert_called_once_with( - mock_session, app, env, "my-app-my-env-test-db" - ) - - mock_run_database_copy_task_fn.assert_called_once_with( - mock_session, account_id, app, env, database, vpc, False, "test-db-connection-string" - ) - - mock_input_fn.assert_called_once_with( - f"Are all tasks using test-db in the my-env environment stopped? (y/n)" - ) - - mock_echo_fn.assert_called_once_with( - "Task arn://task-arn started. Waiting for it to complete (this may take some time)...", - fg="green", - ) - db_copy.wait_for_task_to_stop.assert_called_once_with("arn://task-arn") - db_copy.tail_logs.assert_called_once_with(False) - - -def test_database_load_with_response_of_no(): - app = "my-app" - env = "my-env" - vpc_name = "test-vpc" - database = "test-db" - - account_id = "1234567" - - mock_session = Mock() - mock_session_fn = Mock(return_value=mock_session) - mock_run_database_copy_task_fn = Mock() - - vpc = Vpc([], []) - mock_vpc_config_fn = Mock() - mock_vpc_config_fn.return_value = vpc - mock_db_connection_string_fn = Mock(return_value="test-db-connection-string") - - mock_input_fn = Mock(return_value="no") - mock_echo_fn = Mock() - - db_copy = DatabaseCopy( - account_id, - app, - env, - database, - vpc_name, - mock_session_fn, - mock_run_database_copy_task_fn, - mock_vpc_config_fn, - mock_db_connection_string_fn, - mock_input_fn, - mock_echo_fn, - ) - db_copy.tail_logs = Mock() - - db_copy.load() - - mock_session_fn.assert_not_called() - - mock_vpc_config_fn.assert_not_called() - - mock_db_connection_string_fn.assert_not_called() - - mock_run_database_copy_task_fn.assert_not_called() - - mock_input_fn.assert_called_once_with( - f"Are all tasks using test-db in the my-env environment stopped? (y/n)" - ) - mock_echo_fn.assert_not_called() - db_copy.tail_logs.assert_not_called() - - -@pytest.mark.parametrize("user_response", ["y", "Y", " y ", "\ny", "YES", "yes"]) -def test_is_confirmed_ready_to_load(user_response): - mock_input = Mock() - mock_input.return_value = user_response - db_copy = DatabaseCopy("", "", "test-env", "test-db", "", None, None, None, None, mock_input) - - assert db_copy.is_confirmed_ready_to_load() - - mock_input.assert_called_once_with( - f"Are all tasks using test-db in the test-env environment stopped? (y/n)" - ) - - -@pytest.mark.parametrize("user_response", ["n", "N", " no ", "squiggly"]) -def test_is_not_confirmed_ready_to_load(user_response): - mock_input = Mock() - mock_input.return_value = user_response - db_copy = DatabaseCopy( - None, None, "test-env", "test-db", None, None, None, None, None, mock_input - ) - - assert not db_copy.is_confirmed_ready_to_load() - - mock_input.assert_called_once_with( - f"Are all tasks using test-db in the test-env environment stopped? (y/n)" - ) - - -def test_wait_for_task_to_stop(): - mock_session = Mock() - mock_session_fn = Mock(return_value=mock_session) - mock_client = Mock() - mock_session.client.return_value = mock_client - mock_waiter = Mock() - mock_client.get_waiter.return_value = mock_waiter - mock_echo = Mock() - - db_copy = DatabaseCopy( - None, - "test-app", - "test-env", - "test-db", - None, - mock_session_fn, - None, - None, - None, - None, - mock_echo, - ) - - db_copy.wait_for_task_to_stop("arn://the-task-arn") - - mock_session.client.assert_called_once_with("ecs") - mock_client.get_waiter.assert_called_once_with("tasks_stopped") - mock_waiter.wait.assert_called_once_with( - cluster="test-app-test-env", - tasks=["arn://the-task-arn"], - WaiterConfig={"Delay": 6, "MaxAttempts": 300}, - ) - mock_echo.assert_has_calls( - [ - call("Waiting for task to complete", fg="yellow"), - ] - ) - - -@pytest.mark.parametrize("is_dump", [True, False]) -def test_tail_logs(is_dump): - action = "dump" if is_dump else "load" - mock_session = Mock() - mock_session_fn = Mock(return_value=mock_session) - mock_client = Mock() - mock_session.client.return_value = mock_client - - mock_client.start_live_tail.return_value = { - "responseStream": [ - {"sessionStart": {}}, - {"sessionUpdate": {"sessionResults": []}}, - {"sessionUpdate": {"sessionResults": [{"message": ""}]}}, - {"sessionUpdate": {"sessionResults": [{"message": f"Starting data {action}"}]}}, - {"sessionUpdate": {"sessionResults": [{"message": "A load of SQL shenanigans"}]}}, - {"sessionUpdate": {"sessionResults": [{"message": f"Stopping data {action}"}]}}, - ] - } - mock_echo = Mock() - - db_copy = DatabaseCopy( - "1234", - "test-app", - "test-env", - "test-db", - None, - mock_session_fn, - None, - None, - None, - None, - echo_fn=mock_echo, - ) - db_copy.tail_logs(is_dump) - - mock_session.client.assert_called_once_with("logs") - mock_client.start_live_tail.assert_called_once_with( - logGroupIdentifiers=[ - f"arn:aws:logs:eu-west-2:1234:log-group:/ecs/test-app-test-env-test-db-{action}" - ], - ) - - mock_echo.assert_has_calls( - [ - call(f"Tailing logs for /ecs/test-app-test-env-test-db-{action}", fg="yellow"), - call(f"Starting data {action}"), - call("A load of SQL shenanigans"), - call(f"Stopping data {action}"), - ] - ) diff --git a/tests/platform_helper/utils/test_aws.py b/tests/platform_helper/utils/test_aws.py index ece25dfe6..e248fb7b8 100644 --- a/tests/platform_helper/utils/test_aws.py +++ b/tests/platform_helper/utils/test_aws.py @@ -37,6 +37,9 @@ COPILOT_IDENTIFIER = "c0PIlotiD3ntIF3r" CLUSTER_NAME_SUFFIX = f"Cluster-{COPILOT_IDENTIFIER}" SERVICE_NAME_SUFFIX = f"Service-{COPILOT_IDENTIFIER}" +REFRESH_TOKEN_MESSAGE = ( + "To refresh this SSO session run `aws sso login` with the corresponding profile" +) def test_get_aws_session_or_abort_profile_not_configured(clear_session_cache, capsys): @@ -91,20 +94,54 @@ def test_get_ssm_secrets(mock_get_aws_session_or_abort): assert result == [("/copilot/test-application/development/secrets/TEST_SECRET", "test value")] +@pytest.mark.parametrize( + "aws_profile, side_effect, expected_error_message", + [ + ( + "existing_profile", + botocore.exceptions.NoCredentialsError( + error_msg="There are no credentials set for this session." + ), + f"There are no credentials set for this session. {REFRESH_TOKEN_MESSAGE}", + ), + ( + "existing_profile", + botocore.exceptions.UnauthorizedSSOTokenError( + error_msg="The SSO Token used for this session is unauthorised." + ), + f"The SSO Token used for this session is unauthorised. {REFRESH_TOKEN_MESSAGE}", + ), + ( + "existing_profile", + botocore.exceptions.TokenRetrievalError( + error_msg="Unable to retrieve the Token for this session.", provider="sso" + ), + f"Unable to retrieve the Token for this session. {REFRESH_TOKEN_MESSAGE}", + ), + ( + "existing_profile", + botocore.exceptions.SSOTokenLoadError( + error_msg="The SSO session associated with this profile has expired, is not set or is otherwise invalid." + ), + f"The SSO session associated with this profile has expired, is not set or is otherwise invalid. {REFRESH_TOKEN_MESSAGE}", + ), + ], +) @patch("dbt_platform_helper.utils.aws.get_account_details") @patch("boto3.session.Session") @patch("click.secho") -def test_get_aws_session_or_abort_with_invalid_credentials( - mock_secho, mock_session, mock_get_account_details +def test_get_aws_session_or_abort_errors( + mock_secho, + mock_session, + mock_get_account_details, + aws_profile, + side_effect, + expected_error_message, ): - aws_profile = "existing_profile" - expected_error_message = ( - "The SSO session associated with this profile has expired or is otherwise invalid." - + "To refresh this SSO session run `aws sso login` with the corresponding profile" - ) - mock_get_account_details.side_effect = botocore.exceptions.SSOTokenLoadError( - error_msg=expected_error_message - ) + if isinstance(side_effect, botocore.exceptions.ProfileNotFound): + mock_session.side_effect = side_effect + else: + mock_get_account_details.side_effect = side_effect with pytest.raises(SystemExit) as exc_info: get_aws_session_or_abort(aws_profile=aws_profile)