Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix logic of the skip_all_except method #31153

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 35 additions & 21 deletions airflow/models/skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,19 @@
from typing import TYPE_CHECKING, Iterable, Sequence

from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.sqlalchemy import tuple_in_condition
from airflow.utils.state import State

if TYPE_CHECKING:
from pendulum import DateTime
from sqlalchemy import Session

from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.models.taskmixin import DAGNode
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
Expand Down Expand Up @@ -60,24 +61,30 @@ class SkipMixin(LoggingMixin):
def _set_state_to_skipped(
self,
dag_run: DagRun | DagRunPydantic,
tasks: Iterable[Operator],
tasks: Sequence[str] | Sequence[tuple[str, int]],
session: Session,
) -> None:
"""Used internally to set state of task instances to skipped from the same dag run."""
now = timezone.utcnow()

session.query(TaskInstance).filter(
TaskInstance.dag_id == dag_run.dag_id,
TaskInstance.run_id == dag_run.run_id,
TaskInstance.task_id.in_(d.task_id for d in tasks),
).update(
{
TaskInstance.state: State.SKIPPED,
TaskInstance.start_date: now,
TaskInstance.end_date: now,
},
synchronize_session=False,
)
if tasks:
now = timezone.utcnow()
TI = TaskInstance
query = session.query(TI).filter(
TI.dag_id == dag_run.dag_id,
TI.run_id == dag_run.run_id,
)
if isinstance(tasks[0], tuple):
query = query.filter(tuple_in_condition((TI.task_id, TI.map_index), tasks))
else:
query = query.filter(TI.task_id.in_(tasks))

query.update(
{
TaskInstance.state: State.SKIPPED,
TaskInstance.start_date: now,
TaskInstance.end_date: now,
},
synchronize_session=False,
)

@provide_session
def skip(
Expand Down Expand Up @@ -130,7 +137,8 @@ def skip(
if dag_run is None:
raise ValueError("dag_run is required")

self._set_state_to_skipped(dag_run, task_list, session)
task_ids_list = [d.task_id for d in task_list]
self._set_state_to_skipped(dag_run, task_ids_list, session)
session.commit()

# SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.
Expand All @@ -140,7 +148,7 @@ def skip(

XCom.set(
key=XCOM_SKIPMIXIN_KEY,
value={XCOM_SKIPMIXIN_SKIPPED: [d.task_id for d in task_list]},
value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list},
task_id=task_id,
dag_id=dag_run.dag_id,
run_id=dag_run.run_id,
Expand Down Expand Up @@ -183,6 +191,7 @@ def skip_all_except(
)

dag_run = ti.get_dagrun()
assert isinstance(dag_run, DagRun)

# TODO(potiuk): Handle TaskInstancePydantic case differently - we need to figure out the way to
# pass task that has been set in LocalTaskJob but in the way that TaskInstancePydantic definition
Expand Down Expand Up @@ -218,10 +227,15 @@ def skip_all_except(
for branch_task_id in list(branch_task_id_set):
branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))

skip_tasks = [t for t in downstream_tasks if t.task_id not in branch_task_id_set]
follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set]
skip_tasks = [
(t.task_id, downstream_ti.map_index)
for t in downstream_tasks
if (downstream_ti := dag_run.get_task_instance(t.task_id, map_index=ti.map_index))
and t.task_id not in branch_task_id_set
]

self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks])
follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set]
self.log.info("Skipping tasks %s", skip_tasks)
with create_session() as session:
self._set_state_to_skipped(dag_run, skip_tasks, session=session)
# For some reason, session.commit() needs to happen before xcom_push.
Expand Down
36 changes: 36 additions & 0 deletions tests/models/test_skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pytest

from airflow import settings
from airflow.decorators import task, task_group
from airflow.exceptions import AirflowException
from airflow.models.skipmixin import SkipMixin
from airflow.models.taskinstance import TaskInstance as TI
Expand Down Expand Up @@ -133,6 +134,41 @@ def get_state(ti):

assert executed_states == expected_states

def test_mapped_tasks_skip_all_except(self, dag_maker):
with dag_maker("dag_test_skip_all_except") as dag:

@task
def branch_op(k):
...

@task_group
def task_group_op(k):
branch_a = EmptyOperator(task_id="branch_a")
branch_b = EmptyOperator(task_id="branch_b")
branch_op(k) >> [branch_a, branch_b]

task_group_op.expand(k=[i for i in range(2)])

dag_maker.create_dagrun()
branch_op_ti_0 = TI(dag.get_task("task_group_op.branch_op"), execution_date=DEFAULT_DATE, map_index=0)
branch_op_ti_1 = TI(dag.get_task("task_group_op.branch_op"), execution_date=DEFAULT_DATE, map_index=1)
branch_a_ti_0 = TI(dag.get_task("task_group_op.branch_a"), execution_date=DEFAULT_DATE, map_index=0)
branch_a_ti_1 = TI(dag.get_task("task_group_op.branch_a"), execution_date=DEFAULT_DATE, map_index=1)
branch_b_ti_0 = TI(dag.get_task("task_group_op.branch_b"), execution_date=DEFAULT_DATE, map_index=0)
branch_b_ti_1 = TI(dag.get_task("task_group_op.branch_b"), execution_date=DEFAULT_DATE, map_index=1)

SkipMixin().skip_all_except(ti=branch_op_ti_0, branch_task_ids="task_group_op.branch_a")
SkipMixin().skip_all_except(ti=branch_op_ti_1, branch_task_ids="task_group_op.branch_b")

def get_state(ti):
ti.refresh_from_db()
return ti.state

assert get_state(branch_a_ti_0) == State.NONE
assert get_state(branch_b_ti_0) == State.SKIPPED
assert get_state(branch_a_ti_1) == State.SKIPPED
assert get_state(branch_b_ti_1) == State.NONE

def test_raise_exception_on_not_accepted_branch_task_ids_type(self, dag_maker):
with dag_maker("dag_test_skip_all_except_wrong_type"):
task = EmptyOperator(task_id="task")
Expand Down