Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a "force" option to emr serverless stop/delete operator #30757

Merged
merged 11 commits into from
Apr 24, 2023
40 changes: 35 additions & 5 deletions airflow/providers/amazon/aws/hooks/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from botocore.exceptions import ClientError

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.utils.helpers import prune_dict
Expand Down Expand Up @@ -253,10 +252,41 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["client_type"] = "emr-serverless"
super().__init__(*args, **kwargs)

@cached_property
def conn(self):
"""Get the underlying boto3 EmrServerlessAPIService client (cached)"""
return super().conn
def cancel_running_jobs(self, application_id: str, waiter_config: dict = {}):
"""
List all jobs in an intermediate state and cancel them.
Then wait for those jobs to reach a terminal state.
Note: if new jobs are triggered while this operation is ongoing,
it's going to time out and return an error.
"""
paginator = self.conn.get_paginator("list_job_runs")
results_per_response = 50
iterator = paginator.paginate(
applicationId=application_id,
states=list(self.JOB_INTERMEDIATE_STATES),
PaginationConfig={
"PageSize": results_per_response,
},
)
count = 0
for r in iterator:
job_ids = [jr["id"] for jr in r["jobRuns"]]
count += len(job_ids)
if len(job_ids) > 0:
self.log.info(
"Cancelling %s pending job(s) for the application %s so that it can be stopped",
len(job_ids),
application_id,
)
for job_id in job_ids:
self.conn.cancel_job_run(applicationId=application_id, jobRunId=job_id)
if count > 0:
self.log.info("now waiting for the %s cancelled job(s) to terminate", count)
self.get_waiter("no_job_running").wait(
applicationId=application_id,
states=list(self.JOB_INTERMEDIATE_STATES.union({"CANCELLING"})),
WaiterConfig=waiter_config,
)


class EmrContainerHook(AwsBaseHook):
Expand Down
22 changes: 22 additions & 0 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,10 @@ class EmrServerlessStopApplicationOperator(BaseOperator):
the application be stopped. Defaults to 5 minutes.
:param waiter_check_interval_seconds: Number of seconds between polling the state of the application.
Defaults to 30 seconds.
:param force_stop: If set to True, any job for that app that is not in a terminal state will be cancelled.
Otherwise, trying to stop an app with running jobs will return an error.
If you want to wait for the jobs to finish gracefully, use
:class:`airflow.providers.amazon.aws.sensors.emr.EmrServerlessJobSensor`
"""

template_fields: Sequence[str] = ("application_id",)
Expand All @@ -1036,13 +1040,15 @@ def __init__(
aws_conn_id: str = "aws_default",
waiter_countdown: int = 5 * 60,
waiter_check_interval_seconds: int = 30,
force_stop: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT about updating the docs/apache-airflow-providers-amazon/operators/emr/emr_serverless.rst to contain this additional param

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it to the system test (https://github.com/apache/airflow/pull/30757/files/8370c8172d8c28ffeec3ca460621b9d0447bfba6#diff-69be45953c5be696ca3159bb385a89c90930fc06e68357e8bdb33a1b31694f88R120) inside the howto_operator thing, so it'll be embeded in the doc automatically (as sample usage).
It doesn't explain what it's doing, but it being present there highlights its presence, and the code doc contains the explanation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good

**kwargs,
):
self.aws_conn_id = aws_conn_id
self.application_id = application_id
self.wait_for_completion = wait_for_completion
self.waiter_countdown = waiter_countdown
self.waiter_check_interval_seconds = waiter_check_interval_seconds
self.force_stop = force_stop
super().__init__(**kwargs)

@cached_property
Expand All @@ -1052,6 +1058,16 @@ def hook(self) -> EmrServerlessHook:

def execute(self, context: Context) -> None:
self.log.info("Stopping application: %s", self.application_id)

if self.force_stop:
self.hook.cancel_running_jobs(
self.application_id,
waiter_config={
"Delay": self.waiter_check_interval_seconds,
"MaxAttempts": self.waiter_countdown / self.waiter_check_interval_seconds,
},
)

self.hook.conn.stop_application(applicationId=self.application_id)

if self.wait_for_completion:
Expand Down Expand Up @@ -1088,6 +1104,10 @@ class EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperato
the application to be stopped, and then deleted. Defaults to 25 minutes.
:param waiter_check_interval_seconds: Number of seconds between polling the state of the application.
Defaults to 60 seconds.
:param force_stop: If set to True, any job for that app that is not in a terminal state will be cancelled.
Otherwise, trying to delete an app with running jobs will return an error.
If you want to wait for the jobs to finish gracefully, use
:class:`airflow.providers.amazon.aws.sensors.emr.EmrServerlessJobSensor`
"""

template_fields: Sequence[str] = ("application_id",)
Expand All @@ -1099,6 +1119,7 @@ def __init__(
aws_conn_id: str = "aws_default",
waiter_countdown: int = 25 * 60,
waiter_check_interval_seconds: int = 60,
force_stop: bool = False,
**kwargs,
):
self.wait_for_delete_completion = wait_for_completion
Expand All @@ -1110,6 +1131,7 @@ def __init__(
aws_conn_id=aws_conn_id,
waiter_countdown=waiter_countdown,
waiter_check_interval_seconds=waiter_check_interval_seconds,
force_stop=force_stop,
**kwargs,
)

Expand Down
18 changes: 18 additions & 0 deletions airflow/providers/amazon/aws/waiters/emr-serverless.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"version": 2,
"waiters": {
"no_job_running": {
"operation": "ListJobRuns",
"delay": 10,
"maxAttempts": 60,
"acceptors": [
{
"matcher": "path",
"argument": "length(jobRuns) == `0`",
"expected": true,
"state": "success"
}
]
Comment on lines +8 to +15
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like too much the fact that I don't have a failure case for this waiter, but I think there is nothing we can do about it...
The failure case would be count > prev_count, or detecting a new job_id that we didn't see in the previous calls, which is way beyond the capabilities of waiters.

}
}
}
41 changes: 41 additions & 0 deletions tests/providers/amazon/aws/hooks/test_emr_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations

from unittest.mock import MagicMock, PropertyMock, patch

from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook

task_id = "test_emr_serverless_create_application_operator"
Expand All @@ -34,3 +36,42 @@ def test_conn_attribute(self):
conn = hook.conn
conn2 = hook.conn
assert conn is conn2

@patch.object(EmrServerlessHook, "conn", new_callable=PropertyMock)
def test_cancel_jobs(self, conn_mock: MagicMock):
conn_mock().get_paginator().paginate.return_value = [{"jobRuns": [{"id": "job1"}, {"id": "job2"}]}]
hook = EmrServerlessHook(aws_conn_id="aws_default")
waiter_mock = MagicMock()
hook.get_waiter = waiter_mock

hook.cancel_running_jobs("app")

assert conn_mock().cancel_job_run.call_count == 2
conn_mock().cancel_job_run.assert_any_call(applicationId="app", jobRunId="job1")
conn_mock().cancel_job_run.assert_any_call(applicationId="app", jobRunId="job2")
waiter_mock.assert_called_with("no_job_running")

@patch.object(EmrServerlessHook, "conn", new_callable=PropertyMock)
def test_cancel_jobs_several_calls(self, conn_mock: MagicMock):
conn_mock().get_paginator().paginate.return_value = [
{"jobRuns": [{"id": "job1"}, {"id": "job2"}]},
{"jobRuns": [{"id": "job3"}, {"id": "job4"}]},
]
hook = EmrServerlessHook(aws_conn_id="aws_default")
waiter_mock = MagicMock()
hook.get_waiter = waiter_mock

hook.cancel_running_jobs("app")

assert conn_mock().cancel_job_run.call_count == 4
waiter_mock.assert_called_once() # we should wait once for all jobs, not once per page

@patch.object(EmrServerlessHook, "conn", new_callable=PropertyMock)
def test_cancel_jobs_but_no_jobs(self, conn_mock: MagicMock):
conn_mock.return_value.list_job_runs.return_value = {"jobRuns": []}
hook = EmrServerlessHook(aws_conn_id="aws_default")

hook.cancel_running_jobs("app")

# nothing very interesting should happen
conn_mock.assert_called_once()
15 changes: 14 additions & 1 deletion tests/providers/amazon/aws/operators/test_emr_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

from unittest import mock
from unittest.mock import MagicMock
from unittest.mock import MagicMock, PropertyMock
from uuid import UUID

import pytest
Expand Down Expand Up @@ -654,3 +654,16 @@ def test_stop_no_wait(self, mock_conn: MagicMock, mock_waiter: MagicMock):

mock_waiter.assert_not_called()
mock_conn.stop_application.assert_called_once()

@mock.patch("airflow.providers.amazon.aws.operators.emr.waiter")
@mock.patch.object(EmrServerlessStopApplicationOperator, "hook", new_callable=PropertyMock)
def test_force_stop(self, mock_hook: MagicMock, mock_waiter: MagicMock):
operator = EmrServerlessStopApplicationOperator(
task_id=task_id, application_id="test", force_stop=True
)

operator.execute(None)

mock_hook().cancel_running_jobs.assert_called_once()
mock_hook().conn.stop_application.assert_called_once()
mock_waiter.assert_called_once()
5 changes: 4 additions & 1 deletion tests/system/providers/amazon/aws/example_emr_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,15 @@
configuration_overrides=SPARK_CONFIGURATION_OVERRIDES,
)
# [END howto_operator_emr_serverless_start_job]
start_job.waiter_check_interval_seconds = 10
start_job.wait_for_completion = False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of waiting for the job to finish, we just trigger it and continue...


# [START howto_sensor_emr_serverless_job]
wait_for_job = EmrServerlessJobSensor(
task_id="wait_for_job",
application_id=emr_serverless_app_id,
job_run_id=start_job.output,
# the default is to wait for job completion, here we just wait for the job to be running.
target_states={"RUNNING"},
)
# [END howto_sensor_emr_serverless_job]
wait_for_job.poke_interval = 10
Expand All @@ -115,6 +117,7 @@
stop_app = EmrServerlessStopApplicationOperator(
task_id="stop_application",
application_id=emr_serverless_app_id,
force_stop=True,
)
# [END howto_operator_emr_serverless_stop_application]
stop_app.waiter_check_interval_seconds = 1
Expand Down