Skip to content

Commit

Permalink
Update system test example_dms_serverless
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck committed Feb 12, 2025
1 parent 340c994 commit bdc8114
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -493,10 +493,6 @@ def execute(self, context: Context) -> None:
Filters=[{"Name": "replication-config-arn", "Values": [self.replication_config_arn]}],
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
)
self.hook.get_waiter("replication_deprovisioned").wait(
Filters=[{"Name": "replication-config-arn", "Values": [self.replication_config_arn]}],
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
)
self.hook.delete_replication_config(self.replication_config_arn)
self.handle_delete_wait()

Expand Down
12 changes: 12 additions & 0 deletions providers/amazon/src/airflow/providers/amazon/aws/waiters/dms.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@
"argument": "Replications[0].Status",
"expected": "stopped",
"state": "success"
},
{
"matcher": "path",
"argument": "Replications[0].Status",
"expected": "created",
"state": "success"
},
{
"matcher": "path",
"argument": "Replications[0].ProvisionData.ProvisionState",
"expected": "deprovisioned",
"state": "success"
}
]
},
Expand Down
129 changes: 10 additions & 119 deletions providers/amazon/tests/system/amazon/aws/example_dms_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import boto3
from providers.amazon.tests.system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder
from providers.amazon.tests.system.amazon.aws.utils.ec2 import get_default_vpc_id
from sqlalchemy import Column, MetaData, String, Table, create_engine

from airflow.decorators import task
Expand All @@ -38,8 +37,6 @@
DmsDeleteReplicationConfigOperator,
DmsDescribeReplicationConfigsOperator,
DmsDescribeReplicationsOperator,
DmsStartReplicationOperator,
DmsStopReplicationOperator,
)
from airflow.providers.amazon.aws.operators.rds import (
RdsCreateDbInstanceOperator,
Expand Down Expand Up @@ -76,11 +73,6 @@
("Subversion", "2000"),
("NiFi", "2006"),
]
SG_IP_PERMISSION = {
"FromPort": 5432,
"IpProtocol": "All",
"IpRanges": [{"CidrIp": "0.0.0.0/0"}],
}


def _get_rds_instance_endpoint(instance_name: str):
Expand All @@ -92,25 +84,6 @@ def _get_rds_instance_endpoint(instance_name: str):
return rds_instance_endpoint


@task
def create_security_group(security_group_name: str, vpc_id: str):
client = boto3.client("ec2")
security_group = client.create_security_group(
GroupName=security_group_name,
Description="Created for DMS system test",
VpcId=vpc_id,
)
client.get_waiter("security_group_exists").wait(
GroupIds=[security_group["GroupId"]],
)
client.authorize_security_group_ingress(
GroupId=security_group["GroupId"],
IpPermissions=[SG_IP_PERMISSION],
)

return security_group["GroupId"]


@task
def create_sample_table(instance_name: str, db_name: str, table_name: str):
print("Creating sample table.")
Expand Down Expand Up @@ -138,37 +111,6 @@ def create_sample_table(instance_name: str, db_name: str, table_name: str):
connection.execute(table.select())


@task(trigger_rule=TriggerRule.ALL_SUCCESS)
def create_vpc_endpoints(vpc_id: str):
print("Creating VPC endpoints in vpc: %s", vpc_id)
client = boto3.client("ec2")
session = boto3.session.Session()
region = session.region_name
route_tbls = client.describe_route_tables(Filters=[{"Name": "vpc-id", "Values": [vpc_id]}])
endpoints = client.create_vpc_endpoint(
VpcId=vpc_id,
ServiceName=f"com.amazonaws.{region}.s3",
VpcEndpointType="Gateway",
RouteTableIds=[tbl["RouteTableId"] for tbl in route_tbls["RouteTables"]],
)

return endpoints.get("VpcEndpoint", {}).get("VpcEndpointId")


@task(trigger_rule=TriggerRule.ALL_DONE)
def delete_vpc_endpoints(endpoint_ids: list[str]):
if len(endpoint_ids) == 0:
print("No VPC endpoints to delete.")
return

print("Deleting VPC endpoints.")
client = boto3.client("ec2")

client.delete_vpc_endpoints(VpcEndpointIds=endpoint_ids, DryRun=False)

print("Deleted endpoints: %s", endpoint_ids)


@task(multiple_outputs=True)
def create_dms_assets(
db_name: str,
Expand Down Expand Up @@ -223,20 +165,8 @@ def delete_dms_assets(
target_endpoint_identifier: str,
):
dms_client = boto3.client("dms")

print("Deleting DMS assets.")

print(source_endpoint_arn)
print(target_endpoint_arn)

try:
dms_client.delete_endpoint(EndpointArn=source_endpoint_arn)
dms_client.delete_endpoint(EndpointArn=target_endpoint_arn)
except Exception as ex:
print("Exception while cleaning up endpoints:%s", ex)

print("Awaiting DMS assets tear-down.")

dms_client.delete_endpoint(EndpointArn=source_endpoint_arn)
dms_client.delete_endpoint(EndpointArn=target_endpoint_arn)
dms_client.get_waiter("endpoint_deleted").wait(
Filters=[
{
Expand All @@ -247,44 +177,26 @@ def delete_dms_assets(
)


@task(trigger_rule=TriggerRule.ALL_DONE)
def delete_security_group(security_group_id: str, security_group_name: str):
boto3.client("ec2").delete_security_group(GroupId=security_group_id, GroupName=security_group_name)


# setup
# source: aurora serverless
# dest: S3
# S3

with DAG(
dag_id=DAG_ID,
schedule="@once",
start_date=datetime(2021, 1, 1),
tags=["example"],
catchup=False,
) as dag:
test_context = sys_test_context_task()
env_id = test_context[ENV_ID_KEY]
role_arn = test_context[ROLE_ARN_KEY]

bucket_name = f"{env_id}-dms-bucket"
bucket_name = f"{env_id}-dms-serverless-bucket"
rds_instance_name = f"{env_id}-instance"
rds_db_name = f"{env_id}_source_database" # dashes are not allowed in db name
rds_table_name = f"{env_id}-table"
dms_replication_instance_name = f"{env_id}-replication-instance"
dms_replication_task_id = f"{env_id}-replication-task"
source_endpoint_identifier = f"{env_id}-source-endpoint"
target_endpoint_identifier = f"{env_id}-target-endpoint"
security_group_name = f"{env_id}-dms-security-group"
replication_id = f"{env_id}-replication-id"

create_s3_bucket = S3CreateBucketOperator(task_id="create_s3_bucket", bucket_name=bucket_name)

get_vpc_id = get_default_vpc_id()

create_sg = create_security_group(security_group_name, get_vpc_id)

create_db_instance = RdsCreateDbInstanceOperator(
task_id="create_db_instance",
db_instance_identifier=rds_instance_name,
Expand All @@ -296,9 +208,6 @@ def delete_security_group(security_group_id: str, security_group_name: str):
"MasterUsername": RDS_USERNAME,
"MasterUserPassword": RDS_PASSWORD,
"PubliclyAccessible": True,
"VpcSecurityGroupIds": [
create_sg,
],
},
)

Expand Down Expand Up @@ -360,24 +269,24 @@ def delete_security_group(security_group_id: str, security_group_name: str):
},
replication_type="full-load",
table_mappings=json.dumps(table_mappings),
trigger_rule=TriggerRule.ALL_SUCCESS,
)
# [END howto_operator_dms_create_replication_config]

# [START howto_operator_dms_describe_replication_config]
describe_replication_configs = DmsDescribeReplicationConfigsOperator(
task_id="describe_replication_configs",
trigger_rule=TriggerRule.ALL_SUCCESS,
)
# [END howto_operator_dms_describe_replication_config]

# [START howto_operator_dms_serverless_describe_replication]
describe_replications = DmsDescribeReplicationsOperator(
task_id="describe_replications",
trigger_rule=TriggerRule.ALL_SUCCESS,
)
# [END howto_operator_dms_serverless_describe_replication]

# Comment the next two tasks because they take too much time to be run in the CI
# Keep them for documentation purposes
"""
# [START howto_operator_dms_serverless_start_replication]
replicate = DmsStartReplicationOperator(
task_id="replicate",
Expand All @@ -386,34 +295,28 @@ def delete_security_group(security_group_id: str, security_group_name: str):
wait_for_completion=True,
waiter_delay=60,
waiter_max_attempts=200,
trigger_rule=TriggerRule.ALL_SUCCESS,
deferrable=False,
)
# [END howto_operator_dms_serverless_start_replication]
# [START howto_operator_dms_serverless_stop_replication]
stop_relication = DmsStopReplicationOperator(
stop_replication = DmsStopReplicationOperator(
task_id="stop_replication",
replication_config_arn="{{ task_instance.xcom_pull(task_ids='create_replication_config', key='return_value') }}",
wait_for_completion=True,
waiter_delay=120,
waiter_max_attempts=200,
trigger_rule=TriggerRule.ALL_SUCCESS,
deferrable=False,
)
# [END howto_operator_dms_serverless_stop_replication]
"""

# [START howto_operator_dms_serverless_delete_replication_config]
delete_replication_config = DmsDeleteReplicationConfigOperator(
task_id="delete_replication_config",
wait_for_completion=True,
waiter_delay=60,
waiter_max_attempts=200,
deferrable=False,
replication_config_arn="{{ task_instance.xcom_pull(task_ids='create_replication_config', key='return_value') }}",
trigger_rule=TriggerRule.ALL_DONE,
)
# [END howto_operator_dms_serverless_delete_replication_config]
delete_replication_config.trigger_rule = TriggerRule.ALL_DONE

delete_assets = delete_dms_assets(
source_endpoint_arn=create_assets["source_endpoint_arn"],
Expand All @@ -440,31 +343,19 @@ def delete_security_group(security_group_id: str, security_group_name: str):

chain(
# TEST SETUP
test_context,
create_s3_bucket,
get_vpc_id,
create_sg,
create_db_instance,
create_sample_table(rds_instance_name, rds_db_name, rds_table_name),
create_vpc_endpoints(
vpc_id="{{ task_instance.xcom_pull(task_ids='get_default_vpc_id',key='return_value')}}"
),
create_assets,
# TEST BODY
create_replication_config,
describe_replication_configs,
replicate,
stop_relication,
describe_replications,
delete_replication_config,
# TEST TEARDOWN
delete_vpc_endpoints(
endpoint_ids=[
"{{ task_instance.xcom_pull(task_ids='create_vpc_endpoints', key='return_value') }}"
]
),
delete_assets,
delete_db_instance,
delete_security_group(create_sg, security_group_name),
delete_s3_bucket,
)

Expand Down

0 comments on commit bdc8114

Please sign in to comment.