Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix async KPO by waiting pod termination in execute_complete before cleanup #32467

Merged
merged 2 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions airflow/providers/cncf/kubernetes/operators/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,15 +636,11 @@ def invoke_defer_method(self):

def execute_complete(self, context: Context, event: dict, **kwargs):
pod = None
remote_pod = None
try:
pod = self.hook.get_pod(
event["name"],
event["namespace"],
)
# It is done to coincide with the current implementation of the general logic of the cleanup
# method. If it's going to be remade in future then it must be changed
remote_pod = pod
if event["status"] in ("error", "failed", "timeout"):
# fetch some logs when pod is failed
if self.get_logs:
Expand All @@ -661,16 +657,13 @@ def execute_complete(self, context: Context, event: dict, **kwargs):

if self.do_xcom_push:
xcom_sidecar_output = self.extract_xcom(pod=pod)
pod = self.pod_manager.await_pod_completion(pod)
# It is done to coincide with the current implementation of the general logic of
# the cleanup method. If it's going to be remade in future then it must be changed
remote_pod = pod
return xcom_sidecar_output
finally:
if pod is not None and remote_pod is not None:
pod = self.pod_manager.await_pod_completion(pod)
if pod is not None:
self.post_complete_action(
pod=pod,
remote_pod=remote_pod,
remote_pod=pod,
)

def write_logs(self, pod: k8s.V1Pod):
Expand Down
26 changes: 9 additions & 17 deletions airflow/providers/cncf/kubernetes/triggers/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,23 +154,15 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
self.log.debug("Container %s status: %s", self.base_container_name, container_state)

if container_state == ContainerState.TERMINATED:
if pod_status not in PodPhase.terminal_states:
self.log.info(
"Pod %s is still running. Sleeping for %s seconds.",
self.pod_name,
self.poll_interval,
)
await asyncio.sleep(self.poll_interval)
else:
yield TriggerEvent(
{
"name": self.pod_name,
"namespace": self.pod_namespace,
"status": "success",
"message": "All containers inside pod have started successfully.",
}
)
return
yield TriggerEvent(
{
"name": self.pod_name,
"namespace": self.pod_namespace,
"status": "success",
"message": "All containers inside pod have started successfully.",
}
)
return
elif self.should_wait(pod_phase=pod_status, container_state=container_state):
self.log.info("Container is not completed and still working.")

Expand Down
99 changes: 99 additions & 0 deletions tests/providers/cncf/kubernetes/operators/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,10 +1389,12 @@ def test_async_create_pod_should_throw_exception(self, mocked_hook, mocked_clean
({"skip_on_exit_code": None}, 100, AirflowException, "Failed", "error"),
],
)
@patch(KUB_OP_PATH.format("pod_manager"))
@patch(HOOK_CLASS)
def test_async_create_pod_with_skip_on_exit_code_should_skip(
self,
mocked_hook,
mock_manager,
extra_kwargs,
actual_exit_code,
expected_exc,
Expand Down Expand Up @@ -1426,6 +1428,7 @@ def test_async_create_pod_with_skip_on_exit_code_should_skip(
remote_pod.status.phase = pod_status
remote_pod.status.container_statuses = [base_container, sidecar_container]
mocked_hook.return_value.get_pod.return_value = remote_pod
mock_manager.await_pod_completion.return_value = remote_pod

context = {
"ti": MagicMock(),
Expand Down Expand Up @@ -1608,3 +1611,99 @@ def test_cleanup_log_pod_spec_on_failure(self, log_pod_spec_on_failure, expect_m
pod.status = V1PodStatus(phase=PodPhase.FAILED)
with pytest.raises(AirflowException, match=expect_match):
k.cleanup(pod, pod)


@pytest.mark.parametrize("do_xcom_push", [True, False])
@patch(KUB_OP_PATH.format("extract_xcom"))
@patch(KUB_OP_PATH.format("post_complete_action"))
@patch(HOOK_CLASS)
def test_async_kpo_wait_termination_before_cleanup_on_success(
mocked_hook, post_complete_action, mock_extract_xcom, do_xcom_push
):
metadata = {"metadata.name": TEST_NAME, "metadata.namespace": TEST_NAMESPACE}
running_state = mock.MagicMock(**metadata, **{"status.phase": "Running"})
succeeded_state = mock.MagicMock(**metadata, **{"status.phase": "Succeeded"})
mocked_hook.return_value.get_pod.return_value = running_state
read_pod_mock = mocked_hook.return_value.core_v1_client.read_namespaced_pod
read_pod_mock.side_effect = [
running_state,
running_state,
succeeded_state,
]

ti_mock = MagicMock()

success_event = {
"status": "success",
"message": TEST_SUCCESS_MESSAGE,
"name": TEST_NAME,
"namespace": TEST_NAMESPACE,
}

k = KubernetesPodOperator(task_id="task", deferrable=True, do_xcom_push=do_xcom_push)
k.execute_complete({"ti": ti_mock}, success_event)

# check if it gets the pod
mocked_hook.return_value.get_pod.assert_called_once_with(TEST_NAME, TEST_NAMESPACE)

# check if it pushes the xcom
assert ti_mock.xcom_push.call_count == 2
ti_mock.xcom_push.assert_any_call(key="pod_name", value=TEST_NAME)
ti_mock.xcom_push.assert_any_call(key="pod_namespace", value=TEST_NAMESPACE)

# assert that the xcom are extracted/not extracted
if do_xcom_push:
mock_extract_xcom.assert_called_once()
else:
mock_extract_xcom.assert_not_called()

# check if it waits for the pod to complete
assert read_pod_mock.call_count == 3

# assert that the cleanup is called
post_complete_action.assert_called_once()


@pytest.mark.parametrize("do_xcom_push", [True, False])
@patch(KUB_OP_PATH.format("extract_xcom"))
@patch(KUB_OP_PATH.format("post_complete_action"))
@patch(HOOK_CLASS)
def test_async_kpo_wait_termination_before_cleanup_on_failure(
mocked_hook, post_complete_action, mock_extract_xcom, do_xcom_push
):
metadata = {"metadata.name": TEST_NAME, "metadata.namespace": TEST_NAMESPACE}
running_state = mock.MagicMock(**metadata, **{"status.phase": "Running"})
failed_state = mock.MagicMock(**metadata, **{"status.phase": "Failed"})
mocked_hook.return_value.get_pod.return_value = running_state
read_pod_mock = mocked_hook.return_value.core_v1_client.read_namespaced_pod
read_pod_mock.side_effect = [
running_state,
running_state,
failed_state,
]

ti_mock = MagicMock()

success_event = {"status": "failed", "message": "error", "name": TEST_NAME, "namespace": TEST_NAMESPACE}

post_complete_action.side_effect = AirflowException()

k = KubernetesPodOperator(task_id="task", deferrable=True, do_xcom_push=do_xcom_push)

with pytest.raises(AirflowException):
k.execute_complete({"ti": ti_mock}, success_event)

# check if it gets the pod
mocked_hook.return_value.get_pod.assert_called_once_with(TEST_NAME, TEST_NAMESPACE)

# assert that it does not push the xcom
ti_mock.xcom_push.assert_not_called()

# assert that the xcom are not extracted
mock_extract_xcom.assert_not_called()

# check if it waits for the pod to complete
assert read_pod_mock.call_count == 3

# assert that the cleanup is called
post_complete_action.assert_called_once()
32 changes: 1 addition & 31 deletions tests/providers/cncf/kubernetes/triggers/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def test_serialize(self, trigger):
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_PATH}._get_async_hook")
async def test_run_loop_return_success_event(self, mock_hook, mock_method, trigger):
pod_mock = mock.MagicMock(**{"status.phase": "Succeeded"})
mock_hook.return_value.get_pod.return_value = self._mock_pod_result(pod_mock)
mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.TERMINATED

expected_event = TriggerEvent(
Expand All @@ -112,35 +111,6 @@ async def test_run_loop_return_success_event(self, mock_hook, mock_method, trigg

assert actual_event == expected_event

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_PATH}._get_async_hook")
async def test_run_loop_wait_pod_termination_before_returning_success_event(
self, mock_hook, mock_method, trigger
):
running_state = mock.MagicMock(**{"status.phase": "Running"})
succeeded_state = mock.MagicMock(**{"status.phase": "Succeeded"})
mock_hook.return_value.get_pod.side_effect = [
self._mock_pod_result(running_state),
self._mock_pod_result(running_state),
self._mock_pod_result(succeeded_state),
]
mock_method.return_value = ContainerState.TERMINATED

expected_event = TriggerEvent(
{
"name": POD_NAME,
"namespace": NAMESPACE,
"status": "success",
"message": "All containers inside pod have started successfully.",
}
)
with mock.patch.object(asyncio, "sleep") as mock_sleep:
actual_event = await (trigger.run()).asend(None)

assert actual_event == expected_event
assert mock_sleep.call_count == 2

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_PATH}._get_async_hook")
Expand Down
12 changes: 2 additions & 10 deletions tests/providers/google/cloud/triggers/test_kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,7 @@ def test_serialize_should_execute_successfully(self, trigger):
async def test_run_loop_return_success_event_should_execute_successfully(
self, mock_hook, mock_method, trigger
):
running_state = mock.MagicMock(**{"status.phase": "Running"})
succeeded_state = mock.MagicMock(**{"status.phase": "Succeeded"})
mock_hook.return_value.get_pod.side_effect = [
self._mock_pod_result(running_state),
self._mock_pod_result(running_state),
self._mock_pod_result(succeeded_state),
]
mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.TERMINATED

expected_event = TriggerEvent(
Expand All @@ -127,11 +121,9 @@ async def test_run_loop_return_success_event_should_execute_successfully(
"message": "All containers inside pod have started successfully.",
}
)
with mock.patch.object(asyncio, "sleep") as mock_sleep:
actual_event = await (trigger.run()).asend(None)
actual_event = await (trigger.run()).asend(None)

assert actual_event == expected_event
assert mock_sleep.call_count == 2

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
Expand Down