diff --git a/providers/src/airflow/providers/amazon/aws/links/ec2.py b/providers/src/airflow/providers/amazon/aws/links/ec2.py new file mode 100644 index 0000000000000..38a23956cddbb --- /dev/null +++ b/providers/src/airflow/providers/amazon/aws/links/ec2.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink + + +class EC2InstanceLink(BaseAwsLink): + """Helper class for constructing Amazon EC2 instance links.""" + + name = "Instance" + key = "_instance_id" + format_str = ( + BASE_AWS_CONSOLE_LINK + "/ec2/home?region={region_name}#InstanceDetails:instanceId={instance_id}" + ) + + +class EC2InstanceDashboardLink(BaseAwsLink): + """ + Helper class for constructing Amazon EC2 console links. + + This is useful for displaying the list of EC2 instances, rather + than a single instance. + """ + + name = "EC2 Instances" + key = "_instance_dashboard" + format_str = BASE_AWS_CONSOLE_LINK + "/ec2/home?region={region_name}#Instances:instanceId=:{instance_ids}" + + @staticmethod + def format_instance_id_filter(instance_ids: list[str]) -> str: + return ",:".join(instance_ids) diff --git a/providers/src/airflow/providers/amazon/aws/operators/ec2.py b/providers/src/airflow/providers/amazon/aws/operators/ec2.py index 5b25b27fd0555..f3d0e9fc2af25 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/ec2.py +++ b/providers/src/airflow/providers/amazon/aws/operators/ec2.py @@ -23,6 +23,10 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook +from airflow.providers.amazon.aws.links.ec2 import ( + EC2InstanceDashboardLink, + EC2InstanceLink, +) if TYPE_CHECKING: from airflow.utils.context import Context @@ -47,6 +51,7 @@ class EC2StartInstanceOperator(BaseOperator): between each instance state checks until operation is completed """ + operator_extra_links = (EC2InstanceLink(),) template_fields: Sequence[str] = ("instance_id", "region_name") ui_color = "#eeaa11" ui_fgcolor = "#ffffff" @@ -71,6 +76,13 @@ def execute(self, context: Context): self.log.info("Starting EC2 instance %s", self.instance_id) instance = ec2_hook.get_instance(instance_id=self.instance_id) instance.start() + EC2InstanceLink.persist( + context=context, + operator=self, + aws_partition=ec2_hook.conn_partition, + instance_id=self.instance_id, + region_name=ec2_hook.conn_region_name, + ) ec2_hook.wait_for_state( instance_id=self.instance_id, target_state="running", @@ -97,6 +109,7 @@ class EC2StopInstanceOperator(BaseOperator): between each instance state checks until operation is completed """ + operator_extra_links = (EC2InstanceLink(),) template_fields: Sequence[str] = ("instance_id", "region_name") ui_color = "#eeaa11" ui_fgcolor = "#ffffff" @@ -120,7 +133,15 @@ def execute(self, context: Context): ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) self.log.info("Stopping EC2 instance %s", self.instance_id) instance = ec2_hook.get_instance(instance_id=self.instance_id) + EC2InstanceLink.persist( + context=context, + operator=self, + aws_partition=ec2_hook.conn_partition, + instance_id=self.instance_id, + region_name=ec2_hook.conn_region_name, + ) instance.stop() + ec2_hook.wait_for_state( instance_id=self.instance_id, target_state="stopped", @@ -154,6 +175,7 @@ class EC2CreateInstanceOperator(BaseOperator): in the `running` state before returning. """ + operator_extra_links = (EC2InstanceDashboardLink(),) template_fields: Sequence[str] = ( "image_id", "max_count", @@ -198,6 +220,15 @@ def execute(self, context: Context): )["Instances"] instance_ids = self._on_kill_instance_ids = [instance["InstanceId"] for instance in instances] + # Console link is for EC2 dashboard list, not individual instances when more than 1 instance + + EC2InstanceDashboardLink.persist( + context=context, + operator=self, + region_name=ec2_hook.conn_region_name, + aws_partition=ec2_hook.conn_partition, + instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(instance_ids), + ) for instance_id in instance_ids: self.log.info("Created EC2 instance %s", instance_id) @@ -311,6 +342,7 @@ class EC2RebootInstanceOperator(BaseOperator): in the `running` state before returning. """ + operator_extra_links = (EC2InstanceDashboardLink(),) template_fields: Sequence[str] = ("instance_ids", "region_name") ui_color = "#eeaa11" ui_fgcolor = "#ffffff" @@ -341,6 +373,14 @@ def execute(self, context: Context): self.log.info("Rebooting EC2 instances %s", ", ".join(self.instance_ids)) ec2_hook.conn.reboot_instances(InstanceIds=self.instance_ids) + # Console link is for EC2 dashboard list, not individual instances + EC2InstanceDashboardLink.persist( + context=context, + operator=self, + region_name=ec2_hook.conn_region_name, + aws_partition=ec2_hook.conn_partition, + instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(self.instance_ids), + ) if self.wait_for_completion: ec2_hook.get_waiter("instance_running").wait( InstanceIds=self.instance_ids, @@ -374,6 +414,7 @@ class EC2HibernateInstanceOperator(BaseOperator): in the `stopped` state before returning. """ + operator_extra_links = (EC2InstanceDashboardLink(),) template_fields: Sequence[str] = ("instance_ids", "region_name") ui_color = "#eeaa11" ui_fgcolor = "#ffffff" @@ -404,6 +445,15 @@ def execute(self, context: Context): self.log.info("Hibernating EC2 instances %s", ", ".join(self.instance_ids)) instances = ec2_hook.get_instances(instance_ids=self.instance_ids) + # Console link is for EC2 dashboard list, not individual instances + EC2InstanceDashboardLink.persist( + context=context, + operator=self, + region_name=ec2_hook.conn_region_name, + aws_partition=ec2_hook.conn_partition, + instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(self.instance_ids), + ) + for instance in instances: hibernation_options = instance.get("HibernationOptions") if not hibernation_options or not hibernation_options["Configured"]: diff --git a/providers/src/airflow/providers/amazon/provider.yaml b/providers/src/airflow/providers/amazon/provider.yaml index 43569a28827ab..5192052800079 100644 --- a/providers/src/airflow/providers/amazon/provider.yaml +++ b/providers/src/airflow/providers/amazon/provider.yaml @@ -891,7 +891,8 @@ extra-links: - airflow.providers.amazon.aws.links.comprehend.ComprehendDocumentClassifierLink - airflow.providers.amazon.aws.links.datasync.DataSyncTaskLink - airflow.providers.amazon.aws.links.datasync.DataSyncTaskExecutionLink - + - airflow.providers.amazon.aws.links.ec2.EC2InstanceLink + - airflow.providers.amazon.aws.links.ec2.EC2InstanceDashboardLink connection-types: - hook-class-name: airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook diff --git a/providers/tests/amazon/aws/links/test_ec2.py b/providers/tests/amazon/aws/links/test_ec2.py new file mode 100644 index 0000000000000..922b12275e5aa --- /dev/null +++ b/providers/tests/amazon/aws/links/test_ec2.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.links.ec2 import EC2InstanceDashboardLink, EC2InstanceLink + +from providers.tests.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase + + +class TestEC2InstanceLink(BaseAwsLinksTestCase): + link_class = EC2InstanceLink + + INSTANCE_ID = "i-xxxxxxxxxxxx" + + def test_extra_link(self): + self.assert_extra_link_url( + expected_url=( + "https://console.aws.amazon.com/ec2/home" + f"?region=eu-west-1#InstanceDetails:instanceId={self.INSTANCE_ID}" + ), + region_name="eu-west-1", + aws_partition="aws", + instance_id=self.INSTANCE_ID, + ) + + +class TestEC2InstanceDashboardLink(BaseAwsLinksTestCase): + link_class = EC2InstanceDashboardLink + + BASE_URL = "https://console.aws.amazon.com/ec2/home" + INSTANCE_IDS = ["i-xxxxxxxxxxxx", "i-yyyyyyyyyyyy"] + + def test_instance_id_filter(self): + instance_list = ",:".join(self.INSTANCE_IDS) + result = EC2InstanceDashboardLink.format_instance_id_filter(self.INSTANCE_IDS) + assert result == instance_list + + def test_extra_link(self): + instance_list = ",:".join(self.INSTANCE_IDS) + self.assert_extra_link_url( + expected_url=(f"{self.BASE_URL}?region=eu-west-1#Instances:instanceId=:{instance_list}"), + region_name="eu-west-1", + aws_partition="aws", + instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(self.INSTANCE_IDS), + )