From 9e3ceca885d5ee9558fb55b676343055651ffa22 Mon Sep 17 00:00:00 2001 From: Zhyhimont Dmitry Date: Tue, 9 May 2023 16:24:44 +0300 Subject: [PATCH 1/7] Fix logic of the skip_all_except method to work correctly with a mapped branch operator --- airflow/models/skipmixin.py | 39 +++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index d75a4a0e4d9b6..f5065e71d3423 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -26,6 +26,7 @@ from airflow.utils import timezone from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, create_session, provide_session +from airflow.utils.sqlalchemy import tuple_in_condition from airflow.utils.state import State if TYPE_CHECKING: @@ -60,17 +61,24 @@ class SkipMixin(LoggingMixin): def _set_state_to_skipped( self, dag_run: DagRun | DagRunPydantic, - tasks: Iterable[Operator], + tasks: Iterable[str] | Iterable[tuple[str, int]], session: Session, + include_map_index: bool = False, ) -> None: """Used internally to set state of task instances to skipped from the same dag run.""" now = timezone.utcnow() + TI = TaskInstance + query = session.query(TI).filter( + TI.dag_id == dag_run.dag_id, + TI.run_id == dag_run.run_id, + ) + + if include_map_index: + query = query.filter(tuple_in_condition((TI.task_id, TI.map_index), tasks)) + else: + query = query.filter(TI.task_id.in_(tasks)) - session.query(TaskInstance).filter( - TaskInstance.dag_id == dag_run.dag_id, - TaskInstance.run_id == dag_run.run_id, - TaskInstance.task_id.in_(d.task_id for d in tasks), - ).update( + query.update( { TaskInstance.state: State.SKIPPED, TaskInstance.start_date: now, @@ -130,7 +138,8 @@ def skip( if dag_run is None: raise ValueError("dag_run is required") - self._set_state_to_skipped(dag_run, task_list, session) + task_ids_list = [d.task_id for d in task_list] + self._set_state_to_skipped(dag_run, task_ids_list, session) session.commit() # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available. @@ -140,7 +149,7 @@ def skip( XCom.set( key=XCOM_SKIPMIXIN_KEY, - value={XCOM_SKIPMIXIN_SKIPPED: [d.task_id for d in task_list]}, + value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list}, task_id=task_id, dag_id=dag_run.dag_id, run_id=dag_run.run_id, @@ -182,7 +191,7 @@ def skip_all_except( f"but got {type(branch_task_ids).__name__!r}." ) - dag_run = ti.get_dagrun() + dag_run: DagRun = ti.get_dagrun() # TODO(potiuk): Handle TaskInstancePydantic case differently - we need to figure out the way to # pass task that has been set in LocalTaskJob but in the way that TaskInstancePydantic definition @@ -218,12 +227,18 @@ def skip_all_except( for branch_task_id in list(branch_task_id_set): branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False)) - skip_tasks = [t for t in downstream_tasks if t.task_id not in branch_task_id_set] + skip_tasks = list() + for t in downstream_tasks: + downstream_ti =\ + dag_run.get_task_instance(t.task_id, map_index=ti.map_index) # type: ignore[union-attr] + if downstream_ti and t.task_id not in branch_task_id_set: + skip_tasks.append((t.task_id, downstream_ti.map_index)) + follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set] - self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks]) + self.log.info("Skipping tasks %s", skip_tasks) with create_session() as session: - self._set_state_to_skipped(dag_run, skip_tasks, session=session) + self._set_state_to_skipped(dag_run, skip_tasks, session=session, include_map_index=True) # For some reason, session.commit() needs to happen before xcom_push. # Otherwise the session is not committed. session.commit() From e36b09597ac088fee798c02892430bb728959698 Mon Sep 17 00:00:00 2001 From: zhyhimont Date: Fri, 9 Jun 2023 10:02:32 +0300 Subject: [PATCH 2/7] Address feadback --- airflow/models/skipmixin.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index f5065e71d3423..803bec910e92a 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -63,7 +63,6 @@ def _set_state_to_skipped( dag_run: DagRun | DagRunPydantic, tasks: Iterable[str] | Iterable[tuple[str, int]], session: Session, - include_map_index: bool = False, ) -> None: """Used internally to set state of task instances to skipped from the same dag run.""" now = timezone.utcnow() @@ -72,8 +71,7 @@ def _set_state_to_skipped( TI.dag_id == dag_run.dag_id, TI.run_id == dag_run.run_id, ) - - if include_map_index: + if isinstance(tasks[0], tuple): query = query.filter(tuple_in_condition((TI.task_id, TI.map_index), tasks)) else: query = query.filter(TI.task_id.in_(tasks)) @@ -235,10 +233,9 @@ def skip_all_except( skip_tasks.append((t.task_id, downstream_ti.map_index)) follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set] - self.log.info("Skipping tasks %s", skip_tasks) with create_session() as session: - self._set_state_to_skipped(dag_run, skip_tasks, session=session, include_map_index=True) + self._set_state_to_skipped(dag_run, skip_tasks, session=session) # For some reason, session.commit() needs to happen before xcom_push. # Otherwise the session is not committed. session.commit() From b04db093753155d4ab7d9e0a7712853136ff8703 Mon Sep 17 00:00:00 2001 From: zhyhimont Date: Fri, 9 Jun 2023 10:03:56 +0300 Subject: [PATCH 3/7] Add an unit test --- tests/models/test_skipmixin.py | 36 ++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py index c9912a64d4190..547dbec5b4208 100644 --- a/tests/models/test_skipmixin.py +++ b/tests/models/test_skipmixin.py @@ -24,6 +24,7 @@ import pytest from airflow import settings +from airflow.decorators import task, task_group from airflow.exceptions import AirflowException from airflow.models.skipmixin import SkipMixin from airflow.models.taskinstance import TaskInstance as TI @@ -133,6 +134,41 @@ def get_state(ti): assert executed_states == expected_states + def test_mapped_tasks_skip_all_except(self, dag_maker): + with dag_maker("dag_test_skip_all_except") as dag: + + @task + def branch_op(k): + ... + + @task_group + def task_group_op(k): + branch_a = EmptyOperator(task_id="branch_a") + branch_b = EmptyOperator(task_id="branch_b") + branch_op(k) >> [branch_a, branch_b] + + task_group_op.expand(k=[i for i in range(2)]) + + dag_maker.create_dagrun() + branch_op_ti_0 = TI(dag.get_task("task_group_op.branch_op"), execution_date=DEFAULT_DATE, map_index=0) + branch_op_ti_1 = TI(dag.get_task("task_group_op.branch_op"), execution_date=DEFAULT_DATE, map_index=1) + branch_a_ti_0 = TI(dag.get_task("task_group_op.branch_a"), execution_date=DEFAULT_DATE, map_index=0) + branch_a_ti_1 = TI(dag.get_task("task_group_op.branch_a"), execution_date=DEFAULT_DATE, map_index=1) + branch_b_ti_0 = TI(dag.get_task("task_group_op.branch_b"), execution_date=DEFAULT_DATE, map_index=0) + branch_b_ti_1 = TI(dag.get_task("task_group_op.branch_b"), execution_date=DEFAULT_DATE, map_index=1) + + SkipMixin().skip_all_except(ti=branch_op_ti_0, branch_task_ids="task_group_op.branch_a") + SkipMixin().skip_all_except(ti=branch_op_ti_1, branch_task_ids="task_group_op.branch_b") + + def get_state(ti): + ti.refresh_from_db() + return ti.state + + assert get_state(branch_a_ti_0) == State.NONE + assert get_state(branch_b_ti_0) == State.SKIPPED + assert get_state(branch_a_ti_1) == State.SKIPPED + assert get_state(branch_b_ti_1) == State.NONE + def test_raise_exception_on_not_accepted_branch_task_ids_type(self, dag_maker): with dag_maker("dag_test_skip_all_except_wrong_type"): task = EmptyOperator(task_id="task") From 3c5660d15919480c8051bf780a34b3e2684321d2 Mon Sep 17 00:00:00 2001 From: zhyhimont Date: Fri, 9 Jun 2023 11:16:28 +0300 Subject: [PATCH 4/7] Skipp empty tasks list --- airflow/models/skipmixin.py | 39 +++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index 803bec910e92a..e44443e666cd7 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -65,25 +65,26 @@ def _set_state_to_skipped( session: Session, ) -> None: """Used internally to set state of task instances to skipped from the same dag run.""" - now = timezone.utcnow() - TI = TaskInstance - query = session.query(TI).filter( - TI.dag_id == dag_run.dag_id, - TI.run_id == dag_run.run_id, - ) - if isinstance(tasks[0], tuple): - query = query.filter(tuple_in_condition((TI.task_id, TI.map_index), tasks)) - else: - query = query.filter(TI.task_id.in_(tasks)) - - query.update( - { - TaskInstance.state: State.SKIPPED, - TaskInstance.start_date: now, - TaskInstance.end_date: now, - }, - synchronize_session=False, - ) + if tasks: + now = timezone.utcnow() + TI = TaskInstance + query = session.query(TI).filter( + TI.dag_id == dag_run.dag_id, + TI.run_id == dag_run.run_id, + ) + if isinstance(tasks[0], tuple): + query = query.filter(tuple_in_condition((TI.task_id, TI.map_index), tasks)) + else: + query = query.filter(TI.task_id.in_(tasks)) + + query.update( + { + TaskInstance.state: State.SKIPPED, + TaskInstance.start_date: now, + TaskInstance.end_date: now, + }, + synchronize_session=False, + ) @provide_session def skip( From f4fbf4417a62b05ba4d2e5bcb8859e927d1c57cf Mon Sep 17 00:00:00 2001 From: zhyhimont Date: Fri, 9 Jun 2023 12:26:46 +0300 Subject: [PATCH 5/7] Fix static checks --- airflow/models/skipmixin.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index e44443e666cd7..dc37fdfa6b7a0 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -61,7 +61,7 @@ class SkipMixin(LoggingMixin): def _set_state_to_skipped( self, dag_run: DagRun | DagRunPydantic, - tasks: Iterable[str] | Iterable[tuple[str, int]], + tasks: Sequence[str] | Sequence[tuple[str, int]], session: Session, ) -> None: """Used internally to set state of task instances to skipped from the same dag run.""" @@ -190,7 +190,7 @@ def skip_all_except( f"but got {type(branch_task_ids).__name__!r}." ) - dag_run: DagRun = ti.get_dagrun() + dag_run = ti.get_dagrun() # TODO(potiuk): Handle TaskInstancePydantic case differently - we need to figure out the way to # pass task that has been set in LocalTaskJob but in the way that TaskInstancePydantic definition @@ -226,10 +226,11 @@ def skip_all_except( for branch_task_id in list(branch_task_id_set): branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False)) - skip_tasks = list() + skip_tasks = [] for t in downstream_tasks: - downstream_ti =\ - dag_run.get_task_instance(t.task_id, map_index=ti.map_index) # type: ignore[union-attr] + downstream_ti = dag_run.get_task_instance( # type: ignore[union-attr] + t.task_id, map_index=ti.map_index + ) if downstream_ti and t.task_id not in branch_task_id_set: skip_tasks.append((t.task_id, downstream_ti.map_index)) From 2d34c48579dbbbd8ea93702b814a477b61940ebd Mon Sep 17 00:00:00 2001 From: zhyhimont Date: Tue, 13 Jun 2023 20:00:51 +0300 Subject: [PATCH 6/7] Address feadback --- airflow/models/skipmixin.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index dc37fdfa6b7a0..96866dfd9361c 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Iterable, Sequence from airflow.exceptions import AirflowException, RemovedInAirflow3Warning +from airflow.models import DagRun from airflow.models.taskinstance import TaskInstance from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.utils import timezone @@ -33,7 +34,6 @@ from pendulum import DateTime from sqlalchemy import Session - from airflow.models.dagrun import DagRun from airflow.models.operator import Operator from airflow.models.taskmixin import DAGNode from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic @@ -191,6 +191,7 @@ def skip_all_except( ) dag_run = ti.get_dagrun() + assert isinstance(dag_run, DagRun) # TODO(potiuk): Handle TaskInstancePydantic case differently - we need to figure out the way to # pass task that has been set in LocalTaskJob but in the way that TaskInstancePydantic definition @@ -226,13 +227,12 @@ def skip_all_except( for branch_task_id in list(branch_task_id_set): branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False)) - skip_tasks = [] - for t in downstream_tasks: - downstream_ti = dag_run.get_task_instance( # type: ignore[union-attr] - t.task_id, map_index=ti.map_index - ) - if downstream_ti and t.task_id not in branch_task_id_set: - skip_tasks.append((t.task_id, downstream_ti.map_index)) + skip_tasks = [ + (t.task_id, downstream_ti.map_index) + for t in downstream_tasks + if (downstream_ti := dag_run.get_task_instance(t.task_id, map_index=ti.map_index)) + and t.task_id not in branch_task_id_set + ] follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set] self.log.info("Skipping tasks %s", skip_tasks) From 2888976989fad38d66c84a2296bc9a80c051847e Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 14 Jun 2023 15:33:44 +0800 Subject: [PATCH 7/7] Use fully qualified import --- airflow/models/skipmixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index 96866dfd9361c..849083e38b951 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Iterable, Sequence from airflow.exceptions import AirflowException, RemovedInAirflow3Warning -from airflow.models import DagRun +from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.utils import timezone