Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor waiter function and improve unit tests #28753

Merged
merged 1 commit into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions airflow/providers/amazon/aws/utils/waiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,16 @@ def waiter(
:param check_interval_seconds: Number of seconds waiter should wait before attempting
to retry get_state_callable. Defaults to 60 seconds.
"""
response = get_state_callable(**get_state_args)
state: str = get_state(response, parse_response)
while state not in desired_state:
while True:
state = get_state(get_state_callable(**get_state_args), parse_response)
if state in desired_state:
break
if state in failure_states:
raise AirflowException(f"{object_type.title()} reached failure state {state}.")
if countdown >= check_interval_seconds:
if countdown > check_interval_seconds:
countdown -= check_interval_seconds
log.info("Waiting for %s to be %s.", object_type.lower(), action.lower())
time.sleep(check_interval_seconds)
state = get_state(get_state_callable(**get_state_args), parse_response)
else:
message = f"{object_type.title()} still not {action.lower()} after the allocated time limit."
log.error(message)
Expand Down
152 changes: 56 additions & 96 deletions tests/providers/amazon/aws/utils/test_waiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,95 +19,75 @@

from typing import Any
from unittest import mock
from unittest.mock import MagicMock

import pytest

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.utils.waiter import waiter


def get_state_test(test_id: str) -> dict[str, Any]:
pass


SUCCESS_STATES = {"Created"}
FAILURE_STATES = {"Failed"}


class TestWaiter:
def _generate_response(self, test_id, state: str) -> dict[str, Any]:
return {
"Id": test_id,
"Status": {
"State": state,
},
}
def generate_response(state: str) -> dict[str, Any]:
return {
"Status": {
"State": state,
},
}

@mock.patch("tests.providers.amazon.aws.utils.test_waiter.get_state_test")
def test_waiter(self, mock_get_state):
test_id = "test_id"
mock_get_state.return_value = self._generate_response(test_id, "Created")
waiter(
get_state_callable=get_state_test,
get_state_args={"test_id": test_id},
parse_response=["Status", "State"],
desired_state=SUCCESS_STATES,
failure_states=FAILURE_STATES,
object_type="test_object",
action="testing",
)

assert mock_get_state.called_once()

@mock.patch("tests.providers.amazon.aws.utils.test_waiter.get_state_test")
def test_waiter_failure(self, mock_get_state):
test_id = "test_id"
mock_get_state.return_value = self._generate_response(test_id, "Failed")
with pytest.raises(AirflowException) as ex_message:
waiter(
get_state_callable=get_state_test,
get_state_args={"test_id": test_id},
parse_response=["Status", "State"],
desired_state=SUCCESS_STATES,
failure_states=FAILURE_STATES,
object_type="test_object",
action="testing",
)
assert mock_get_state.called_once()
assert "Test_Object reached failure state Failed." in str(ex_message.value)

@mock.patch("tests.providers.amazon.aws.utils.test_waiter.get_state_test")
@mock.patch("time.sleep", return_value=None)
def test_waiter_multiple_attempts_success(self, _, mock_get_state):
test_id = "test_id"
test_data = [self._generate_response(test_id, "Pending") for i in range(2)]
test_data.append(self._generate_response(test_id, "Created"))
mock_get_state.side_effect = test_data

waiter(
get_state_callable=get_state_test,
get_state_args={"test_id": test_id},
parse_response=["Status", "State"],
desired_state=SUCCESS_STATES,
failure_states=FAILURE_STATES,
object_type="test_object",
action="testing",
check_interval_seconds=1,
countdown=5,
)
assert mock_get_state.call_count == 3

@mock.patch("tests.providers.amazon.aws.utils.test_waiter.get_state_test")
class TestWaiter:
@pytest.mark.parametrize(
"get_state_responses, fails, expected_exception, expected_num_calls",
[
([generate_response("Created")], False, None, 1),
([generate_response("Failed")], True, AirflowException, 1),
(
[generate_response("Pending"), generate_response("Pending"), generate_response("Created")],
False,
None,
3,
),
(
[generate_response("Pending"), generate_response("Failed")],
True,
AirflowException,
2,
),
(
[generate_response("Pending"), generate_response("Pending"), generate_response("Failed")],
True,
AirflowException,
3,
),
([generate_response("Pending") for i in range(10)], True, RuntimeError, 5),
],
)
@mock.patch("time.sleep", return_value=None)
def test_waiter_multiple_attempts_fail(self, _, mock_get_state):
test_id = "test_id"
test_data = [self._generate_response(test_id, "Pending") for i in range(2)]
test_data.append(self._generate_response(test_id, "Failed"))
mock_get_state.side_effect = test_data
with pytest.raises(AirflowException) as ex_message:
def test_waiter(self, _, get_state_responses, fails, expected_exception, expected_num_calls):
mock_get_state = MagicMock()
mock_get_state.side_effect = get_state_responses
get_state_args = {}

if fails:
with pytest.raises(expected_exception):
waiter(
get_state_callable=mock_get_state,
get_state_args=get_state_args,
parse_response=["Status", "State"],
desired_state=SUCCESS_STATES,
failure_states=FAILURE_STATES,
object_type="test_object",
action="testing",
check_interval_seconds=1,
countdown=5,
)
else:
waiter(
get_state_callable=get_state_test,
get_state_args={"test_id": test_id},
get_state_callable=mock_get_state,
get_state_args=get_state_args,
parse_response=["Status", "State"],
desired_state=SUCCESS_STATES,
failure_states=FAILURE_STATES,
Expand All @@ -116,25 +96,5 @@ def test_waiter_multiple_attempts_fail(self, _, mock_get_state):
check_interval_seconds=1,
countdown=5,
)
assert mock_get_state.call_count == 3
assert "Test_Object reached failure state Failed." in str(ex_message.value)

@mock.patch("tests.providers.amazon.aws.utils.test_waiter.get_state_test")
@mock.patch("time.sleep", return_value=None)
def test_waiter_multiple_attempts_pending(self, _, mock_get_state):
test_id = "test_id"
mock_get_state.return_value = self._generate_response(test_id, "Pending")
with pytest.raises(RuntimeError) as ex_message:
waiter(
get_state_callable=get_state_test,
get_state_args={"test_id": test_id},
parse_response=["Status", "State"],
desired_state=SUCCESS_STATES,
failure_states=FAILURE_STATES,
object_type="test_object",
action="testing",
check_interval_seconds=1,
countdown=5,
)
assert mock_get_state.call_count == 5
assert "Test_Object still not testing after the allocated time limit." in str(ex_message.value)
assert mock_get_state.call_count == expected_num_calls