Skip to content

Commit

Permalink
AIP-72: Port task success overtime to the Supervisor (#44590)
Browse files Browse the repository at this point in the history
This PR ports the overtime feature on `LocalTaskJob` (added in #39890) to the Supervisor.
It allows to terminate Task process to terminate when it exceeding the configured success overtime threshold which is useful when we add Listenener to the Task process.

closes #44356

Also added `TaskState` to update state and send end_date from task process to the supervisor.
  • Loading branch information
kaxil authored Dec 3, 2024
1 parent 40821bf commit d059d4a
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 15 deletions.
2 changes: 2 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

from __future__ import annotations

from datetime import datetime
from typing import Annotated, Literal, Union

from pydantic import BaseModel, ConfigDict, Field
Expand Down Expand Up @@ -101,6 +102,7 @@ class TaskState(BaseModel):
"""

state: TerminalTIState
end_date: datetime | None = None
type: Literal["TaskState"] = "TaskState"


Expand Down
37 changes: 30 additions & 7 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
GetVariable,
GetXCom,
StartupDetails,
TaskState,
ToSupervisor,
)

Expand Down Expand Up @@ -265,9 +266,9 @@ class WatchedSubprocess:
client: Client

_process: psutil.Process
_exit_code: int | None = None
_terminal_state: str | None = None
_final_state: str | None = None
_exit_code: int | None = attrs.field(default=None, init=False)
_terminal_state: str | None = attrs.field(default=None, init=False)
_final_state: str | None = attrs.field(default=None, init=False)

_last_successful_heartbeat: float = attrs.field(default=0, init=False)
_last_heartbeat_attempt: float = attrs.field(default=0, init=False)
Expand All @@ -277,6 +278,12 @@ class WatchedSubprocess:
# does not hang around forever.
failed_heartbeats: int = attrs.field(default=0, init=False)

# Maximum possible time (in seconds) that task will have for execution of auxiliary processes
# like listeners after task is complete.
# TODO: This should come from airflow.cfg: [core] task_success_overtime
TASK_OVERTIME_THRESHOLD: ClassVar[float] = 20.0
_task_end_time_monotonic: float | None = attrs.field(default=None, init=False)

selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector)

procs: ClassVar[weakref.WeakValueDictionary[int, WatchedSubprocess]] = weakref.WeakValueDictionary()
Expand Down Expand Up @@ -500,6 +507,21 @@ def _monitor_subprocess(self):

self._send_heartbeat_if_needed()

self._handle_task_overtime_if_needed()

def _handle_task_overtime_if_needed(self):
"""Handle termination of auxiliary processes if the task exceeds the configured overtime."""
# If the task has reached a terminal state, we can start monitoring the overtime
if not self._terminal_state:
return

if (
self._task_end_time_monotonic
and (time.monotonic() - self._task_end_time_monotonic) > self.TASK_OVERTIME_THRESHOLD
):
log.warning("Task success overtime reached; terminating process", ti_id=self.ti_id)
self.kill(signal.SIGTERM, force=True)

def _service_subprocess(self, max_wait_time: float, raise_on_timeout: bool = False):
"""
Service subprocess events by processing socket activity and checking for process exit.
Expand Down Expand Up @@ -631,9 +653,11 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N
log.exception("Unable to decode message", line=line)
continue

# if isinstance(msg, TaskState):
# self._terminal_state = msg.state
if isinstance(msg, GetConnection):
resp = None
if isinstance(msg, TaskState):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
elif isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
resp = conn.model_dump_json(exclude_unset=True).encode()
elif isinstance(msg, GetVariable):
Expand All @@ -645,7 +669,6 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N
elif isinstance(msg, DeferTask):
self._terminal_state = IntermediateTIState.DEFERRED
self.client.task_instances.defer(self.ti_id, msg)
resp = None
else:
log.error("Unhandled request", msg=msg)
continue
Expand Down
12 changes: 9 additions & 3 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@

import os
import sys
from datetime import datetime, timezone
from io import FileIO
from typing import TYPE_CHECKING, TextIO

import attrs
import structlog
from pydantic import ConfigDict, TypeAdapter

from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.execution_time.comms import DeferTask, StartupDetails, ToSupervisor, ToTask
from airflow.sdk.execution_time.comms import DeferTask, StartupDetails, TaskState, ToSupervisor, ToTask

if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger as Logger
Expand Down Expand Up @@ -158,11 +159,14 @@ def run(ti: RuntimeTaskInstance, log: Logger):
if TYPE_CHECKING:
assert ti.task is not None
assert isinstance(ti.task, BaseOperator)

msg: ToSupervisor | None = None
try:
# TODO: pre execute etc.
# TODO next_method to support resuming from deferred
# TODO: Get a real context object
ti.task.execute({"task_instance": ti}) # type: ignore[attr-defined]
msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc))
except TaskDeferred as defer:
classpath, trigger_kwargs = defer.trigger.serialize()
next_method = defer.method_name
Expand All @@ -173,7 +177,6 @@ def run(ti: RuntimeTaskInstance, log: Logger):
next_method=next_method,
trigger_timeout=timeout,
)
SUPERVISOR_COMMS.send_request(msg=msg, log=log)
except AirflowSkipException:
...
except AirflowRescheduleException:
Expand All @@ -189,6 +192,9 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO: Handle TI handle failure
raise

if msg:
SUPERVISOR_COMMS.send_request(msg=msg, log=log)


def finalize(log: Logger): ...

Expand Down
69 changes: 68 additions & 1 deletion task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from airflow.sdk.api import client as sdk_client
from airflow.sdk.api.client import ServerResponseError
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import (
ConnectionResult,
DeferTask,
Expand Down Expand Up @@ -478,6 +478,73 @@ def test_heartbeat_failures_handling(self, monkeypatch, mocker, captured_logs, t
"timestamp": mocker.ANY,
} in captured_logs

@pytest.mark.parametrize(
["terminal_state", "task_end_time_monotonic", "overtime_threshold", "expected_kill"],
[
pytest.param(
None,
15.0,
10,
False,
id="no_terminal_state",
),
pytest.param(TerminalTIState.SUCCESS, 15.0, 10, False, id="below_threshold"),
pytest.param(TerminalTIState.SUCCESS, 9.0, 10, True, id="above_threshold"),
pytest.param(TerminalTIState.FAILED, 9.0, 10, True, id="above_threshold_failed_state"),
pytest.param(TerminalTIState.SKIPPED, 9.0, 10, True, id="above_threshold_skipped_state"),
pytest.param(TerminalTIState.SUCCESS, None, 20, False, id="task_end_datetime_none"),
],
)
def test_overtime_handling(
self,
mocker,
terminal_state,
task_end_time_monotonic,
overtime_threshold,
expected_kill,
monkeypatch,
):
"""Test handling of overtime under various conditions."""
# Mocking logger since we are only interested that it is called with the expected message
# and not the actual log output
mock_logger = mocker.patch("airflow.sdk.execution_time.supervisor.log")

# Mock the kill method at the class level so we can assert it was called with the correct signal
mock_kill = mocker.patch("airflow.sdk.execution_time.supervisor.WatchedSubprocess.kill")

# Mock the current monotonic time
mocker.patch("time.monotonic", return_value=20.0)

# Patch the task overtime threshold
monkeypatch.setattr(WatchedSubprocess, "TASK_OVERTIME_THRESHOLD", overtime_threshold)

mock_watched_subprocess = WatchedSubprocess(
ti_id=TI_ID,
pid=12345,
stdin=mocker.Mock(),
process=mocker.Mock(),
client=mocker.Mock(),
)

# Set the terminal state and task end datetime
mock_watched_subprocess._terminal_state = terminal_state
mock_watched_subprocess._task_end_time_monotonic = task_end_time_monotonic

# Call `wait` to trigger the overtime handling
# This will call the `kill` method if the task has been running for too long
mock_watched_subprocess._handle_task_overtime_if_needed()

# Validate process kill behavior and log messages
if expected_kill:
mock_kill.assert_called_once_with(signal.SIGTERM, force=True)
mock_logger.warning.assert_called_once_with(
"Task success overtime reached; terminating process",
ti_id=TI_ID,
)
else:
mock_kill.assert_not_called()
mock_logger.warning.assert_not_called()


class TestWatchedSubprocessKill:
@pytest.fixture
Expand Down
20 changes: 16 additions & 4 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from uuid6 import uuid7

from airflow.sdk import DAG, BaseOperator
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.execution_time.comms import DeferTask, StartupDetails
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import DeferTask, StartupDetails, TaskState
from airflow.sdk.execution_time.task_runner import CommsDecoder, parse, run
from airflow.utils import timezone

Expand Down Expand Up @@ -78,7 +78,7 @@ def test_parse(test_dags_dir: Path):
assert isinstance(ti.task.dag, DAG)


def test_run_basic(test_dags_dir: Path):
def test_run_basic(test_dags_dir: Path, time_machine):
"""Test running a basic task."""
what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1),
Expand All @@ -87,7 +87,19 @@ def test_run_basic(test_dags_dir: Path):
)

ti = parse(what)
run(ti, log=mock.MagicMock())

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
mock_supervisor_comms.send_request = mock.Mock()
run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=TaskState(state=TerminalTIState.SUCCESS, end_date=instant), log=mock.ANY
)


def test_run_deferred_basic(test_dags_dir: Path, time_machine):
Expand Down

0 comments on commit d059d4a

Please sign in to comment.