Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a stop operator to emr serverless #30720

Merged
merged 6 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 71 additions & 23 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,21 +1010,21 @@ def execute(self, context: Context) -> dict:
return response["jobRunId"]


class EmrServerlessDeleteApplicationOperator(BaseOperator):
class EmrServerlessStopApplicationOperator(BaseOperator):
"""
Operator to delete EMR Serverless application
Operator to stop an EMR Serverless application

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:EmrServerlessDeleteApplicationOperator`

:param application_id: ID of the EMR Serverless application to delete.
:param wait_for_completion: If true, wait for the Application to start before returning. Default to True
:param application_id: ID of the EMR Serverless application to stop.
:param wait_for_completion: If true, wait for the Application to stop before returning. Default to True
:param aws_conn_id: AWS connection to use
:param waiter_countdown: Total amount of time, in seconds, the operator will wait for
the application be deleted. Defaults to 25 minutes.
the application be stopped. Defaults to 5 minutes.
:param waiter_check_interval_seconds: Number of seconds between polling the state of the application.
Defaults to 60 seconds.
Defaults to 30 seconds.
"""

template_fields: Sequence[str] = ("application_id",)
Expand All @@ -1034,8 +1034,8 @@ def __init__(
application_id: str,
wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
waiter_countdown: int = 25 * 60,
waiter_check_interval_seconds: int = 60,
waiter_countdown: int = 5 * 60,
waiter_check_interval_seconds: int = 30,
**kwargs,
):
self.aws_conn_id = aws_conn_id
Expand All @@ -1054,28 +1054,76 @@ def execute(self, context: Context) -> None:
self.log.info("Stopping application: %s", self.application_id)
self.hook.conn.stop_application(applicationId=self.application_id)

# This should be replaced with a boto waiter when available.
waiter(
get_state_callable=self.hook.conn.get_application,
get_state_args={
"applicationId": self.application_id,
},
parse_response=["application", "state"],
desired_state=EmrServerlessHook.APPLICATION_FAILURE_STATES,
failure_states=set(),
object_type="application",
action="stopped",
countdown=self.waiter_countdown,
check_interval_seconds=self.waiter_check_interval_seconds,
if self.wait_for_completion:
# This should be replaced with a boto waiter when available.
waiter(
get_state_callable=self.hook.conn.get_application,
get_state_args={
"applicationId": self.application_id,
},
parse_response=["application", "state"],
desired_state=EmrServerlessHook.APPLICATION_FAILURE_STATES,
failure_states=set(),
object_type="application",
action="stopped",
countdown=self.waiter_countdown,
check_interval_seconds=self.waiter_check_interval_seconds,
)
self.log.info("EMR serverless application %s stopped successfully", self.application_id)


class EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperator):
"""
Operator to delete EMR Serverless application

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:EmrServerlessDeleteApplicationOperator`

:param application_id: ID of the EMR Serverless application to delete.
:param wait_for_completion: If true, wait for the Application to be deleted before returning.
Defaults to True. Note that this operator will always wait for the application to be STOPPED first.
:param aws_conn_id: AWS connection to use
:param waiter_countdown: Total amount of time, in seconds, the operator will wait for each step of first,
the application to be stopped, and then deleted. Defaults to 25 minutes.
:param waiter_check_interval_seconds: Number of seconds between polling the state of the application.
Defaults to 60 seconds.
"""

template_fields: Sequence[str] = ("application_id",)

def __init__(
self,
application_id: str,
wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
waiter_countdown: int = 25 * 60,
waiter_check_interval_seconds: int = 60,
**kwargs,
):
self.wait_for_delete_completion = wait_for_completion
# super stops the app
super().__init__(
application_id=application_id,
# when deleting an app, we always need to wait for it to stop before we can call delete()
wait_for_completion=True,
aws_conn_id=aws_conn_id,
waiter_countdown=waiter_countdown,
waiter_check_interval_seconds=waiter_check_interval_seconds,
**kwargs,
)

self.log.info("Deleting application: %s", self.application_id)
def execute(self, context: Context) -> None:
# super stops the app (or makes sure it's already stopped)
super().execute(context)

self.log.info("Now deleting application: %s", self.application_id)
response = self.hook.conn.delete_application(applicationId=self.application_id)

if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Application deletion failed: {response}")

if self.wait_for_completion:
if self.wait_for_delete_completion:
# This should be replaced with a boto waiter when available.
waiter(
get_state_callable=self.hook.conn.get_application,
Expand Down
30 changes: 27 additions & 3 deletions tests/providers/amazon/aws/operators/test_emr_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

from unittest import mock
from unittest.mock import MagicMock
from uuid import UUID

import pytest
Expand All @@ -26,6 +27,7 @@
EmrServerlessCreateApplicationOperator,
EmrServerlessDeleteApplicationOperator,
EmrServerlessStartJobOperator,
EmrServerlessStopApplicationOperator,
)

task_id = "test_emr_serverless_task_id"
Expand Down Expand Up @@ -606,14 +608,13 @@ def test_delete_application_without_wait_for_completion_successfully(self, mock_

operator.execute(None)

assert operator.wait_for_completion is False
mock_waiter.assert_called_once()
mock_conn.stop_application.assert_called_once()
mock_conn.delete_application.assert_called_once_with(applicationId=application_id_delete_operator)

@mock.patch("airflow.providers.amazon.aws.operators.emr.waiter")
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
def test_delete_application_failed_deleteion(self, mock_conn, mock_waiter):
def test_delete_application_failed_deletion(self, mock_conn, mock_waiter):
mock_waiter.return_value = True
mock_conn.stop_application.return_value = {}
mock_conn.delete_application.return_value = {"ResponseMetadata": {"HTTPStatusCode": 400}}
Expand All @@ -626,7 +627,30 @@ def test_delete_application_failed_deleteion(self, mock_conn, mock_waiter):

assert "Application deletion failed:" in str(ex_message.value)

assert operator.wait_for_completion is True
mock_waiter.assert_called_once()
mock_conn.stop_application.assert_called_once()
mock_conn.delete_application.assert_called_once_with(applicationId=application_id_delete_operator)


class TestEmrServerlessStopOperator:
@mock.patch("airflow.providers.amazon.aws.operators.emr.waiter")
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
def test_stop(self, mock_conn: MagicMock, mock_waiter: MagicMock):
operator = EmrServerlessStopApplicationOperator(task_id=task_id, application_id="test")

operator.execute(None)

mock_waiter.assert_called_once()
mock_conn.stop_application.assert_called_once()

@mock.patch("airflow.providers.amazon.aws.operators.emr.waiter")
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
def test_stop_no_wait(self, mock_conn: MagicMock, mock_waiter: MagicMock):
operator = EmrServerlessStopApplicationOperator(
task_id=task_id, application_id="test", wait_for_completion=False
)

operator.execute(None)

mock_waiter.assert_not_called()
mock_conn.stop_application.assert_called_once()
11 changes: 11 additions & 0 deletions tests/system/providers/amazon/aws/example_emr_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
EmrServerlessCreateApplicationOperator,
EmrServerlessDeleteApplicationOperator,
EmrServerlessStartJobOperator,
EmrServerlessStopApplicationOperator,
)
from airflow.providers.amazon.aws.operators.s3 import S3CreateBucketOperator, S3DeleteBucketOperator
from airflow.providers.amazon.aws.sensors.emr import EmrServerlessApplicationSensor, EmrServerlessJobSensor
Expand Down Expand Up @@ -108,6 +109,15 @@
job_run_id=start_job.output,
)
# [END howto_sensor_emr_serverless_job]
wait_for_job.poke_interval = 10

# [START howto_operator_emr_serverless_stop_application]
stop_app = EmrServerlessStopApplicationOperator(
task_id="stop_application",
application_id=emr_serverless_app_id,
)
# [END howto_operator_emr_serverless_stop_application]
stop_app.waiter_check_interval_seconds = 1

# [START howto_operator_emr_serverless_delete_application]
delete_app = EmrServerlessDeleteApplicationOperator(
Expand All @@ -134,6 +144,7 @@
wait_for_app_creation,
start_job,
wait_for_job,
stop_app,
# TEST TEARDOWN
delete_app,
delete_s3_bucket,
Expand Down