diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index ff5694ca01a9c..1d7165442451c 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -388,6 +388,10 @@ class EcsRunTaskOperator(EcsBaseOperator): AirflowException if an ECS task is stopped (to receive Airflow alerts with the logs of what failed in the code running in ECS). :param wait_for_completion: If True, waits for creation of the cluster to complete. (default: True) + :param waiter_delay: The amount of time in seconds to wait between attempts, + if not set then the default waiter value will be used. + :param waiter_max_attempts: The maximum number of attempts to be made, + if not set then the default waiter value will be used. """ ui_color = "#f0ede4" @@ -443,6 +447,8 @@ def __init__( reattach: bool = False, number_logs_exception: int = 10, wait_for_completion: bool = True, + waiter_delay: int | None = None, + waiter_max_attempts: int | None = None, **kwargs, ): super().__init__(**kwargs) @@ -474,6 +480,8 @@ def __init__( self.retry_args = quota_retry self.task_log_fetcher: EcsTaskLogFetcher | None = None self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts @provide_session def execute(self, context, session=None): @@ -596,7 +604,16 @@ def _wait_for_task_ended(self) -> None: waiter = self.client.get_waiter("tasks_stopped") waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow - waiter.wait(cluster=self.cluster, tasks=[self.arn]) + waiter.wait( + cluster=self.cluster, + tasks=[self.arn], + WaiterConfig=prune_dict( + { + "Delay": self.waiter_delay, + "MaxAttempts": self.waiter_max_attempts, + } + ), + ) return diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index 2908cdb264e56..cadaa6e329462 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -339,7 +339,9 @@ def test_wait_end_tasks(self, client_mock): self.ecs._wait_for_task_ended() client_mock.get_waiter.assert_called_once_with("tasks_stopped") - client_mock.get_waiter.return_value.wait.assert_called_once_with(cluster="c", tasks=["arn"]) + client_mock.get_waiter.return_value.wait.assert_called_once_with( + cluster="c", tasks=["arn"], WaiterConfig={} + ) assert sys.maxsize == client_mock.get_waiter.return_value.config.max_attempts @mock.patch.object(EcsBaseOperator, "client")