Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass error for on_task_instance_failed in task sdk #46941

Merged
merged 1 commit into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -178,17 +178,21 @@ def on_running():
self.log.debug("Skipping this instance of rescheduled task - START event was emitted already")
return

date = dagrun.logical_date
if AIRFLOW_V_3_0_PLUS and date is None:
date = dagrun.run_after

parent_run_id = self.adapter.build_dag_run_id(
dag_id=dag.dag_id,
logical_date=dagrun.logical_date,
logical_date=date,
clear_number=clear_number,
)

task_uuid = self.adapter.build_task_instance_run_id(
dag_id=dag.dag_id,
task_id=task.task_id,
try_number=task_instance.try_number,
logical_date=dagrun.logical_date,
logical_date=date,
map_index=task_instance.map_index,
)
event_type = RunState.RUNNING.value.lower()
Expand Down Expand Up @@ -276,17 +280,21 @@ def _on_task_instance_success(self, task_instance: RuntimeTaskInstance, dag, dag

@print_warning(self.log)
def on_success():
date = dagrun.logical_date
if AIRFLOW_V_3_0_PLUS and date is None:
date = dagrun.run_after

parent_run_id = self.adapter.build_dag_run_id(
dag_id=dag.dag_id,
logical_date=dagrun.logical_date,
logical_date=date,
clear_number=dagrun.clear_number,
)

task_uuid = self.adapter.build_task_instance_run_id(
dag_id=dag.dag_id,
task_id=task.task_id,
try_number=_get_try_number_success(task_instance),
logical_date=dagrun.logical_date,
logical_date=date,
map_index=task_instance.map_index,
)
event_type = RunState.COMPLETE.value.lower()
Expand Down Expand Up @@ -393,17 +401,21 @@ def _on_task_instance_failed(

@print_warning(self.log)
def on_failure():
date = dagrun.logical_date
if AIRFLOW_V_3_0_PLUS and date is None:
date = dagrun.run_after

parent_run_id = self.adapter.build_dag_run_id(
dag_id=dag.dag_id,
logical_date=dagrun.logical_date,
logical_date=date,
clear_number=dagrun.clear_number,
)

task_uuid = self.adapter.build_task_instance_run_id(
dag_id=dag.dag_id,
task_id=task.task_id,
try_number=task_instance.try_number,
logical_date=dagrun.logical_date,
logical_date=date,
map_index=task_instance.map_index,
)
event_type = RunState.FAIL.value.lower()
Expand Down
32 changes: 20 additions & 12 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv

def run(
ti: RuntimeTaskInstance, log: Logger
) -> tuple[IntermediateTIState | TerminalTIState, ToSupervisor | None]:
) -> tuple[IntermediateTIState | TerminalTIState, ToSupervisor | None, BaseException | None]:
"""Run the task in this process."""
from airflow.exceptions import (
AirflowException,
Expand All @@ -591,6 +591,7 @@ def run(

msg: ToSupervisor | None = None
state: IntermediateTIState | TerminalTIState
error: BaseException | None = None
try:
context = ti.get_template_context()
with set_current_context(context):
Expand All @@ -599,7 +600,7 @@ def run(
if early_exit := _prepare(ti, log, context):
msg = early_exit
state = TerminalTIState.FAILED
return state, msg
return state, msg, error

result = _execute_task(context, ti)

Expand Down Expand Up @@ -639,7 +640,7 @@ def run(
reschedule_date=reschedule.reschedule_date, end_date=datetime.now(tz=timezone.utc)
)
state = IntermediateTIState.UP_FOR_RESCHEDULE
except (AirflowFailException, AirflowSensorTimeout):
except (AirflowFailException, AirflowSensorTimeout) as e:
# If AirflowFailException is raised, task should not retry.
# If a sensor in reschedule mode reaches timeout, task should not retry.
log.exception("Task failed with exception")
Expand All @@ -650,15 +651,17 @@ def run(
end_date=datetime.now(tz=timezone.utc),
)
state = TerminalTIState.FAIL_WITHOUT_RETRY
except (AirflowTaskTimeout, AirflowException):
error = e
except (AirflowTaskTimeout, AirflowException) as e:
# We should allow retries if the task has defined it.
log.exception("Task failed with exception")
msg = TaskState(
state=TerminalTIState.FAILED,
end_date=datetime.now(tz=timezone.utc),
)
state = TerminalTIState.FAILED
except AirflowTaskTerminated:
error = e
except AirflowTaskTerminated as e:
# External state updates are already handled with `ti_heartbeat` and will be
# updated already be another UI API. So, these exceptions should ideally never be thrown.
# If these are thrown, we should mark the TI state as failed.
Expand All @@ -668,23 +671,26 @@ def run(
end_date=datetime.now(tz=timezone.utc),
)
state = TerminalTIState.FAIL_WITHOUT_RETRY
except SystemExit:
error = e
except SystemExit as e:
# SystemExit needs to be retried if they are eligible.
log.exception("Task failed with exception")
msg = TaskState(
state=TerminalTIState.FAILED,
end_date=datetime.now(tz=timezone.utc),
)
state = TerminalTIState.FAILED
except BaseException:
error = e
except BaseException as e:
log.exception("Task failed with exception")
msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc))
state = TerminalTIState.FAILED
error = e
finally:
if msg:
SUPERVISOR_COMMS.send_request(msg=msg, log=log)
# Return the message to make unit tests easier too
return state, msg
return state, msg, error


def _execute_task(context: Context, ti: RuntimeTaskInstance):
Expand Down Expand Up @@ -759,7 +765,9 @@ def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger):
_xcom_push(ti, "return_value", result, mapped_length=mapped_length)


def finalize(ti: RuntimeTaskInstance, state: TerminalTIState, log: Logger):
def finalize(
ti: RuntimeTaskInstance, state: TerminalTIState, log: Logger, error: BaseException | None = None
):
# Pushing xcom for each operator extra links defined on the operator only.
for oe in ti.task.operator_extra_links:
link, xcom_key = oe.get_link(operator=ti.task, ti_key=ti.id), oe.xcom_key # type: ignore[arg-type]
Expand All @@ -774,7 +782,7 @@ def finalize(ti: RuntimeTaskInstance, state: TerminalTIState, log: Logger):
# TODO: Run task success callbacks here
if state in [TerminalTIState.FAILED, TerminalTIState.FAIL_WITHOUT_RETRY]:
get_listener_manager().hook.on_task_instance_failed(
previous_state=TaskInstanceState.RUNNING, task_instance=ti
previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
)
# TODO: Run task failure callbacks here

Expand All @@ -787,8 +795,8 @@ def main():
SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](input=sys.stdin)
try:
ti, log = startup()
state, msg = run(ti, log)
finalize(ti, state, log)
state, msg, error = run(ti, log)
finalize(ti, state, log, error)
except KeyboardInterrupt:
log = structlog.get_logger(logger_name="task")
log.exception("Ctrl-c hit")
Expand Down
15 changes: 9 additions & 6 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@ def execute(self, context):
"a_simple_list": ["one", "two", "three", "actually one value is made per line"],
},
)
_, msg = run(runtime_ti, log=mock.MagicMock())
_, msg, _ = run(runtime_ti, log=mock.MagicMock())
assert isinstance(msg, SucceedTask)

def test_task_run_with_operator_extra_links(self, create_runtime_ti, mock_supervisor_comms, time_machine):
Expand Down Expand Up @@ -1502,6 +1502,7 @@ class CustomListener:
def __init__(self):
self.state = []
self.component = None
self.error = None

@hookimpl
def on_starting(self, component):
Expand All @@ -1516,8 +1517,9 @@ def on_task_instance_success(self, previous_state, task_instance):
self.state.append(TaskInstanceState.SUCCESS)

@hookimpl
def on_task_instance_failed(self, previous_state, task_instance):
def on_task_instance_failed(self, previous_state, task_instance, error):
self.state.append(TaskInstanceState.FAILED)
self.error = error

@hookimpl
def before_stopping(self, component):
Expand Down Expand Up @@ -1566,7 +1568,7 @@ def execute(self, context):
assert isinstance(listener.component, TaskRunnerMarker)
del listener.component

state, _ = run(runtime_ti, log)
state, _, _ = run(runtime_ti, log)
finalize(runtime_ti, state, log)
assert isinstance(listener.component, TaskRunnerMarker)

Expand Down Expand Up @@ -1595,7 +1597,7 @@ def execute(self, context):
)
log = mock.MagicMock()

state, _ = run(runtime_ti, log)
state, _, _ = run(runtime_ti, log)
finalize(runtime_ti, state, log)

assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS]
Expand Down Expand Up @@ -1633,7 +1635,8 @@ def execute(self, context):
)
log = mock.MagicMock()

state, _ = run(runtime_ti, log)
finalize(runtime_ti, state, log)
state, _, error = run(runtime_ti, log)
finalize(runtime_ti, state, log, error)

assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.FAILED]
assert listener.error == error
Loading