Skip to content

Commit

Permalink
Fix async KPO by waiting pod termination in execute_complete before…
Browse files Browse the repository at this point in the history
… cleanup (#32467)

* Fix async KPO by waiting pod termination in `execute_complete` before cleanup (#32467)

---------

Signed-off-by: Hussein Awala <[email protected]>
  • Loading branch information
hussein-awala authored Jul 12, 2023
1 parent 04a6e85 commit b3ce116
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 68 deletions.
13 changes: 3 additions & 10 deletions airflow/providers/cncf/kubernetes/operators/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,15 +636,11 @@ def invoke_defer_method(self):

def execute_complete(self, context: Context, event: dict, **kwargs):
pod = None
remote_pod = None
try:
pod = self.hook.get_pod(
event["name"],
event["namespace"],
)
# It is done to coincide with the current implementation of the general logic of the cleanup
# method. If it's going to be remade in future then it must be changed
remote_pod = pod
if event["status"] in ("error", "failed", "timeout"):
# fetch some logs when pod is failed
if self.get_logs:
Expand All @@ -661,16 +657,13 @@ def execute_complete(self, context: Context, event: dict, **kwargs):

if self.do_xcom_push:
xcom_sidecar_output = self.extract_xcom(pod=pod)
pod = self.pod_manager.await_pod_completion(pod)
# It is done to coincide with the current implementation of the general logic of
# the cleanup method. If it's going to be remade in future then it must be changed
remote_pod = pod
return xcom_sidecar_output
finally:
if pod is not None and remote_pod is not None:
pod = self.pod_manager.await_pod_completion(pod)
if pod is not None:
self.post_complete_action(
pod=pod,
remote_pod=remote_pod,
remote_pod=pod,
)

def write_logs(self, pod: k8s.V1Pod):
Expand Down
26 changes: 9 additions & 17 deletions airflow/providers/cncf/kubernetes/triggers/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,23 +154,15 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
self.log.debug("Container %s status: %s", self.base_container_name, container_state)

if container_state == ContainerState.TERMINATED:
if pod_status not in PodPhase.terminal_states:
self.log.info(
"Pod %s is still running. Sleeping for %s seconds.",
self.pod_name,
self.poll_interval,
)
await asyncio.sleep(self.poll_interval)
else:
yield TriggerEvent(
{
"name": self.pod_name,
"namespace": self.pod_namespace,
"status": "success",
"message": "All containers inside pod have started successfully.",
}
)
return
yield TriggerEvent(
{
"name": self.pod_name,
"namespace": self.pod_namespace,
"status": "success",
"message": "All containers inside pod have started successfully.",
}
)
return
elif self.should_wait(pod_phase=pod_status, container_state=container_state):
self.log.info("Container is not completed and still working.")

Expand Down
99 changes: 99 additions & 0 deletions tests/providers/cncf/kubernetes/operators/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,10 +1389,12 @@ def test_async_create_pod_should_throw_exception(self, mocked_hook, mocked_clean
({"skip_on_exit_code": None}, 100, AirflowException, "Failed", "error"),
],
)
@patch(KUB_OP_PATH.format("pod_manager"))
@patch(HOOK_CLASS)
def test_async_create_pod_with_skip_on_exit_code_should_skip(
self,
mocked_hook,
mock_manager,
extra_kwargs,
actual_exit_code,
expected_exc,
Expand Down Expand Up @@ -1426,6 +1428,7 @@ def test_async_create_pod_with_skip_on_exit_code_should_skip(
remote_pod.status.phase = pod_status
remote_pod.status.container_statuses = [base_container, sidecar_container]
mocked_hook.return_value.get_pod.return_value = remote_pod
mock_manager.await_pod_completion.return_value = remote_pod

context = {
"ti": MagicMock(),
Expand Down Expand Up @@ -1608,3 +1611,99 @@ def test_cleanup_log_pod_spec_on_failure(self, log_pod_spec_on_failure, expect_m
pod.status = V1PodStatus(phase=PodPhase.FAILED)
with pytest.raises(AirflowException, match=expect_match):
k.cleanup(pod, pod)


@pytest.mark.parametrize("do_xcom_push", [True, False])
@patch(KUB_OP_PATH.format("extract_xcom"))
@patch(KUB_OP_PATH.format("post_complete_action"))
@patch(HOOK_CLASS)
def test_async_kpo_wait_termination_before_cleanup_on_success(
mocked_hook, post_complete_action, mock_extract_xcom, do_xcom_push
):
metadata = {"metadata.name": TEST_NAME, "metadata.namespace": TEST_NAMESPACE}
running_state = mock.MagicMock(**metadata, **{"status.phase": "Running"})
succeeded_state = mock.MagicMock(**metadata, **{"status.phase": "Succeeded"})
mocked_hook.return_value.get_pod.return_value = running_state
read_pod_mock = mocked_hook.return_value.core_v1_client.read_namespaced_pod
read_pod_mock.side_effect = [
running_state,
running_state,
succeeded_state,
]

ti_mock = MagicMock()

success_event = {
"status": "success",
"message": TEST_SUCCESS_MESSAGE,
"name": TEST_NAME,
"namespace": TEST_NAMESPACE,
}

k = KubernetesPodOperator(task_id="task", deferrable=True, do_xcom_push=do_xcom_push)
k.execute_complete({"ti": ti_mock}, success_event)

# check if it gets the pod
mocked_hook.return_value.get_pod.assert_called_once_with(TEST_NAME, TEST_NAMESPACE)

# check if it pushes the xcom
assert ti_mock.xcom_push.call_count == 2
ti_mock.xcom_push.assert_any_call(key="pod_name", value=TEST_NAME)
ti_mock.xcom_push.assert_any_call(key="pod_namespace", value=TEST_NAMESPACE)

# assert that the xcom are extracted/not extracted
if do_xcom_push:
mock_extract_xcom.assert_called_once()
else:
mock_extract_xcom.assert_not_called()

# check if it waits for the pod to complete
assert read_pod_mock.call_count == 3

# assert that the cleanup is called
post_complete_action.assert_called_once()


@pytest.mark.parametrize("do_xcom_push", [True, False])
@patch(KUB_OP_PATH.format("extract_xcom"))
@patch(KUB_OP_PATH.format("post_complete_action"))
@patch(HOOK_CLASS)
def test_async_kpo_wait_termination_before_cleanup_on_failure(
mocked_hook, post_complete_action, mock_extract_xcom, do_xcom_push
):
metadata = {"metadata.name": TEST_NAME, "metadata.namespace": TEST_NAMESPACE}
running_state = mock.MagicMock(**metadata, **{"status.phase": "Running"})
failed_state = mock.MagicMock(**metadata, **{"status.phase": "Failed"})
mocked_hook.return_value.get_pod.return_value = running_state
read_pod_mock = mocked_hook.return_value.core_v1_client.read_namespaced_pod
read_pod_mock.side_effect = [
running_state,
running_state,
failed_state,
]

ti_mock = MagicMock()

success_event = {"status": "failed", "message": "error", "name": TEST_NAME, "namespace": TEST_NAMESPACE}

post_complete_action.side_effect = AirflowException()

k = KubernetesPodOperator(task_id="task", deferrable=True, do_xcom_push=do_xcom_push)

with pytest.raises(AirflowException):
k.execute_complete({"ti": ti_mock}, success_event)

# check if it gets the pod
mocked_hook.return_value.get_pod.assert_called_once_with(TEST_NAME, TEST_NAMESPACE)

# assert that it does not push the xcom
ti_mock.xcom_push.assert_not_called()

# assert that the xcom are not extracted
mock_extract_xcom.assert_not_called()

# check if it waits for the pod to complete
assert read_pod_mock.call_count == 3

# assert that the cleanup is called
post_complete_action.assert_called_once()
32 changes: 1 addition & 31 deletions tests/providers/cncf/kubernetes/triggers/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def test_serialize(self, trigger):
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_PATH}._get_async_hook")
async def test_run_loop_return_success_event(self, mock_hook, mock_method, trigger):
pod_mock = mock.MagicMock(**{"status.phase": "Succeeded"})
mock_hook.return_value.get_pod.return_value = self._mock_pod_result(pod_mock)
mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.TERMINATED

expected_event = TriggerEvent(
Expand All @@ -112,35 +111,6 @@ async def test_run_loop_return_success_event(self, mock_hook, mock_method, trigg

assert actual_event == expected_event

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_PATH}._get_async_hook")
async def test_run_loop_wait_pod_termination_before_returning_success_event(
self, mock_hook, mock_method, trigger
):
running_state = mock.MagicMock(**{"status.phase": "Running"})
succeeded_state = mock.MagicMock(**{"status.phase": "Succeeded"})
mock_hook.return_value.get_pod.side_effect = [
self._mock_pod_result(running_state),
self._mock_pod_result(running_state),
self._mock_pod_result(succeeded_state),
]
mock_method.return_value = ContainerState.TERMINATED

expected_event = TriggerEvent(
{
"name": POD_NAME,
"namespace": NAMESPACE,
"status": "success",
"message": "All containers inside pod have started successfully.",
}
)
with mock.patch.object(asyncio, "sleep") as mock_sleep:
actual_event = await (trigger.run()).asend(None)

assert actual_event == expected_event
assert mock_sleep.call_count == 2

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_PATH}._get_async_hook")
Expand Down
12 changes: 2 additions & 10 deletions tests/providers/google/cloud/triggers/test_kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,7 @@ def test_serialize_should_execute_successfully(self, trigger):
async def test_run_loop_return_success_event_should_execute_successfully(
self, mock_hook, mock_method, trigger
):
running_state = mock.MagicMock(**{"status.phase": "Running"})
succeeded_state = mock.MagicMock(**{"status.phase": "Succeeded"})
mock_hook.return_value.get_pod.side_effect = [
self._mock_pod_result(running_state),
self._mock_pod_result(running_state),
self._mock_pod_result(succeeded_state),
]
mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.TERMINATED

expected_event = TriggerEvent(
Expand All @@ -127,11 +121,9 @@ async def test_run_loop_return_success_event_should_execute_successfully(
"message": "All containers inside pod have started successfully.",
}
)
with mock.patch.object(asyncio, "sleep") as mock_sleep:
actual_event = await (trigger.run()).asend(None)
actual_event = await (trigger.run()).asend(None)

assert actual_event == expected_event
assert mock_sleep.call_count == 2

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
Expand Down

0 comments on commit b3ce116

Please sign in to comment.