From 72ed26f27392219aa24427f4dcfee2abc46374b6 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Tue, 3 Dec 2024 01:57:26 +0000 Subject: [PATCH] AIP-72: Port task success overtime to the Supervisor This PR ports the overtime feature on `LocalTaskJob` (added in https://github.com/apache/airflow/pull/39890) to the Supervisor. It allows to terminate Task process to terminate when it exceeding the configured success overtime threshold which is useful when we add Listenener to the Task process. closes https://github.com/apache/airflow/issues/44356 Also added `TaskState` to update state and send end_date from task process to the supervisor. --- .../src/airflow/sdk/execution_time/comms.py | 2 + .../airflow/sdk/execution_time/supervisor.py | 37 +++++-- .../airflow/sdk/execution_time/task_runner.py | 14 ++- .../tests/execution_time/test_supervisor.py | 97 ++++++++++++++++++- .../tests/execution_time/test_task_runner.py | 22 ++++- 5 files changed, 156 insertions(+), 16 deletions(-) diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index 0e45e45700e88..307eacd462846 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -43,6 +43,7 @@ from __future__ import annotations +from datetime import datetime from typing import Annotated, Literal, Union from pydantic import BaseModel, ConfigDict, Field @@ -101,6 +102,7 @@ class TaskState(BaseModel): """ state: TerminalTIState + end_date: datetime | None = None type: Literal["TaskState"] = "TaskState" diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 015cc231f7dd4..4a002177d855c 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -51,6 +51,7 @@ GetVariable, GetXCom, StartupDetails, + TaskState, ToSupervisor, ) @@ -265,9 +266,9 @@ class WatchedSubprocess: client: Client _process: psutil.Process - _exit_code: int | None = None - _terminal_state: str | None = None - _final_state: str | None = None + _exit_code: int | None = attrs.field(default=None, init=False) + _terminal_state: str | None = attrs.field(default=None, init=False) + _final_state: str | None = attrs.field(default=None, init=False) _last_successful_heartbeat: float = attrs.field(default=0, init=False) _last_heartbeat_attempt: float = attrs.field(default=0, init=False) @@ -277,6 +278,13 @@ class WatchedSubprocess: # does not hang around forever. failed_heartbeats: int = attrs.field(default=0, init=False) + # Maximum possible time (in seconds) that task will have for execution of auxiliary processes + # like listeners after task is marked as success. + # TODO: This should be come from airflow.cfg: [core] task_success_overtime + task_success_overtime_threshold: float = attrs.field(default=20.0, init=False) + _overtime: float = attrs.field(default=0.0, init=False) + _task_end_datetime: datetime | None = attrs.field(default=None, init=False) + selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector) procs: ClassVar[weakref.WeakValueDictionary[int, WatchedSubprocess]] = weakref.WeakValueDictionary() @@ -500,6 +508,20 @@ def _monitor_subprocess(self): self._send_heartbeat_if_needed() + self._handle_task_overtime_if_needed() + + def _handle_task_overtime_if_needed(self): + """Handle termination of auxiliary processes if the task exceeds the configured success overtime.""" + if self._terminal_state != TerminalTIState.SUCCESS: + return + + now = datetime.now(tz=timezone.utc) + self._overtime = (now - (self._task_end_datetime or now)).total_seconds() + + if self._overtime > self.task_success_overtime_threshold: + log.warning("Task success overtime reached; terminating process", ti_id=self.ti_id) + self.kill(signal.SIGTERM, force=True) + def _service_subprocess(self, max_wait_time: float, raise_on_timeout: bool = False): """ Service subprocess events by processing socket activity and checking for process exit. @@ -631,9 +653,11 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N log.exception("Unable to decode message", line=line) continue - # if isinstance(msg, TaskState): - # self._terminal_state = msg.state - if isinstance(msg, GetConnection): + resp = None + if isinstance(msg, TaskState): + self._terminal_state = msg.state + self._task_end_datetime = msg.end_date + elif isinstance(msg, GetConnection): conn = self.client.connections.get(msg.conn_id) resp = conn.model_dump_json(exclude_unset=True).encode() elif isinstance(msg, GetVariable): @@ -645,7 +669,6 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N elif isinstance(msg, DeferTask): self._terminal_state = IntermediateTIState.DEFERRED self.client.task_instances.defer(self.ti_id, msg) - resp = None else: log.error("Unhandled request", msg=msg) continue diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 9c8bc4942294f..36f04fa0a7693 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -21,6 +21,7 @@ import os import sys +from datetime import datetime, timezone from io import FileIO from typing import TYPE_CHECKING, TextIO @@ -28,9 +29,9 @@ import structlog from pydantic import ConfigDict, TypeAdapter -from airflow.sdk.api.datamodels._generated import TaskInstance +from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.definitions.baseoperator import BaseOperator -from airflow.sdk.execution_time.comms import DeferTask, StartupDetails, ToSupervisor, ToTask +from airflow.sdk.execution_time.comms import DeferTask, StartupDetails, TaskState, ToSupervisor, ToTask if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger as Logger @@ -158,11 +159,14 @@ def run(ti: RuntimeTaskInstance, log: Logger): if TYPE_CHECKING: assert ti.task is not None assert isinstance(ti.task, BaseOperator) + + msg = None try: # TODO: pre execute etc. # TODO next_method to support resuming from deferred # TODO: Get a real context object ti.task.execute({"task_instance": ti}) # type: ignore[attr-defined] + msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc)) except TaskDeferred as defer: classpath, trigger_kwargs = defer.trigger.serialize() next_method = defer.method_name @@ -173,9 +177,8 @@ def run(ti: RuntimeTaskInstance, log: Logger): next_method=next_method, trigger_timeout=timeout, ) - SUPERVISOR_COMMS.send_request(msg=msg, log=log) except AirflowSkipException: - ... + msg = TaskState(state=TerminalTIState.SKIPPED) except AirflowRescheduleException: ... except (AirflowFailException, AirflowSensorTimeout): @@ -189,6 +192,9 @@ def run(ti: RuntimeTaskInstance, log: Logger): # TODO: Handle TI handle failure raise + if msg: + SUPERVISOR_COMMS.send_request(msg=msg, log=log) + def finalize(log: Logger): ... diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index cd9abae55c48c..78d18fb2c5bf3 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -36,7 +36,7 @@ from airflow.sdk.api import client as sdk_client from airflow.sdk.api.client import ServerResponseError -from airflow.sdk.api.datamodels._generated import TaskInstance +from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.execution_time.comms import ( ConnectionResult, DeferTask, @@ -478,6 +478,101 @@ def test_heartbeat_failures_handling(self, monkeypatch, mocker, captured_logs, t "timestamp": mocker.ANY, } in captured_logs + @pytest.mark.parametrize( + "terminal_state, task_end_datetime, overtime_threshold, expected_kill", + [ + # The current date is fixed at tz.datetime(2024, 12, 1, 10, 10, 20) + # Current time minus 5 seconds | Threshold: 10s + pytest.param( + None, + tz.datetime(2024, 12, 1, 10, 10, 15), + 10, + False, + id="no_terminal_state", + ), + # Terminal state is not SUCCESS, while we are above the threshold, it should not kill the process + pytest.param( + TerminalTIState.SKIPPED, + tz.datetime(2024, 12, 1, 10, 10, 0), + 1, + False, + id="non_success_state", + ), + # Current time minus 5 seconds | Threshold: 10s + pytest.param( + TerminalTIState.SUCCESS, + tz.datetime(2024, 12, 1, 10, 10, 15), + 10, + False, + id="below_threshold", + ), + # Current time minus 10 seconds | Threshold: 9s + pytest.param( + TerminalTIState.SUCCESS, + tz.datetime(2024, 12, 1, 10, 10, 10), + 9, + True, + id="above_threshold", + ), + # End datetime is None | Threshold: 20s + pytest.param( + TerminalTIState.SUCCESS, + None, + 20, + False, + id="task_end_datetime_none", + ), + ], + ) + def test_overtime_handling( + self, + mocker, + terminal_state, + task_end_datetime, + overtime_threshold, + expected_kill, + time_machine, + ): + """Test handling of overtime under various conditions.""" + # Mocking logger since we are only interested that it is called with the expected message + # and not the actual log output + mock_logger = mocker.patch("airflow.sdk.execution_time.supervisor.log") + + # Mock the kill method at the class level so we can assert it was called with the correct signal + mock_kill = mocker.patch("airflow.sdk.execution_time.supervisor.WatchedSubprocess.kill") + + mock_watched_subprocess = WatchedSubprocess( + ti_id=TI_ID, + pid=12345, + stdin=mocker.Mock(), + process=mocker.Mock(), + client=mocker.Mock(), + ) + + # Fix the current datetime + instant = tz.datetime(2024, 12, 1, 10, 10, 20) + time_machine.move_to(instant, tick=False) + + # Set the terminal state and task end datetime + mock_watched_subprocess._terminal_state = terminal_state + mock_watched_subprocess._task_end_datetime = task_end_datetime + mock_watched_subprocess.task_success_overtime_threshold = overtime_threshold + + # Call `wait` to trigger the overtime handling + # This will call the `kill` method if the task has been running for too long + mock_watched_subprocess.wait() + + # Validate process kill behavior and log messages + if expected_kill: + mock_kill.assert_called_once_with(signal.SIGTERM, force=True) + mock_logger.warning.assert_called_once_with( + "Task success overtime reached; terminating process", + ti_id=TI_ID, + ) + else: + mock_kill.assert_not_called() + mock_logger.warning.assert_not_called() + class TestWatchedSubprocessKill: @pytest.fixture diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index a66f54c709fc3..f021c35091e4c 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -28,9 +28,10 @@ from airflow.sdk import DAG, BaseOperator from airflow.sdk.api.datamodels._generated import TaskInstance -from airflow.sdk.execution_time.comms import DeferTask, StartupDetails +from airflow.sdk.execution_time.comms import DeferTask, StartupDetails, TaskState from airflow.sdk.execution_time.task_runner import CommsDecoder, parse, run from airflow.utils import timezone +from airflow.utils.state import TerminalTIState class TestCommsDecoder: @@ -78,7 +79,7 @@ def test_parse(test_dags_dir: Path): assert isinstance(ti.task.dag, DAG) -def test_run_basic(test_dags_dir: Path): +def test_run_basic(test_dags_dir: Path, time_machine, mocked_supervisor_comms): """Test running a basic task.""" what = StartupDetails( ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1), @@ -87,10 +88,23 @@ def test_run_basic(test_dags_dir: Path): ) ti = parse(what) - run(ti, log=mock.MagicMock()) + + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as mock_supervisor_comms: + # Mocking the communication interface + mock_supervisor_comms.send_request = mock.Mock() + run(ti, log=mock.MagicMock()) + + mock_supervisor_comms.send_request.assert_called_once_with( + msg=TaskState(state=TerminalTIState.SUCCESS, end_date=instant), log=mock.ANY + ) -def test_run_deferred_basic(test_dags_dir: Path, time_machine): +def test_run_deferred_basic(test_dags_dir: Path, time_machine, mocked_supervisor_comms): """Test that a task can transition to a deferred state.""" what = StartupDetails( ti=TaskInstance(