From abcd6d1fcefad792df7318fb6eee01b27bc683f7 Mon Sep 17 00:00:00 2001 From: Maciej Obuchowski Date: Thu, 20 Feb 2025 22:21:47 +0100 Subject: [PATCH] pass error for on_task_instance_failed in task sdk Signed-off-by: Maciej Obuchowski --- .../providers/openlineage/plugins/listener.py | 24 ++++++++++---- .../airflow/sdk/execution_time/task_runner.py | 32 ++++++++++++------- .../tests/execution_time/test_task_runner.py | 15 +++++---- 3 files changed, 47 insertions(+), 24 deletions(-) diff --git a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py index c49880237b8a7..880320babe4b2 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py +++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py @@ -178,9 +178,13 @@ def on_running(): self.log.debug("Skipping this instance of rescheduled task - START event was emitted already") return + date = dagrun.logical_date + if AIRFLOW_V_3_0_PLUS and date is None: + date = dagrun.run_after + parent_run_id = self.adapter.build_dag_run_id( dag_id=dag.dag_id, - logical_date=dagrun.logical_date, + logical_date=date, clear_number=clear_number, ) @@ -188,7 +192,7 @@ def on_running(): dag_id=dag.dag_id, task_id=task.task_id, try_number=task_instance.try_number, - logical_date=dagrun.logical_date, + logical_date=date, map_index=task_instance.map_index, ) event_type = RunState.RUNNING.value.lower() @@ -276,9 +280,13 @@ def _on_task_instance_success(self, task_instance: RuntimeTaskInstance, dag, dag @print_warning(self.log) def on_success(): + date = dagrun.logical_date + if AIRFLOW_V_3_0_PLUS and date is None: + date = dagrun.run_after + parent_run_id = self.adapter.build_dag_run_id( dag_id=dag.dag_id, - logical_date=dagrun.logical_date, + logical_date=date, clear_number=dagrun.clear_number, ) @@ -286,7 +294,7 @@ def on_success(): dag_id=dag.dag_id, task_id=task.task_id, try_number=_get_try_number_success(task_instance), - logical_date=dagrun.logical_date, + logical_date=date, map_index=task_instance.map_index, ) event_type = RunState.COMPLETE.value.lower() @@ -393,9 +401,13 @@ def _on_task_instance_failed( @print_warning(self.log) def on_failure(): + date = dagrun.logical_date + if AIRFLOW_V_3_0_PLUS and date is None: + date = dagrun.run_after + parent_run_id = self.adapter.build_dag_run_id( dag_id=dag.dag_id, - logical_date=dagrun.logical_date, + logical_date=date, clear_number=dagrun.clear_number, ) @@ -403,7 +415,7 @@ def on_failure(): dag_id=dag.dag_id, task_id=task.task_id, try_number=task_instance.try_number, - logical_date=dagrun.logical_date, + logical_date=date, map_index=task_instance.map_index, ) event_type = RunState.FAIL.value.lower() diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 99967579bb493..7924ea1cfb85f 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -572,7 +572,7 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv def run( ti: RuntimeTaskInstance, log: Logger -) -> tuple[IntermediateTIState | TerminalTIState, ToSupervisor | None]: +) -> tuple[IntermediateTIState | TerminalTIState, ToSupervisor | None, BaseException | None]: """Run the task in this process.""" from airflow.exceptions import ( AirflowException, @@ -591,6 +591,7 @@ def run( msg: ToSupervisor | None = None state: IntermediateTIState | TerminalTIState + error: BaseException | None = None try: context = ti.get_template_context() with set_current_context(context): @@ -599,7 +600,7 @@ def run( if early_exit := _prepare(ti, log, context): msg = early_exit state = TerminalTIState.FAILED - return state, msg + return state, msg, error result = _execute_task(context, ti) @@ -639,7 +640,7 @@ def run( reschedule_date=reschedule.reschedule_date, end_date=datetime.now(tz=timezone.utc) ) state = IntermediateTIState.UP_FOR_RESCHEDULE - except (AirflowFailException, AirflowSensorTimeout): + except (AirflowFailException, AirflowSensorTimeout) as e: # If AirflowFailException is raised, task should not retry. # If a sensor in reschedule mode reaches timeout, task should not retry. log.exception("Task failed with exception") @@ -650,7 +651,8 @@ def run( end_date=datetime.now(tz=timezone.utc), ) state = TerminalTIState.FAIL_WITHOUT_RETRY - except (AirflowTaskTimeout, AirflowException): + error = e + except (AirflowTaskTimeout, AirflowException) as e: # We should allow retries if the task has defined it. log.exception("Task failed with exception") msg = TaskState( @@ -658,7 +660,8 @@ def run( end_date=datetime.now(tz=timezone.utc), ) state = TerminalTIState.FAILED - except AirflowTaskTerminated: + error = e + except AirflowTaskTerminated as e: # External state updates are already handled with `ti_heartbeat` and will be # updated already be another UI API. So, these exceptions should ideally never be thrown. # If these are thrown, we should mark the TI state as failed. @@ -668,7 +671,8 @@ def run( end_date=datetime.now(tz=timezone.utc), ) state = TerminalTIState.FAIL_WITHOUT_RETRY - except SystemExit: + error = e + except SystemExit as e: # SystemExit needs to be retried if they are eligible. log.exception("Task failed with exception") msg = TaskState( @@ -676,15 +680,17 @@ def run( end_date=datetime.now(tz=timezone.utc), ) state = TerminalTIState.FAILED - except BaseException: + error = e + except BaseException as e: log.exception("Task failed with exception") msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc)) state = TerminalTIState.FAILED + error = e finally: if msg: SUPERVISOR_COMMS.send_request(msg=msg, log=log) # Return the message to make unit tests easier too - return state, msg + return state, msg, error def _execute_task(context: Context, ti: RuntimeTaskInstance): @@ -759,7 +765,9 @@ def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger): _xcom_push(ti, "return_value", result, mapped_length=mapped_length) -def finalize(ti: RuntimeTaskInstance, state: TerminalTIState, log: Logger): +def finalize( + ti: RuntimeTaskInstance, state: TerminalTIState, log: Logger, error: BaseException | None = None +): # Pushing xcom for each operator extra links defined on the operator only. for oe in ti.task.operator_extra_links: link, xcom_key = oe.get_link(operator=ti.task, ti_key=ti.id), oe.xcom_key # type: ignore[arg-type] @@ -774,7 +782,7 @@ def finalize(ti: RuntimeTaskInstance, state: TerminalTIState, log: Logger): # TODO: Run task success callbacks here if state in [TerminalTIState.FAILED, TerminalTIState.FAIL_WITHOUT_RETRY]: get_listener_manager().hook.on_task_instance_failed( - previous_state=TaskInstanceState.RUNNING, task_instance=ti + previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error ) # TODO: Run task failure callbacks here @@ -787,8 +795,8 @@ def main(): SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](input=sys.stdin) try: ti, log = startup() - state, msg = run(ti, log) - finalize(ti, state, log) + state, msg, error = run(ti, log) + finalize(ti, state, log, error) except KeyboardInterrupt: log = structlog.get_logger(logger_name="task") log.exception("Ctrl-c hit") diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 5e4405165e235..ba53c618e1dd4 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -1147,7 +1147,7 @@ def execute(self, context): "a_simple_list": ["one", "two", "three", "actually one value is made per line"], }, ) - _, msg = run(runtime_ti, log=mock.MagicMock()) + _, msg, _ = run(runtime_ti, log=mock.MagicMock()) assert isinstance(msg, SucceedTask) def test_task_run_with_operator_extra_links(self, create_runtime_ti, mock_supervisor_comms, time_machine): @@ -1502,6 +1502,7 @@ class CustomListener: def __init__(self): self.state = [] self.component = None + self.error = None @hookimpl def on_starting(self, component): @@ -1516,8 +1517,9 @@ def on_task_instance_success(self, previous_state, task_instance): self.state.append(TaskInstanceState.SUCCESS) @hookimpl - def on_task_instance_failed(self, previous_state, task_instance): + def on_task_instance_failed(self, previous_state, task_instance, error): self.state.append(TaskInstanceState.FAILED) + self.error = error @hookimpl def before_stopping(self, component): @@ -1566,7 +1568,7 @@ def execute(self, context): assert isinstance(listener.component, TaskRunnerMarker) del listener.component - state, _ = run(runtime_ti, log) + state, _, _ = run(runtime_ti, log) finalize(runtime_ti, state, log) assert isinstance(listener.component, TaskRunnerMarker) @@ -1595,7 +1597,7 @@ def execute(self, context): ) log = mock.MagicMock() - state, _ = run(runtime_ti, log) + state, _, _ = run(runtime_ti, log) finalize(runtime_ti, state, log) assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS] @@ -1633,7 +1635,8 @@ def execute(self, context): ) log = mock.MagicMock() - state, _ = run(runtime_ti, log) - finalize(runtime_ti, state, log) + state, _, error = run(runtime_ti, log) + finalize(runtime_ti, state, log, error) assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.FAILED] + assert listener.error == error