Skip to content

Commit

Permalink
Add support for TaskGroup in ExternalTaskSensor (#24902)
Browse files Browse the repository at this point in the history
  • Loading branch information
pateash authored Aug 22, 2022
1 parent 0eb0b54 commit bc04c5f
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 53 deletions.
17 changes: 15 additions & 2 deletions airflow/example_dags/example_external_task_marker_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,18 @@
mode="reschedule",
)
# [END howto_operator_external_task_sensor]
child_task2 = EmptyOperator(task_id="child_task2")
child_task1 >> child_task2

# [START howto_operator_external_task_sensor_with_task_group]
child_task2 = ExternalTaskSensor(
task_id="child_task2",
external_dag_id=parent_dag.dag_id,
external_task_group_id='parent_dag_task_group_id',
timeout=600,
allowed_states=['success'],
failed_states=['failed', 'skipped'],
mode="reschedule",
)
# [END howto_operator_external_task_sensor_with_task_group]

child_task3 = EmptyOperator(task_id="child_task3")
child_task1 >> child_task2 >> child_task3
7 changes: 7 additions & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2214,6 +2214,13 @@ def filter_task_group(group, parent_group):
def has_task(self, task_id: str):
return task_id in self.task_dict

def has_task_group(self, task_group_id: str) -> bool:
return task_group_id in self.task_group_dict

@cached_property
def task_group_dict(self):
return {k: v for k, v in self._task_group.get_task_group_dict().items() if k is not None}

def get_task(self, task_id: str, include_subdags: bool = False) -> Operator:
if task_id in self.task_dict:
return self.task_dict[task_id]
Expand Down
125 changes: 92 additions & 33 deletions airflow/sensors/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
from airflow.utils.session import provide_session
from airflow.utils.state import State

if TYPE_CHECKING:
from sqlalchemy.orm import Query


class ExternalDagLink(BaseOperatorLink):
"""
Expand All @@ -54,9 +57,13 @@ def get_link(self, operator, dttm):

class ExternalTaskSensor(BaseSensorOperator):
"""
Waits for a different DAG or a task in a different DAG to complete for a
Waits for a different DAG, a task group, or a task in a different DAG to complete for a
specific logical date.
If both `external_task_group_id` and `external_task_id` are ``None`` (default), the sensor
waits for the DAG.
Values for `external_task_group_id` and `external_task_id` can't be set at the same time.
By default the ExternalTaskSensor will wait for the external task to
succeed, at which point it will also succeed. However, by default it will
*not* fail if the external task fails, but will continue to check the status
Expand All @@ -78,7 +85,7 @@ class ExternalTaskSensor(BaseSensorOperator):
:param external_dag_id: The dag_id that contains the task you want to
wait for
:param external_task_id: The task_id that contains the task you want to
wait for. If ``None`` (default value) the sensor waits for the DAG
wait for.
:param external_task_ids: The list of task_ids that you want to wait for.
If ``None`` (default value) the sensor waits for the DAG. Either
external_task_id or external_task_ids can be passed to
Expand Down Expand Up @@ -111,6 +118,7 @@ def __init__(
external_dag_id: str,
external_task_id: Optional[str] = None,
external_task_ids: Optional[Collection[str]] = None,
external_task_group_id: Optional[str] = None,
allowed_states: Optional[Iterable[str]] = None,
failed_states: Optional[Iterable[str]] = None,
execution_delta: Optional[datetime.timedelta] = None,
Expand Down Expand Up @@ -139,18 +147,25 @@ def __init__(
if external_task_id is not None:
external_task_ids = [external_task_id]

if external_task_ids:
if external_task_group_id and external_task_ids:
raise ValueError(
"Values for `external_task_group_id` and `external_task_id` or `external_task_ids` "
"can't be set at the same time"
)

if external_task_ids or external_task_group_id:
if not total_states <= set(State.task_states):
raise ValueError(
f'Valid values for `allowed_states` and `failed_states` '
f'when `external_task_id` or `external_task_ids` is not `None`: {State.task_states}'
f'when `external_task_id` or `external_task_ids` or `external_task_group_id` '
f'is not `None`: {State.task_states}'
)
if len(external_task_ids) > len(set(external_task_ids)):
if external_task_ids and len(external_task_ids) > len(set(external_task_ids)):
raise ValueError('Duplicate task_ids passed in external_task_ids parameter')
elif not total_states <= set(State.dag_states):
raise ValueError(
f'Valid values for `allowed_states` and `failed_states` '
f'when `external_task_id` is `None`: {State.dag_states}'
f'when `external_task_id` and `external_task_group_id` is `None`: {State.dag_states}'
)

if execution_delta is not None and execution_date_fn is not None:
Expand All @@ -164,27 +179,39 @@ def __init__(
self.external_dag_id = external_dag_id
self.external_task_id = external_task_id
self.external_task_ids = external_task_ids
self.external_task_group_id = external_task_group_id
self.check_existence = check_existence
self._has_checked_existence = False

@provide_session
def poke(self, context, session=None):
def _get_dttm_filter(self, context):
if self.execution_delta:
dttm = context['logical_date'] - self.execution_delta
elif self.execution_date_fn:
dttm = self._handle_execution_date_fn(context=context)
else:
dttm = context['logical_date']
return dttm if isinstance(dttm, list) else [dttm]

dttm_filter = dttm if isinstance(dttm, list) else [dttm]
@provide_session
def poke(self, context, session=None):
dttm_filter = self._get_dttm_filter(context)
serialized_dttm_filter = ','.join(dt.isoformat() for dt in dttm_filter)

self.log.info(
'Poking for tasks %s in dag %s on %s ... ',
self.external_task_ids,
self.external_dag_id,
serialized_dttm_filter,
)
if self.external_task_ids:
self.log.info(
'Poking for tasks %s in dag %s on %s ... ',
self.external_task_ids,
self.external_dag_id,
serialized_dttm_filter,
)

if self.external_task_group_id:
self.log.info(
"Poking for task_group '%s' in dag '%s' on %s ... ",
self.external_task_group_id,
self.external_dag_id,
serialized_dttm_filter,
)

# In poke mode this will check dag existence only once
if self.check_existence and not self._has_checked_existence:
Expand All @@ -207,6 +234,17 @@ def poke(self, context, session=None):
f'Some of the external tasks {self.external_task_ids} '
f'in DAG {self.external_dag_id} failed.'
)
elif self.external_task_group_id:
if self.soft_fail:
raise AirflowSkipException(
f"The external task_group '{self.external_task_group_id}' "
f"in DAG '{self.external_dag_id}' failed. Skipping due to soft_fail."
)
raise AirflowException(
f"The external task_group '{self.external_task_group_id}' "
f"in DAG '{self.external_dag_id}' failed."
)

else:
if self.soft_fail:
raise AirflowSkipException(
Expand All @@ -217,7 +255,7 @@ def poke(self, context, session=None):
return count_allowed == len(dttm_filter)

def _check_for_existence(self, session) -> None:
dag_to_wait = session.query(DagModel).filter(DagModel.dag_id == self.external_dag_id).first()
dag_to_wait = DagModel.get_current(self.external_dag_id, session)

if not dag_to_wait:
raise AirflowException(f'The external DAG {self.external_dag_id} does not exist.')
Expand All @@ -233,6 +271,15 @@ def _check_for_existence(self, session) -> None:
f'The external task {external_task_id} in '
f'DAG {self.external_dag_id} does not exist.'
)

if self.external_task_group_id:
refreshed_dag_info = DagBag(dag_to_wait.fileloc).get_dag(self.external_dag_id)
if not refreshed_dag_info.has_task_group(self.external_task_group_id):
raise AirflowException(
f"The external task group '{self.external_task_group_id}' in "
f"DAG '{self.external_dag_id}' does not exist."
)

self._has_checked_existence = True

def get_count(self, dttm_filter, session, states) -> int:
Expand All @@ -251,28 +298,40 @@ def get_count(self, dttm_filter, session, states) -> int:

if self.external_task_ids:
count = (
session.query(func.count()) # .count() is inefficient
.filter(
TI.dag_id == self.external_dag_id,
TI.task_id.in_(self.external_task_ids),
TI.state.in_(states),
TI.execution_date.in_(dttm_filter),
)
self._count_query(TI, session, states, dttm_filter)
.filter(TI.task_id.in_(self.external_task_ids))
.scalar()
)
count = count / len(self.external_task_ids)
else:
) / len(self.external_task_ids)
elif self.external_task_group_id:
external_task_group_task_ids = self.get_external_task_group_task_ids(session)
count = (
session.query(func.count())
.filter(
DR.dag_id == self.external_dag_id,
DR.state.in_(states),
DR.execution_date.in_(dttm_filter),
)
self._count_query(TI, session, states, dttm_filter)
.filter(TI.task_id.in_(external_task_group_task_ids))
.scalar()
)
) / len(external_task_group_task_ids)
else:
count = self._count_query(DR, session, states, dttm_filter).scalar()
return count

def _count_query(self, model, session, states, dttm_filter) -> "Query":
query = session.query(func.count()).filter(
model.dag_id == self.external_dag_id,
model.state.in_(states), # pylint: disable=no-member
model.execution_date.in_(dttm_filter),
)
return query

def get_external_task_group_task_ids(self, session):
refreshed_dag_info = DagBag(read_dags_from_db=True).get_dag(self.external_dag_id, session)
task_group = refreshed_dag_info.task_group_dict.get(self.external_task_group_id)

if task_group:
return [task.task_id for task in task_group]

# returning default task_id as group_id itself, this will avoid any failure in case of
# 'check_existence=False' and will fail on timeout
return [self.external_task_group_id]

def _handle_execution_date_fn(self, context) -> Any:
"""
This function is to handle backwards compatibility with how this operator was
Expand Down
10 changes: 10 additions & 0 deletions docs/apache-airflow/howto/operator/external_task_sensor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ via ``allowed_states`` and ``failed_states`` parameters.
:start-after: [START howto_operator_external_task_sensor]
:end-before: [END howto_operator_external_task_sensor]

ExternalTaskSensor with task_group dependency
---------------------------------------------
In Addition, we can also use the :class:`~airflow.sensors.external_task.ExternalTaskSensor` to make tasks on a DAG
wait for another ``task_group`` on a different DAG for a specific ``execution_date``.

.. exampleinclude:: /../../airflow/example_dags/example_external_task_marker_dag.py
:language: python
:dedent: 4
:start-after: [START howto_operator_external_task_sensor_with_task_group]
:end-before: [END howto_operator_external_task_sensor_with_task_group]


ExternalTaskMarker
Expand Down
Loading

0 comments on commit bc04c5f

Please sign in to comment.