Skip to content

Commit

Permalink
Fix logic of the skip_all_except method (#31153)
Browse files Browse the repository at this point in the history
* Fix logic of the skip_all_except method to work correctly with a mapped branch operator

* Address feadback

* Add an unit test

* Skipp empty tasks list

* Fix static checks

* Address feadback

* Use fully qualified import

---------

Co-authored-by: Zhyhimont Dmitry <[email protected]>
Co-authored-by: zhyhimont <[email protected]>
Co-authored-by: Tzu-ping Chung <[email protected]>
  • Loading branch information
4 people authored Jul 6, 2023
1 parent ef75a3a commit 9985c35
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 21 deletions.
56 changes: 35 additions & 21 deletions airflow/models/skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,19 @@
from typing import TYPE_CHECKING, Iterable, Sequence

from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.sqlalchemy import tuple_in_condition
from airflow.utils.state import State

if TYPE_CHECKING:
from pendulum import DateTime
from sqlalchemy import Session

from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.models.taskmixin import DAGNode
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
Expand Down Expand Up @@ -60,24 +61,30 @@ class SkipMixin(LoggingMixin):
def _set_state_to_skipped(
self,
dag_run: DagRun | DagRunPydantic,
tasks: Iterable[Operator],
tasks: Sequence[str] | Sequence[tuple[str, int]],
session: Session,
) -> None:
"""Used internally to set state of task instances to skipped from the same dag run."""
now = timezone.utcnow()

session.query(TaskInstance).filter(
TaskInstance.dag_id == dag_run.dag_id,
TaskInstance.run_id == dag_run.run_id,
TaskInstance.task_id.in_(d.task_id for d in tasks),
).update(
{
TaskInstance.state: State.SKIPPED,
TaskInstance.start_date: now,
TaskInstance.end_date: now,
},
synchronize_session=False,
)
if tasks:
now = timezone.utcnow()
TI = TaskInstance
query = session.query(TI).filter(
TI.dag_id == dag_run.dag_id,
TI.run_id == dag_run.run_id,
)
if isinstance(tasks[0], tuple):
query = query.filter(tuple_in_condition((TI.task_id, TI.map_index), tasks))
else:
query = query.filter(TI.task_id.in_(tasks))

query.update(
{
TaskInstance.state: State.SKIPPED,
TaskInstance.start_date: now,
TaskInstance.end_date: now,
},
synchronize_session=False,
)

@provide_session
def skip(
Expand Down Expand Up @@ -130,7 +137,8 @@ def skip(
if dag_run is None:
raise ValueError("dag_run is required")

self._set_state_to_skipped(dag_run, task_list, session)
task_ids_list = [d.task_id for d in task_list]
self._set_state_to_skipped(dag_run, task_ids_list, session)
session.commit()

# SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.
Expand All @@ -140,7 +148,7 @@ def skip(

XCom.set(
key=XCOM_SKIPMIXIN_KEY,
value={XCOM_SKIPMIXIN_SKIPPED: [d.task_id for d in task_list]},
value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list},
task_id=task_id,
dag_id=dag_run.dag_id,
run_id=dag_run.run_id,
Expand Down Expand Up @@ -183,6 +191,7 @@ def skip_all_except(
)

dag_run = ti.get_dagrun()
assert isinstance(dag_run, DagRun)

# TODO(potiuk): Handle TaskInstancePydantic case differently - we need to figure out the way to
# pass task that has been set in LocalTaskJob but in the way that TaskInstancePydantic definition
Expand Down Expand Up @@ -218,10 +227,15 @@ def skip_all_except(
for branch_task_id in list(branch_task_id_set):
branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))

skip_tasks = [t for t in downstream_tasks if t.task_id not in branch_task_id_set]
follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set]
skip_tasks = [
(t.task_id, downstream_ti.map_index)
for t in downstream_tasks
if (downstream_ti := dag_run.get_task_instance(t.task_id, map_index=ti.map_index))
and t.task_id not in branch_task_id_set
]

self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks])
follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set]
self.log.info("Skipping tasks %s", skip_tasks)
with create_session() as session:
self._set_state_to_skipped(dag_run, skip_tasks, session=session)
# For some reason, session.commit() needs to happen before xcom_push.
Expand Down
36 changes: 36 additions & 0 deletions tests/models/test_skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pytest

from airflow import settings
from airflow.decorators import task, task_group
from airflow.exceptions import AirflowException
from airflow.models.skipmixin import SkipMixin
from airflow.models.taskinstance import TaskInstance as TI
Expand Down Expand Up @@ -133,6 +134,41 @@ def get_state(ti):

assert executed_states == expected_states

def test_mapped_tasks_skip_all_except(self, dag_maker):
with dag_maker("dag_test_skip_all_except") as dag:

@task
def branch_op(k):
...

@task_group
def task_group_op(k):
branch_a = EmptyOperator(task_id="branch_a")
branch_b = EmptyOperator(task_id="branch_b")
branch_op(k) >> [branch_a, branch_b]

task_group_op.expand(k=[i for i in range(2)])

dag_maker.create_dagrun()
branch_op_ti_0 = TI(dag.get_task("task_group_op.branch_op"), execution_date=DEFAULT_DATE, map_index=0)
branch_op_ti_1 = TI(dag.get_task("task_group_op.branch_op"), execution_date=DEFAULT_DATE, map_index=1)
branch_a_ti_0 = TI(dag.get_task("task_group_op.branch_a"), execution_date=DEFAULT_DATE, map_index=0)
branch_a_ti_1 = TI(dag.get_task("task_group_op.branch_a"), execution_date=DEFAULT_DATE, map_index=1)
branch_b_ti_0 = TI(dag.get_task("task_group_op.branch_b"), execution_date=DEFAULT_DATE, map_index=0)
branch_b_ti_1 = TI(dag.get_task("task_group_op.branch_b"), execution_date=DEFAULT_DATE, map_index=1)

SkipMixin().skip_all_except(ti=branch_op_ti_0, branch_task_ids="task_group_op.branch_a")
SkipMixin().skip_all_except(ti=branch_op_ti_1, branch_task_ids="task_group_op.branch_b")

def get_state(ti):
ti.refresh_from_db()
return ti.state

assert get_state(branch_a_ti_0) == State.NONE
assert get_state(branch_b_ti_0) == State.SKIPPED
assert get_state(branch_a_ti_1) == State.SKIPPED
assert get_state(branch_b_ti_1) == State.NONE

def test_raise_exception_on_not_accepted_branch_task_ids_type(self, dag_maker):
with dag_maker("dag_test_skip_all_except_wrong_type"):
task = EmptyOperator(task_id="task")
Expand Down

0 comments on commit 9985c35

Please sign in to comment.