diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py index 76988f0f82aee..f249199651849 100644 --- a/airflow/api/common/mark_tasks.py +++ b/airflow/api/common/mark_tasks.py @@ -83,6 +83,9 @@ def set_state( if not run_id: raise ValueError("Received tasks with no run_id") + if TYPE_CHECKING: + assert isinstance(dag, DAG) + dag_run_ids = get_run_ids(dag, run_id, future, past, session=session) task_id_map_index_list = list(find_task_relatives(tasks, downstream, upstream)) # now look for the task instances that are affected diff --git a/airflow/api_connexion/schemas/common_schema.py b/airflow/api_connexion/schemas/common_schema.py index 569a745a62f52..9c08d6932af28 100644 --- a/airflow/api_connexion/schemas/common_schema.py +++ b/airflow/api_connexion/schemas/common_schema.py @@ -25,7 +25,7 @@ from dateutil import relativedelta from marshmallow import Schema, fields, validate -from airflow.models.mappedoperator import MappedOperator +from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.serialization.serialized_objects import SerializedBaseOperator @@ -128,7 +128,7 @@ def _get_module(self, obj): def _get_class_name(self, obj): if isinstance(obj, (MappedOperator, SerializedBaseOperator)): - return obj._task_type + return obj.task_type if isinstance(obj, type): return obj.__name__ return type(obj).__name__ diff --git a/airflow/api_connexion/schemas/task_schema.py b/airflow/api_connexion/schemas/task_schema.py index 1329a351d7b76..648187c0b70cc 100644 --- a/airflow/api_connexion/schemas/task_schema.py +++ b/airflow/api_connexion/schemas/task_schema.py @@ -26,7 +26,6 @@ TimeDeltaSchema, WeightRuleField, ) -from airflow.models.mappedoperator import MappedOperator if TYPE_CHECKING: from airflow.models.operator import Operator @@ -62,7 +61,7 @@ class TaskSchema(Schema): template_fields = fields.List(fields.String(), dump_only=True) downstream_task_ids = fields.List(fields.String(), dump_only=True) params = fields.Method("_get_params", dump_only=True) - is_mapped = fields.Method("_get_is_mapped", dump_only=True) + is_mapped = fields.Boolean(dump_only=True) doc_md = fields.String(dump_only=True) @staticmethod @@ -80,10 +79,6 @@ def _get_params(obj): params = obj.params return {k: v.dump() for k, v in params.items()} - @staticmethod - def _get_is_mapped(obj): - return isinstance(obj, MappedOperator) - class TaskCollection(NamedTuple): """List of Tasks with metadata.""" diff --git a/airflow/api_fastapi/core_api/datamodels/tasks.py b/airflow/api_fastapi/core_api/datamodels/tasks.py index 0806d4453c49a..13a9af7043680 100644 --- a/airflow/api_fastapi/core_api/datamodels/tasks.py +++ b/airflow/api_fastapi/core_api/datamodels/tasks.py @@ -26,29 +26,18 @@ from airflow.api_fastapi.common.types import TimeDeltaWithValidation from airflow.api_fastapi.core_api.base import BaseModel -from airflow.models.mappedoperator import MappedOperator -from airflow.serialization.serialized_objects import SerializedBaseOperator, encode_priority_weight_strategy +from airflow.serialization.serialized_objects import encode_priority_weight_strategy from airflow.task.priority_strategy import PriorityWeightStrategy def _get_class_ref(obj) -> dict[str, str | None]: """Return the class_ref dict for obj.""" - is_mapped_or_serialized = isinstance(obj, (MappedOperator, SerializedBaseOperator)) - - module_path = None - if is_mapped_or_serialized: - module_path = obj._task_module - else: + module_path = getattr(obj, "_task_module", None) + if module_path is None: module_type = inspect.getmodule(obj) module_path = module_type.__name__ if module_type else None - class_name = None - if is_mapped_or_serialized: - class_name = obj._task_type - elif obj.__class__ is type: - class_name = obj.__name__ - else: - class_name = type(obj).__name__ + class_name = obj.task_type return { "module_path": module_path, @@ -89,9 +78,7 @@ class TaskResponse(BaseModel): @model_validator(mode="before") @classmethod def validate_model(cls, task: Any) -> Any: - task.__dict__.update( - {"class_ref": _get_class_ref(task), "is_mapped": isinstance(task, MappedOperator)} - ) + task.__dict__.update({"class_ref": _get_class_ref(task), "is_mapped": task.is_mapped}) return task @field_validator("weight_rule", mode="before") diff --git a/airflow/api_fastapi/core_api/services/ui/grid.py b/airflow/api_fastapi/core_api/services/ui/grid.py index 739cf25b15153..74c44318a5a7a 100644 --- a/airflow/api_fastapi/core_api/services/ui/grid.py +++ b/airflow/api_fastapi/core_api/services/ui/grid.py @@ -32,11 +32,12 @@ ) from airflow.configuration import conf from airflow.exceptions import AirflowConfigException -from airflow.models import MappedOperator -from airflow.models.baseoperator import BaseOperator +from airflow.models.baseoperator import BaseOperator as DBBaseOperator from airflow.models.taskmap import TaskMap +from airflow.sdk import BaseOperator +from airflow.sdk.definitions.mappedoperator import MappedOperator +from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup from airflow.utils.state import TaskInstanceState -from airflow.utils.task_group import MappedTaskGroup, TaskGroup @cache @@ -139,7 +140,7 @@ def _get_total_task_count( node if isinstance(node, int) else ( - node.get_mapped_ti_count(run_id=run_id, session=session) + DBBaseOperator.get_mapped_ti_count(node, run_id=run_id, session=session) or 0 if isinstance(node, (MappedTaskGroup, MappedOperator)) else node ) diff --git a/airflow/cli/commands/remote_commands/task_command.py b/airflow/cli/commands/remote_commands/task_command.py index 51198af74961a..28ecb9d7bd798 100644 --- a/airflow/cli/commands/remote_commands/task_command.py +++ b/airflow/cli/commands/remote_commands/task_command.py @@ -206,6 +206,11 @@ def _get_ti( dag = task.dag if dag is None: raise ValueError("Cannot get task instance for a task not assigned to a DAG") + if not isinstance(dag, DAG): + # TODO: Task-SDK: Shouldn't really happen, and this command will go away before 3.0 + raise ValueError( + f"We need a {DAG.__module__}.DAG, but we got {type(dag).__module__}.{type(dag).__name__}!" + ) # this check is imperfect because diff dags could have tasks with same name # but in a task, dag_id is a property that accesses its dag, and we don't diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 3e4589b3c78db..9db6d058adeb5 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -41,11 +41,11 @@ ListOfDictsExpandInput, is_mappable, ) -from airflow.models.mappedoperator import MappedOperator, ensure_xcomarg_return_value -from airflow.models.xcom_arg import XComArg from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext from airflow.sdk.definitions.asset import Asset from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator +from airflow.sdk.definitions.mappedoperator import MappedOperator, ensure_xcomarg_return_value +from airflow.sdk.definitions.xcom_arg import XComArg from airflow.typing_compat import ParamSpec from airflow.utils import timezone from airflow.utils.context import KNOWN_CONTEXT_KEYS @@ -62,9 +62,9 @@ OperatorExpandArgument, OperatorExpandKwargsArgument, ) - from airflow.models.mappedoperator import ValidationSource from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.dag import DAG + from airflow.sdk.definitions.mappedoperator import ValidationSource from airflow.utils.task_group import TaskGroup diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index 732180d372946..b8fb54f6966fd 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -19,40 +19,36 @@ import datetime import inspect -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Iterable, Sequence from functools import cached_property from typing import TYPE_CHECKING, Any, Callable -import methodtools from sqlalchemy import select from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models.expandinput import NotFullyPopulated -from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator as TaskSDKAbstractOperator +from airflow.sdk.definitions._internal.abstractoperator import ( + AbstractOperator as TaskSDKAbstractOperator, + NotMapped as NotMapped, # Re-export this for compat +) from airflow.sdk.definitions.context import Context from airflow.utils.db import exists_query from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.sqlalchemy import with_row_locks from airflow.utils.state import State, TaskInstanceState -from airflow.utils.task_group import MappedTaskGroup from airflow.utils.trigger_rule import TriggerRule -from airflow.utils.weight_rule import WeightRule, db_safe_priority +from airflow.utils.weight_rule import db_safe_priority if TYPE_CHECKING: - import jinja2 # Slow imports. from sqlalchemy.orm import Session from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.dag import DAG as SchedulerDAG - from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance - from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.triggers.base import StartTriggerArgs - from airflow.utils.task_group import TaskGroup TaskStateChangeCallback = Callable[[Context], None] @@ -71,19 +67,12 @@ ) MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60) -DEFAULT_WEIGHT_RULE: WeightRule = WeightRule( - conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) -) DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta( "core", "default_task_execution_timeout" ) -class NotMapped(Exception): - """Raise if a task is neither mapped nor has any parent mapped groups.""" - - class AbstractOperator(LoggingMixin, TaskSDKAbstractOperator): """ Common implementation for operators, including unmapped and mapped. @@ -98,124 +87,8 @@ class AbstractOperator(LoggingMixin, TaskSDKAbstractOperator): :meta private: """ - trigger_rule: TriggerRule weight_rule: PriorityWeightStrategy - @property - def on_failure_fail_dagrun(self): - """ - Whether the operator should fail the dagrun on failure. - - :meta private: - """ - return self._on_failure_fail_dagrun - - @on_failure_fail_dagrun.setter - def on_failure_fail_dagrun(self, value): - """ - Setter for on_failure_fail_dagrun property. - - :meta private: - """ - if value is True and self.is_teardown is not True: - raise ValueError( - f"Cannot set task on_failure_fail_dagrun for " - f"'{self.task_id}' because it is not a teardown task." - ) - self._on_failure_fail_dagrun = value - - def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: - """ - Return mapped nodes that are direct dependencies of the current task. - - For now, this walks the entire DAG to find mapped nodes that has this - current task as an upstream. We cannot use ``downstream_list`` since it - only contains operators, not task groups. In the future, we should - provide a way to record an DAG node's all downstream nodes instead. - - Note that this does not guarantee the returned tasks actually use the - current task for task mapping, but only checks those task are mapped - operators, and are downstreams of the current task. - - To get a list of tasks that uses the current task for task mapping, use - :meth:`iter_mapped_dependants` instead. - """ - from airflow.models.mappedoperator import MappedOperator - from airflow.utils.task_group import TaskGroup - - def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]: - """ - Recursively walk children in a task group. - - This yields all direct children (including both tasks and task - groups), and all children of any task groups. - """ - for key, child in group.children.items(): - yield key, child - if isinstance(child, TaskGroup): - yield from _walk_group(child) - - dag = self.get_dag() - if not dag: - raise RuntimeError("Cannot check for mapped dependants when not attached to a DAG") - for key, child in _walk_group(dag.task_group): - if key == self.node_id: - continue - if not isinstance(child, (MappedOperator, MappedTaskGroup)): - continue - if self.node_id in child.upstream_task_ids: - yield child - - def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]: - """ - Return mapped nodes that depend on the current task the expansion. - - For now, this walks the entire DAG to find mapped nodes that has this - current task as an upstream. We cannot use ``downstream_list`` since it - only contains operators, not task groups. In the future, we should - provide a way to record an DAG node's all downstream nodes instead. - """ - return ( - downstream - for downstream in self._iter_all_mapped_downstreams() - if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies()) - ) - - def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: - """ - Return mapped task groups this task belongs to. - - Groups are returned from the innermost to the outmost. - - :meta private: - """ - if (group := self.task_group) is None: - return - # TODO: Task-SDK: this type ignore shouldn't be necessary, revisit once mapping support is fully in the - # SDK - yield from group.iter_mapped_task_groups() # type: ignore[misc] - - def get_closest_mapped_task_group(self) -> MappedTaskGroup | None: - """ - Get the mapped task group "closest" to this task in the DAG. - - :meta private: - """ - return next(self.iter_mapped_task_groups(), None) - - def get_needs_expansion(self) -> bool: - """ - Return true if the task is MappedOperator or is in a mapped task group. - - :meta private: - """ - if self._needs_expansion is None: - if self.get_closest_mapped_task_group() is not None: - self._needs_expansion = True - else: - self._needs_expansion = False - return self._needs_expansion - def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator: """ Get the "normal" operator from current abstract operator. @@ -343,43 +216,6 @@ def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None: return link.get_link(self.unmap(None), ti.dag_run.logical_date) # type: ignore[misc] return link.get_link(self.unmap(None), ti_key=ti.key) - @methodtools.lru_cache(maxsize=None) - def get_parse_time_mapped_ti_count(self) -> int: - """ - Return the number of mapped task instances that can be created on DAG run creation. - - This only considers literal mapped arguments, and would return *None* - when any non-literal values are used for mapping. - - :raise NotFullyPopulated: If non-literal mapped arguments are encountered. - :raise NotMapped: If the operator is neither mapped, nor has any parent - mapped task groups. - :return: Total number of mapped TIs this task should have. - """ - group = self.get_closest_mapped_task_group() - if group is None: - raise NotMapped - return group.get_parse_time_mapped_ti_count() - - def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: - """ - Return the number of mapped TaskInstances that can be created at run time. - - This considers both literal and non-literal mapped arguments, and the - result is therefore available when all depended tasks have finished. The - return value should be identical to ``parse_time_mapped_ti_count`` if - all mapped arguments are literal. - - :raise NotFullyPopulated: If upstream tasks are not all complete yet. - :raise NotMapped: If the operator is neither mapped, nor has any parent - mapped task groups. - :return: Total number of mapped TIs this task should have. - """ - group = self.get_closest_mapped_task_group() - if group is None: - raise NotMapped - return group.get_mapped_ti_count(run_id, session=session) - def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]: """ Create the mapped task instances for mapped task. @@ -390,16 +226,20 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence """ from sqlalchemy import func, or_ - from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.settings import task_instance_mutation_hook if not isinstance(self, (BaseOperator, MappedOperator)): - raise RuntimeError(f"cannot expand unrecognized operator type {type(self).__name__}") + raise RuntimeError( + f"cannot expand unrecognized operator type {type(self).__module__}.{type(self).__name__}" + ) + + from airflow.models.baseoperator import BaseOperator as DBBaseOperator try: - total_length: int | None = self.get_mapped_ti_count(run_id, session=session) + total_length: int | None = DBBaseOperator.get_mapped_ti_count(self, run_id, session=session) except NotFullyPopulated as e: # It's possible that the upstream tasks are not yet done, but we # don't have upstream of upstreams in partial DAGs (possible in the @@ -509,29 +349,3 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence ti.state = TaskInstanceState.REMOVED session.flush() return all_expanded_tis, total_expanded_ti_count - 1 - - def render_template_fields( - self, - context: Context, - jinja_env: jinja2.Environment | None = None, - ) -> None: - """ - Template all attributes listed in *self.template_fields*. - - If the operator is mapped, this should return the unmapped, fully - rendered, and map-expanded operator. The mapped operator should not be - modified. However, *context* may be modified in-place to reference the - unmapped operator for template rendering. - - If the operator is not mapped, this should modify the operator in-place. - """ - raise NotImplementedError() - - def __enter__(self): - if not self.is_setup and not self.is_teardown: - raise AirflowException("Only setup/teardown tasks can be used as context managers.") - SetupTeardownContext.push_setup_teardown_task(self) - return SetupTeardownContext - - def __exit__(self, exc_type, exc_val, exc_tb): - SetupTeardownContext.set_work_task_roots_and_leaves() diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 88a656eb7b1ed..b4931ac16d324 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -23,13 +23,12 @@ from __future__ import annotations -import collections.abc -import contextlib import functools import logging -from collections.abc import Collection, Iterable, Sequence +import operator +from collections.abc import Collection, Iterable, Iterator, Sequence from datetime import datetime, timedelta -from functools import wraps +from functools import singledispatchmethod, wraps from threading import local from types import FunctionType from typing import ( @@ -53,36 +52,27 @@ TaskDeferred, ) from airflow.lineage import apply_lineage, prepare_lineage + +# Keeping this file at all is a temp thing as we migrate the repo to the task sdk as the base, but to keep +# main working and useful for others to develop against we use the TaskSDK here but keep this file around from airflow.models.abstractoperator import ( - DEFAULT_EXECUTOR, - DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, - DEFAULT_OWNER, - DEFAULT_POOL_SLOTS, - DEFAULT_PRIORITY_WEIGHT, - DEFAULT_QUEUE, - DEFAULT_RETRIES, - DEFAULT_RETRY_DELAY, - DEFAULT_TASK_EXECUTION_TIMEOUT, - DEFAULT_TRIGGER_RULE, - DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, - DEFAULT_WEIGHT_RULE, AbstractOperator, + NotMapped, ) from airflow.models.base import _sentinel -from airflow.models.mappedoperator import OperatorPartial, validate_mapping_kwargs from airflow.models.taskinstance import TaskInstance, clear_task_instances from airflow.models.taskmixin import DependencyMixin from airflow.models.trigger import TRIGGER_FAIL_REPR, TriggerFailureReason +from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator as TaskSDKAbstractOperator from airflow.sdk.definitions.baseoperator import ( BaseOperatorMeta as TaskSDKBaseOperatorMeta, - get_merged_defaults, + get_merged_defaults as get_merged_defaults, # Re-export for compat ) - -# Keeping this file at all is a temp thing as we migrate the repo to the task sdk as the base, but to keep -# main working and useful for others to develop against we use the TaskSDK here but keep this file around from airflow.sdk.definitions.context import Context -from airflow.sdk.definitions.dag import DAG, BaseOperator as TaskSDKBaseOperator +from airflow.sdk.definitions.dag import BaseOperator as TaskSDKBaseOperator from airflow.sdk.definitions.edges import EdgeModifier as TaskSDKEdgeModifier +from airflow.sdk.definitions.mappedoperator import MappedOperator +from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup from airflow.serialization.enums import DagAttributeTypes from airflow.ti_deps.deps.mapped_task_upstream_dep import MappedTaskUpstreamDep from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep @@ -95,27 +85,19 @@ from airflow.utils.operator_helpers import ExecutionCallableRunner from airflow.utils.operator_resources import Resources from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.types import NOTSET, DagRunTriggeredByType +from airflow.utils.types import DagRunTriggeredByType from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: - from types import ClassMethodDescriptorType - from sqlalchemy.orm import Session from airflow.models.abstractoperator import TaskStateChangeCallback from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.dag import DAG as SchedulerDAG from airflow.models.operator import Operator - from airflow.task.priority_strategy import PriorityWeightStrategy + from airflow.sdk.definitions.node import DAGNode from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.triggers.base import BaseTrigger, StartTriggerArgs - from airflow.utils.types import ArgNotSet - - -# Todo: AIP-44: Once we get rid of AIP-44 we can remove this. But without this here pydantic fails to resolve -# types for serialization -from airflow.utils.task_group import TaskGroup # noqa: TC001 TaskPreExecuteHook = Callable[[Context], None] TaskPostExecuteHook = Callable[[Context, Any], None] @@ -153,193 +135,6 @@ def coerce_resources(resources: dict[str, Any] | None) -> Resources | None: return Resources(**resources) -class _PartialDescriptor: - """A descriptor that guards against ``.partial`` being called on Task objects.""" - - class_method: ClassMethodDescriptorType | None = None - - def __get__( - self, obj: BaseOperator, cls: type[BaseOperator] | None = None - ) -> Callable[..., OperatorPartial]: - # Call this "partial" so it looks nicer in stack traces. - def partial(**kwargs): - raise TypeError("partial can only be called on Operator classes, not Tasks themselves") - - if obj is not None: - return partial - return self.class_method.__get__(cls, cls) - - -_PARTIAL_DEFAULTS: dict[str, Any] = { - "map_index_template": None, - "owner": DEFAULT_OWNER, - "trigger_rule": DEFAULT_TRIGGER_RULE, - "depends_on_past": False, - "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, - "wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, - "wait_for_downstream": False, - "retries": DEFAULT_RETRIES, - "executor": DEFAULT_EXECUTOR, - "queue": DEFAULT_QUEUE, - "pool_slots": DEFAULT_POOL_SLOTS, - "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT, - "retry_delay": DEFAULT_RETRY_DELAY, - "retry_exponential_backoff": False, - "priority_weight": DEFAULT_PRIORITY_WEIGHT, - "weight_rule": DEFAULT_WEIGHT_RULE, - "inlets": [], - "outlets": [], - "allow_nested_operators": True, -} - - -# This is what handles the actual mapping. - -if TYPE_CHECKING: - - def partial( - operator_class: type[BaseOperator], - *, - task_id: str, - dag: DAG | None = None, - task_group: TaskGroup | None = None, - start_date: datetime | ArgNotSet = NOTSET, - end_date: datetime | ArgNotSet = NOTSET, - owner: str | ArgNotSet = NOTSET, - email: None | str | Iterable[str] | ArgNotSet = NOTSET, - params: collections.abc.MutableMapping | None = None, - resources: dict[str, Any] | None | ArgNotSet = NOTSET, - trigger_rule: str | ArgNotSet = NOTSET, - depends_on_past: bool | ArgNotSet = NOTSET, - ignore_first_depends_on_past: bool | ArgNotSet = NOTSET, - wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET, - wait_for_downstream: bool | ArgNotSet = NOTSET, - retries: int | None | ArgNotSet = NOTSET, - queue: str | ArgNotSet = NOTSET, - pool: str | ArgNotSet = NOTSET, - pool_slots: int | ArgNotSet = NOTSET, - execution_timeout: timedelta | None | ArgNotSet = NOTSET, - max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET, - retry_delay: timedelta | float | ArgNotSet = NOTSET, - retry_exponential_backoff: bool | ArgNotSet = NOTSET, - priority_weight: int | ArgNotSet = NOTSET, - weight_rule: str | PriorityWeightStrategy | ArgNotSet = NOTSET, - sla: timedelta | None | ArgNotSet = NOTSET, - map_index_template: str | None | ArgNotSet = NOTSET, - max_active_tis_per_dag: int | None | ArgNotSet = NOTSET, - max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET, - on_execute_callback: None - | TaskStateChangeCallback - | list[TaskStateChangeCallback] - | ArgNotSet = NOTSET, - on_failure_callback: None - | TaskStateChangeCallback - | list[TaskStateChangeCallback] - | ArgNotSet = NOTSET, - on_success_callback: None - | TaskStateChangeCallback - | list[TaskStateChangeCallback] - | ArgNotSet = NOTSET, - on_retry_callback: None - | TaskStateChangeCallback - | list[TaskStateChangeCallback] - | ArgNotSet = NOTSET, - on_skipped_callback: None - | TaskStateChangeCallback - | list[TaskStateChangeCallback] - | ArgNotSet = NOTSET, - run_as_user: str | None | ArgNotSet = NOTSET, - executor: str | None | ArgNotSet = NOTSET, - executor_config: dict | None | ArgNotSet = NOTSET, - inlets: Any | None | ArgNotSet = NOTSET, - outlets: Any | None | ArgNotSet = NOTSET, - doc: str | None | ArgNotSet = NOTSET, - doc_md: str | None | ArgNotSet = NOTSET, - doc_json: str | None | ArgNotSet = NOTSET, - doc_yaml: str | None | ArgNotSet = NOTSET, - doc_rst: str | None | ArgNotSet = NOTSET, - task_display_name: str | None | ArgNotSet = NOTSET, - logger_name: str | None | ArgNotSet = NOTSET, - allow_nested_operators: bool = True, - **kwargs, - ) -> OperatorPartial: ... -else: - - def partial( - operator_class: type[BaseOperator], - *, - task_id: str, - dag: DAG | None = None, - task_group: TaskGroup | None = None, - params: collections.abc.MutableMapping | None = None, - **kwargs, - ): - from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext - - validate_mapping_kwargs(operator_class, "partial", kwargs) - - dag = dag or DagContext.get_current() - if dag: - task_group = task_group or TaskGroupContext.get_current(dag) - if task_group: - task_id = task_group.child_id(task_id) - - # Merge DAG and task group level defaults into user-supplied values. - dag_default_args, partial_params = get_merged_defaults( - dag=dag, - task_group=task_group, - task_params=params, - task_default_args=kwargs.pop("default_args", None), - ) - - # Create partial_kwargs from args and kwargs - partial_kwargs: dict[str, Any] = { - "task_id": task_id, - "dag": dag, - "task_group": task_group, - **kwargs, - } - - # Inject DAG-level default args into args provided to this function. - partial_kwargs.update( - (k, v) for k, v in dag_default_args.items() if partial_kwargs.get(k, NOTSET) is NOTSET - ) - - # Fill fields not provided by the user with default values. - for k, v in _PARTIAL_DEFAULTS.items(): - partial_kwargs.setdefault(k, v) - - # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). - if "task_concurrency" in kwargs: # Reject deprecated option. - raise TypeError("unexpected argument: task_concurrency") - if wait := partial_kwargs.get("wait_for_downstream", False): - partial_kwargs["depends_on_past"] = wait - if start_date := partial_kwargs.get("start_date", None): - partial_kwargs["start_date"] = timezone.convert_to_utc(start_date) - if end_date := partial_kwargs.get("end_date", None): - partial_kwargs["end_date"] = timezone.convert_to_utc(end_date) - if partial_kwargs["pool_slots"] < 1: - dag_str = "" - if dag: - dag_str = f" in dag {dag.dag_id}" - raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") - if retries := partial_kwargs.get("retries"): - partial_kwargs["retries"] = parse_retries(retries) - partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay") - if partial_kwargs.get("max_retry_delay", None) is not None: - partial_kwargs["max_retry_delay"] = coerce_timedelta( - partial_kwargs["max_retry_delay"], - key="max_retry_delay", - ) - partial_kwargs.setdefault("executor_config", {}) - - return OperatorPartial( - operator_class=operator_class, - kwargs=partial_kwargs, - params=partial_params, - ) - - class ExecutorSafeguard: """ The ExecutorSafeguard decorator. @@ -388,12 +183,6 @@ def __new__(cls, name, bases, namespace, **kwargs): if callable(execute_method) and not getattr(execute_method, "__isabstractmethod__", False): namespace["execute"] = ExecutorSafeguard().decorator(execute_method) new_cls = super().__new__(cls, name, bases, namespace, **kwargs) - with contextlib.suppress(KeyError): - # Update the partial descriptor with the class method, so it calls the actual function - # (but let subclasses override it if they need to) - partial_desc = vars(new_cls)["partial"] - if isinstance(partial_desc, _PartialDescriptor): - partial_desc.class_method = classmethod(partial) return new_cls @@ -653,8 +442,6 @@ def dag(self, val: SchedulerDAG): # For type checking only ... - partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore - @classmethod @methodtools.lru_cache(maxsize=None) def get_serialized_fields(cls): @@ -1018,6 +805,89 @@ def expand_start_trigger_args(self, *, context: Context, session: Session) -> St """ return self.start_trigger_args + if TYPE_CHECKING: + + @classmethod + def get_mapped_ti_count( + cls, node: DAGNode | MappedTaskGroup, run_id: str, *, session: Session + ) -> int: + """ + Return the number of mapped TaskInstances that can be created at run time. + + This considers both literal and non-literal mapped arguments, and the + result is therefore available when all depended tasks have finished. The + return value should be identical to ``parse_time_mapped_ti_count`` if + all mapped arguments are literal. + + :raise NotFullyPopulated: If upstream tasks are not all complete yet. + :raise NotMapped: If the operator is neither mapped, nor has any parent + mapped task groups. + :return: Total number of mapped TIs this task should have. + """ + else: + + @singledispatchmethod + @classmethod + def get_mapped_ti_count(cls, task: DAGNode, run_id: str, *, session: Session) -> int: + raise NotImplementedError(f"Not implemented for {type(task)}") + + # https://github.com/python/cpython/issues/86153 + # WHile we support Python 3.9 we can't rely on the type hint, we need to pass the type explicitly to + # register. + @get_mapped_ti_count.register(TaskSDKAbstractOperator) + @classmethod + def _(cls, task: TaskSDKAbstractOperator, run_id: str, *, session: Session) -> int: + group = task.get_closest_mapped_task_group() + if group is None: + raise NotMapped() + return cls.get_mapped_ti_count(group, run_id, session=session) + + @get_mapped_ti_count.register(MappedOperator) + @classmethod + def _(cls, task: MappedOperator, run_id: str, *, session: Session) -> int: + from airflow.serialization.serialized_objects import _ExpandInputRef + + exp_input = task._get_specified_expand_input() + if isinstance(exp_input, _ExpandInputRef): + exp_input = exp_input.deref(task.dag) + current_count = exp_input.get_total_map_length(run_id, session=session) + + group = task.get_closest_mapped_task_group() + if group is None: + return current_count + parent_count = cls.get_mapped_ti_count(group, run_id, session=session) + return parent_count * current_count + + @get_mapped_ti_count.register(TaskGroup) + @classmethod + def _(cls, group: TaskGroup, run_id: str, *, session: Session) -> int: + """ + Return the number of instances a task in this group should be mapped to at run time. + + This considers both literal and non-literal mapped arguments, and the + result is therefore available when all depended tasks have finished. The + return value should be identical to ``parse_time_mapped_ti_count`` if + all mapped arguments are literal. + + If this group is inside mapped task groups, all the nested counts are + multiplied and accounted. + + :raise NotFullyPopulated: If upstream tasks are not all complete yet. + :return: Total number of mapped TIs this task should have. + """ + + def iter_mapped_task_groups(group) -> Iterator[MappedTaskGroup]: + while group is not None: + if isinstance(group, MappedTaskGroup): + yield group + group = group.parent_group + + groups = iter_mapped_task_groups(group) + return functools.reduce( + operator.mul, + (g._expand_input.get_total_map_length(run_id, session=session) for g in groups), + ) + def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None: r""" diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 7da56bca624ab..9b88e6d70f71c 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1600,6 +1600,7 @@ def test( :param mark_success_pattern: regex of task_ids to mark as success instead of running :param session: database connection (optional) """ + from airflow.serialization.serialized_objects import SerializedDAG def add_logger_if_needed(ti: TaskInstance): """ @@ -1642,8 +1643,10 @@ def add_logger_if_needed(ti: TaskInstance): self.log.debug("Getting dagrun for dag %s", self.dag_id) logical_date = timezone.coerce_datetime(logical_date) data_interval = self.timetable.infer_manual_data_interval(run_after=logical_date) + scheduler_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(self)) + dr: DagRun = _get_or_create_dagrun( - dag=self, + dag=scheduler_dag, start_date=logical_date, logical_date=logical_date, run_id=DagRun.generate_run_id(DagRunType.MANUAL, logical_date), diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 8060901f948a3..8795c0507087b 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -66,6 +66,7 @@ from airflow.models.expandinput import NotFullyPopulated from airflow.models.taskinstance import TaskInstance as TI from airflow.models.tasklog import LogTemplate +from airflow.models.taskmap import TaskMap from airflow.stats import Stats from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES @@ -85,6 +86,7 @@ from sqlalchemy.orm import Query, Session + from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG from airflow.models.operator import Operator from airflow.typing_compat import Literal @@ -1134,15 +1136,15 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: if ti.map_index >= 0: # Already expanded, we're good. return None - from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator - if isinstance(ti.task, MappedOperator): + if isinstance(ti.task, TaskSDKMappedOperator): # If we get here, it could be that we are moving from non-mapped to mapped # after task instance clearing or this ti is not yet expanded. Safe to clear # the db references. ti.clear_db_references(session=session) try: - expanded_tis, _ = ti.task.expand_mapped_task(self.run_id, session=session) + expanded_tis, _ = TaskMap.expand_mapped_task(ti.task, self.run_id, session=session) except NotMapped: # Not a mapped task, nothing needed. return None if expanded_tis: @@ -1155,7 +1157,7 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: revised_map_index_task_ids = set() for schedulable in itertools.chain(schedulable_tis, additional_tis): if TYPE_CHECKING: - assert schedulable.task + assert isinstance(schedulable.task, BaseOperator) old_state = schedulable.state if not schedulable.are_dependencies_met(session=session, dep_context=dep_context): old_states[schedulable.key] = old_state @@ -1327,6 +1329,8 @@ def _check_for_removed_or_restored_tasks( :return: Task IDs in the DAG run """ + from airflow.models.baseoperator import BaseOperator + tis = self.get_task_instances(session=session) # check for removed or restored tasks @@ -1362,7 +1366,7 @@ def _check_for_removed_or_restored_tasks( except NotFullyPopulated: # What if it is _now_ dynamically mapped, but wasn't before? try: - total_length = task.get_mapped_ti_count(self.run_id, session=session) + total_length = BaseOperator.get_mapped_ti_count(task, self.run_id, session=session) except NotFullyPopulated: # Not all upstreams finished, so we can't tell what should be here. Remove everything. if ti.map_index >= 0: @@ -1462,10 +1466,13 @@ def _create_tasks( :param tasks: Tasks to create jobs for in the DAG run :param task_creator: Function to create task instances """ + from airflow.models.baseoperator import BaseOperator + map_indexes: Iterable[int] for task in tasks: try: - count = task.get_mapped_ti_count(self.run_id, session=session) + breakpoint + count = BaseOperator.get_mapped_ti_count(task, self.run_id, session=session) except (NotMapped, NotFullyPopulated): map_indexes = (-1,) else: @@ -1531,10 +1538,11 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) -> we delay expansion to the "last resort". See comments at the call site for more details. """ + from airflow.models.baseoperator import BaseOperator from airflow.settings import task_instance_mutation_hook try: - total_length = task.get_mapped_ti_count(self.run_id, session=session) + total_length = BaseOperator.get_mapped_ti_count(task, self.run_id, session=session) except NotMapped: return # Not a mapped task, don't need to do anything. except NotFullyPopulated: @@ -1613,7 +1621,7 @@ def schedule_tis( schedulable_ti_ids = [] for ti in schedulable_tis: if TYPE_CHECKING: - assert ti.task + assert isinstance(ti.task, BaseOperator) if ( ti.task.inherits_from_empty_operator and not ti.task.on_execute_callback diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index fcbb55dc3d24d..8fb35f7032965 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -31,8 +31,8 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.operator import Operator from airflow.models.xcom_arg import XComArg + from airflow.sdk.types import Operator from airflow.serialization.serialized_objects import _ExpandInputRef from airflow.typing_compat import TypeGuard @@ -59,11 +59,6 @@ class MappedArgument(ResolveMixin): _input: ExpandInput _key: str - def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: - # TODO (AIP-42): Implement run-time task map length inspection. This is - # needed when we implement task mapping inside a mapped task group. - raise NotImplementedError() - def iter_references(self) -> Iterable[tuple[Operator, str]]: yield from self._input.iter_references() @@ -145,8 +140,11 @@ def _get_map_lengths(self, run_id: str, *, session: Session) -> dict[str, int]: # TODO: This initiates one database call for each XComArg. Would it be # more efficient to do one single db call and unpack the value here? def _get_length(v: OperatorExpandArgument) -> int | None: + from airflow.models.xcom_arg import get_task_map_length + if _needs_run_time_resolution(v): - return v.get_task_map_length(run_id, session=session) + return get_task_map_length(v, run_id, session=session) + # Unfortunately a user-defined TypeGuard cannot apply negative type # narrowing. https://github.com/python/typing/discussions/1013 if TYPE_CHECKING: @@ -243,9 +241,11 @@ def get_parse_time_mapped_ti_count(self) -> int: raise NotFullyPopulated({"expand_kwargs() argument"}) def get_total_map_length(self, run_id: str, *, session: Session) -> int: + from airflow.models.xcom_arg import get_task_map_length + if isinstance(self.value, collections.abc.Sized): return len(self.value) - length = self.value.get_task_map_length(run_id, session=session) + length = get_task_map_length(self.value, run_id, session=session) if length is None: raise NotFullyPopulated({"expand_kwargs() argument"}) return length diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 8c61bdc42fd11..e7352ad1323d3 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -17,226 +17,22 @@ # under the License. from __future__ import annotations -import collections.abc import contextlib import copy -import warnings -from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Union +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any import attrs -import methodtools -from airflow.models.abstractoperator import ( - DEFAULT_EXECUTOR, - DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, - DEFAULT_OWNER, - DEFAULT_POOL_SLOTS, - DEFAULT_PRIORITY_WEIGHT, - DEFAULT_QUEUE, - DEFAULT_RETRIES, - DEFAULT_RETRY_DELAY, - DEFAULT_TRIGGER_RULE, - DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, - DEFAULT_WEIGHT_RULE, - AbstractOperator, - NotMapped, -) -from airflow.models.expandinput import ( - DictOfListsExpandInput, - ListOfDictsExpandInput, - is_mappable, -) -from airflow.models.pool import Pool -from airflow.serialization.enums import DagAttributeTypes -from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy +from airflow.models.abstractoperator import AbstractOperator +from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator from airflow.triggers.base import StartTriggerArgs -from airflow.utils.context import context_update_for_unmapped -from airflow.utils.helpers import is_container, prevent_duplicates -from airflow.utils.task_instance_session import get_current_task_instance_session -from airflow.utils.types import NOTSET -from airflow.utils.xcom import XCOM_RETURN_KEY +from airflow.utils.helpers import prevent_duplicates if TYPE_CHECKING: - import datetime - from typing import Literal - - import jinja2 # Slow import. - import pendulum from sqlalchemy.orm.session import Session - from airflow.models.abstractoperator import ( - TaskStateChangeCallback, - ) - from airflow.models.baseoperator import BaseOperator - from airflow.models.baseoperatorlink import BaseOperatorLink - from airflow.models.dag import DAG - from airflow.models.expandinput import ( - ExpandInput, - OperatorExpandArgument, - OperatorExpandKwargsArgument, - ) - from airflow.models.operator import Operator - from airflow.models.param import ParamsDict - from airflow.models.xcom_arg import XComArg from airflow.sdk.definitions.context import Context - from airflow.ti_deps.deps.base_ti_dep import BaseTIDep - from airflow.utils.operator_resources import Resources - from airflow.utils.task_group import TaskGroup - from airflow.utils.trigger_rule import TriggerRule - - TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, list[TaskStateChangeCallback]] - - ValidationSource = Literal["expand", "partial"] - - -def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: - # use a dict so order of args is same as code order - unknown_args = value.copy() - for klass in op.mro(): - init = klass.__init__ # type: ignore[misc] - try: - param_names = init._BaseOperatorMeta__param_names - except AttributeError: - continue - for name in param_names: - value = unknown_args.pop(name, NOTSET) - if func != "expand": - continue - if value is NOTSET: - continue - if is_mappable(value): - continue - type_name = type(value).__name__ - error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}" - raise ValueError(error) - if not unknown_args: - return # If we have no args left to check: stop looking at the MRO chain. - - if len(unknown_args) == 1: - error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}" - else: - names = ", ".join(repr(n) for n in unknown_args) - error = f"unexpected keyword arguments {names}" - raise TypeError(f"{op.__name__}.{func}() got {error}") - - -def ensure_xcomarg_return_value(arg: Any) -> None: - from airflow.models.xcom_arg import XComArg - - if isinstance(arg, XComArg): - for operator, key in arg.iter_references(): - if key != XCOM_RETURN_KEY: - raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}") - elif not is_container(arg): - return - elif isinstance(arg, collections.abc.Mapping): - for v in arg.values(): - ensure_xcomarg_return_value(v) - elif isinstance(arg, collections.abc.Iterable): - for v in arg: - ensure_xcomarg_return_value(v) - - -@attrs.define(kw_only=True, repr=False) -class OperatorPartial: - """ - An "intermediate state" returned by ``BaseOperator.partial()``. - - This only exists at DAG-parsing time; the only intended usage is for the - user to call ``.expand()`` on it at some point (usually in a method chain) to - create a ``MappedOperator`` to add into the DAG. - """ - - operator_class: type[BaseOperator] - kwargs: dict[str, Any] - params: ParamsDict | dict - - _expand_called: bool = False # Set when expand() is called to ease user debugging. - - def __attrs_post_init__(self): - validate_mapping_kwargs(self.operator_class, "partial", self.kwargs) - - def __repr__(self) -> str: - args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items()) - return f"{self.operator_class.__name__}.partial({args})" - - def __del__(self): - if not self._expand_called: - try: - task_id = repr(self.kwargs["task_id"]) - except KeyError: - task_id = f"at {hex(id(self))}" - warnings.warn(f"Task {task_id} was never mapped!", category=UserWarning, stacklevel=1) - - def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator: - if not mapped_kwargs: - raise TypeError("no arguments to expand against") - validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs) - prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") - # Since the input is already checked at parse time, we can set strict - # to False to skip the checks on execution. - return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False) - - def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator: - from airflow.models.xcom_arg import XComArg - - if isinstance(kwargs, collections.abc.Sequence): - for item in kwargs: - if not isinstance(item, (XComArg, collections.abc.Mapping)): - raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") - elif not isinstance(kwargs, XComArg): - raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") - return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) - - def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: - from airflow.operators.empty import EmptyOperator - from airflow.sensors.base import BaseSensorOperator - - self._expand_called = True - ensure_xcomarg_return_value(expand_input.value) - - partial_kwargs = self.kwargs.copy() - task_id = partial_kwargs.pop("task_id") - dag = partial_kwargs.pop("dag") - task_group = partial_kwargs.pop("task_group") - start_date = partial_kwargs.pop("start_date", None) - end_date = partial_kwargs.pop("end_date", None) - - try: - operator_name = self.operator_class.custom_operator_name # type: ignore - except AttributeError: - operator_name = self.operator_class.__name__ - - op = MappedOperator( - operator_class=self.operator_class, - expand_input=expand_input, - partial_kwargs=partial_kwargs, - task_id=task_id, - params=self.params, - operator_extra_links=self.operator_class.operator_extra_links, - template_ext=self.operator_class.template_ext, - template_fields=self.operator_class.template_fields, - template_fields_renderers=self.operator_class.template_fields_renderers, - ui_color=self.operator_class.ui_color, - ui_fgcolor=self.operator_class.ui_fgcolor, - is_empty=issubclass(self.operator_class, EmptyOperator), - is_sensor=issubclass(self.operator_class, BaseSensorOperator), - task_module=self.operator_class.__module__, - task_type=self.operator_class.__name__, - operator_name=operator_name, - dag=dag, - task_group=task_group, - start_date=start_date, - end_date=end_date, - disallow_kwargs_override=strict, - # For classic operators, this points to expand_input because kwargs - # to BaseOperator.expand() contribute to operator arguments. - expand_input_attr="expand_input", - start_trigger_args=self.operator_class.start_trigger_args, - start_from_trigger=self.operator_class.start_from_trigger, - ) - return op @attrs.define( @@ -249,422 +45,12 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: # special here (the logic is only important for slots=True), we use Python's # built-in implementation, which works (as proven by good old BaseOperator). getstate_setstate=False, + repr=False, ) -class MappedOperator(AbstractOperator): +# TODO: Task-SDK: Multiple inheritance is a crime. There must be a better way +class MappedOperator(TaskSDKMappedOperator, AbstractOperator): # type: ignore[misc] # It complains about weight_rule being different """Object representing a mapped operator in a DAG.""" - # This attribute serves double purpose. For a "normal" operator instance - # loaded from DAG, this holds the underlying non-mapped operator class that - # can be used to create an unmapped operator for execution. For an operator - # recreated from a serialized DAG, however, this holds the serialized data - # that can be used to unmap this into a SerializedBaseOperator. - operator_class: type[BaseOperator] | dict[str, Any] - - expand_input: ExpandInput - partial_kwargs: dict[str, Any] - - # Needed for serialization. - task_id: str - params: ParamsDict | dict - deps: frozenset[BaseTIDep] = attrs.field(init=False) - operator_extra_links: Collection[BaseOperatorLink] - template_ext: Sequence[str] - template_fields: Collection[str] - template_fields_renderers: dict[str, str] - ui_color: str - ui_fgcolor: str - _is_empty: bool - _is_sensor: bool = False - _task_module: str - _task_type: str - _operator_name: str - start_trigger_args: StartTriggerArgs | None - start_from_trigger: bool - _needs_expansion: bool = True - - dag: DAG | None - task_group: TaskGroup | None - start_date: pendulum.DateTime | None - end_date: pendulum.DateTime | None - upstream_task_ids: set[str] = attrs.field(factory=set, init=False) - downstream_task_ids: set[str] = attrs.field(factory=set, init=False) - - _disallow_kwargs_override: bool - """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``. - - If *False*, values from ``expand_input`` under duplicate keys override those - under corresponding keys in ``partial_kwargs``. - """ - - _expand_input_attr: str - """Where to get kwargs to calculate expansion length against. - - This should be a name to call ``getattr()`` on. - """ - - supports_lineage: bool = False - - HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset( - ("parse_time_mapped_ti_count", "operator_class", "start_trigger_args", "start_from_trigger") - ) - - @deps.default - def _deps(self): - from airflow.models.baseoperator import BaseOperator - - return BaseOperator.deps - - def __hash__(self): - return id(self) - - def __repr__(self): - return f"" - - def __attrs_post_init__(self): - from airflow.models.xcom_arg import XComArg - - if self.get_closest_mapped_task_group() is not None: - raise NotImplementedError("operator expansion in an expanded task group is not yet supported") - - if self.task_group: - self.task_group.add(self) - if self.dag: - self.dag.add_task(self) - XComArg.apply_upstream_relationship(self, self.expand_input.value) - for k, v in self.partial_kwargs.items(): - if k in self.template_fields: - XComArg.apply_upstream_relationship(self, v) - - @methodtools.lru_cache(maxsize=None) - @classmethod - def get_serialized_fields(cls): - # Not using 'cls' here since we only want to serialize base fields. - return (frozenset(attrs.fields_dict(MappedOperator)) | {"task_type"}) - { - "_task_type", - "dag", - "deps", - "expand_input", # This is needed to be able to accept XComArg. - "task_group", - "upstream_task_ids", - "supports_lineage", - "_is_setup", - "_is_teardown", - "_on_failure_fail_dagrun", - } - - @property - def task_type(self) -> str: - """Implementing Operator.""" - return self._task_type - - @property - def operator_name(self) -> str: - return self._operator_name - - @property - def inherits_from_empty_operator(self) -> bool: - """Implementing Operator.""" - return self._is_empty - - @property - def roots(self) -> Sequence[AbstractOperator]: - """Implementing DAGNode.""" - return [self] - - @property - def leaves(self) -> Sequence[AbstractOperator]: - """Implementing DAGNode.""" - return [self] - - @property - def task_display_name(self) -> str: - return self.partial_kwargs.get("task_display_name") or self.task_id - - @property - def owner(self) -> str: # type: ignore[override] - return self.partial_kwargs.get("owner", DEFAULT_OWNER) - - @property - def email(self) -> None | str | Iterable[str]: - return self.partial_kwargs.get("email") - - @property - def map_index_template(self) -> None | str: - return self.partial_kwargs.get("map_index_template") - - @map_index_template.setter - def map_index_template(self, value: str | None) -> None: - self.partial_kwargs["map_index_template"] = value - - @property - def trigger_rule(self) -> TriggerRule: - return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE) - - @trigger_rule.setter - def trigger_rule(self, value): - self.partial_kwargs["trigger_rule"] = value - - @property - def is_setup(self) -> bool: - return bool(self.partial_kwargs.get("is_setup")) - - @is_setup.setter - def is_setup(self, value: bool) -> None: - self.partial_kwargs["is_setup"] = value - - @property - def is_teardown(self) -> bool: - return bool(self.partial_kwargs.get("is_teardown")) - - @is_teardown.setter - def is_teardown(self, value: bool) -> None: - self.partial_kwargs["is_teardown"] = value - - @property - def depends_on_past(self) -> bool: - return bool(self.partial_kwargs.get("depends_on_past")) - - @depends_on_past.setter - def depends_on_past(self, value: bool) -> None: - self.partial_kwargs["depends_on_past"] = value - - @property - def ignore_first_depends_on_past(self) -> bool: - value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST) - return bool(value) - - @ignore_first_depends_on_past.setter - def ignore_first_depends_on_past(self, value: bool) -> None: - self.partial_kwargs["ignore_first_depends_on_past"] = value - - @property - def wait_for_past_depends_before_skipping(self) -> bool: - value = self.partial_kwargs.get( - "wait_for_past_depends_before_skipping", DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING - ) - return bool(value) - - @wait_for_past_depends_before_skipping.setter - def wait_for_past_depends_before_skipping(self, value: bool) -> None: - self.partial_kwargs["wait_for_past_depends_before_skipping"] = value - - @property - def wait_for_downstream(self) -> bool: - return bool(self.partial_kwargs.get("wait_for_downstream")) - - @wait_for_downstream.setter - def wait_for_downstream(self, value: bool) -> None: - self.partial_kwargs["wait_for_downstream"] = value - - @property - def retries(self) -> int: - return self.partial_kwargs.get("retries", DEFAULT_RETRIES) - - @retries.setter - def retries(self, value: int) -> None: - self.partial_kwargs["retries"] = value - - @property - def queue(self) -> str: - return self.partial_kwargs.get("queue", DEFAULT_QUEUE) - - @queue.setter - def queue(self, value: str) -> None: - self.partial_kwargs["queue"] = value - - @property - def pool(self) -> str: - return self.partial_kwargs.get("pool", Pool.DEFAULT_POOL_NAME) - - @pool.setter - def pool(self, value: str) -> None: - self.partial_kwargs["pool"] = value - - @property - def pool_slots(self) -> int: - return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS) - - @pool_slots.setter - def pool_slots(self, value: int) -> None: - self.partial_kwargs["pool_slots"] = value - - @property - def execution_timeout(self) -> datetime.timedelta | None: - return self.partial_kwargs.get("execution_timeout") - - @execution_timeout.setter - def execution_timeout(self, value: datetime.timedelta | None) -> None: - self.partial_kwargs["execution_timeout"] = value - - @property - def max_retry_delay(self) -> datetime.timedelta | None: - return self.partial_kwargs.get("max_retry_delay") - - @max_retry_delay.setter - def max_retry_delay(self, value: datetime.timedelta | None) -> None: - self.partial_kwargs["max_retry_delay"] = value - - @property - def retry_delay(self) -> datetime.timedelta: - return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY) - - @retry_delay.setter - def retry_delay(self, value: datetime.timedelta) -> None: - self.partial_kwargs["retry_delay"] = value - - @property - def retry_exponential_backoff(self) -> bool: - return bool(self.partial_kwargs.get("retry_exponential_backoff")) - - @retry_exponential_backoff.setter - def retry_exponential_backoff(self, value: bool) -> None: - self.partial_kwargs["retry_exponential_backoff"] = value - - @property - def priority_weight(self) -> int: # type: ignore[override] - return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT) - - @priority_weight.setter - def priority_weight(self, value: int) -> None: - self.partial_kwargs["priority_weight"] = value - - @property - def weight_rule(self) -> PriorityWeightStrategy: # type: ignore[override] - return validate_and_load_priority_weight_strategy( - self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) - ) - - @weight_rule.setter - def weight_rule(self, value: str | PriorityWeightStrategy) -> None: - self.partial_kwargs["weight_rule"] = validate_and_load_priority_weight_strategy(value) - - @property - def max_active_tis_per_dag(self) -> int | None: - return self.partial_kwargs.get("max_active_tis_per_dag") - - @max_active_tis_per_dag.setter - def max_active_tis_per_dag(self, value: int | None) -> None: - self.partial_kwargs["max_active_tis_per_dag"] = value - - @property - def max_active_tis_per_dagrun(self) -> int | None: - return self.partial_kwargs.get("max_active_tis_per_dagrun") - - @max_active_tis_per_dagrun.setter - def max_active_tis_per_dagrun(self, value: int | None) -> None: - self.partial_kwargs["max_active_tis_per_dagrun"] = value - - @property - def resources(self) -> Resources | None: - return self.partial_kwargs.get("resources") - - @property - def on_execute_callback(self) -> TaskStateChangeCallbackAttrType: - return self.partial_kwargs.get("on_execute_callback") - - @on_execute_callback.setter - def on_execute_callback(self, value: TaskStateChangeCallbackAttrType) -> None: - self.partial_kwargs["on_execute_callback"] = value - - @property - def on_failure_callback(self) -> TaskStateChangeCallbackAttrType: - return self.partial_kwargs.get("on_failure_callback") - - @on_failure_callback.setter - def on_failure_callback(self, value: TaskStateChangeCallbackAttrType) -> None: - self.partial_kwargs["on_failure_callback"] = value - - @property - def on_retry_callback(self) -> TaskStateChangeCallbackAttrType: - return self.partial_kwargs.get("on_retry_callback") - - @on_retry_callback.setter - def on_retry_callback(self, value: TaskStateChangeCallbackAttrType) -> None: - self.partial_kwargs["on_retry_callback"] = value - - @property - def on_success_callback(self) -> TaskStateChangeCallbackAttrType: - return self.partial_kwargs.get("on_success_callback") - - @on_success_callback.setter - def on_success_callback(self, value: TaskStateChangeCallbackAttrType) -> None: - self.partial_kwargs["on_success_callback"] = value - - @property - def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType: - return self.partial_kwargs.get("on_skipped_callback") - - @on_skipped_callback.setter - def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) -> None: - self.partial_kwargs["on_skipped_callback"] = value - - @property - def run_as_user(self) -> str | None: - return self.partial_kwargs.get("run_as_user") - - @property - def executor(self) -> str | None: - return self.partial_kwargs.get("executor", DEFAULT_EXECUTOR) - - @property - def executor_config(self) -> dict: - return self.partial_kwargs.get("executor_config", {}) - - @property # type: ignore[override] - def inlets(self) -> list[Any]: # type: ignore[override] - return self.partial_kwargs.get("inlets", []) - - @inlets.setter - def inlets(self, value: list[Any]) -> None: # type: ignore[override] - self.partial_kwargs["inlets"] = value - - @property # type: ignore[override] - def outlets(self) -> list[Any]: # type: ignore[override] - return self.partial_kwargs.get("outlets", []) - - @outlets.setter - def outlets(self, value: list[Any]) -> None: # type: ignore[override] - self.partial_kwargs["outlets"] = value - - @property - def doc(self) -> str | None: - return self.partial_kwargs.get("doc") - - @property - def doc_md(self) -> str | None: - return self.partial_kwargs.get("doc_md") - - @property - def doc_json(self) -> str | None: - return self.partial_kwargs.get("doc_json") - - @property - def doc_yaml(self) -> str | None: - return self.partial_kwargs.get("doc_yaml") - - @property - def doc_rst(self) -> str | None: - return self.partial_kwargs.get("doc_rst") - - @property - def allow_nested_operators(self) -> bool: - return bool(self.partial_kwargs.get("allow_nested_operators")) - - def get_dag(self) -> DAG | None: - """Implement Operator.""" - return self.dag - - @property - def output(self) -> XComArg: - """Return reference to XCom pushed by current operator.""" - from airflow.models.xcom_arg import XComArg - - return XComArg(operator=self) - - def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: - """Implement DAGNode.""" - return DagAttributeTypes.OP, self.task_id - def _expand_mapped_kwargs( self, context: Mapping[str, Any], session: Session, *, include_xcom: bool ) -> tuple[Mapping[str, Any], set[int]]: @@ -770,136 +156,3 @@ def expand_start_trigger_args(self, *, context: Context, session: Session) -> St next_kwargs=next_kwargs, timeout=timeout, ) - - def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator: - """ - Get the "normal" Operator after applying the current mapping. - - The *resolve* argument is only used if ``operator_class`` is a real - class, i.e. if this operator is not serialized. If ``operator_class`` is - not a class (i.e. this DAG has been deserialized), this returns a - SerializedBaseOperator that "looks like" the actual unmapping result. - - If *resolve* is a two-tuple (context, session), the information is used - to resolve the mapped arguments into init arguments. If it is a mapping, - no resolving happens, the mapping directly provides those init arguments - resolved from mapped kwargs. - - :meta private: - """ - if isinstance(self.operator_class, type): - if isinstance(resolve, collections.abc.Mapping): - kwargs = resolve - elif resolve is not None: - kwargs, _ = self._expand_mapped_kwargs(*resolve, include_xcom=True) - else: - raise RuntimeError("cannot unmap a non-serialized operator without context") - kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override) - is_setup = kwargs.pop("is_setup", False) - is_teardown = kwargs.pop("is_teardown", False) - on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False) - op = self.operator_class(**kwargs, _airflow_from_mapped=True) - # We need to overwrite task_id here because BaseOperator further - # mangles the task_id based on the task hierarchy (namely, group_id - # is prepended, and '__N' appended to deduplicate). This is hacky, - # but better than duplicating the whole mangling logic. - op.task_id = self.task_id - op.is_setup = is_setup - op.is_teardown = is_teardown - op.on_failure_fail_dagrun = on_failure_fail_dagrun - return op - - # After a mapped operator is serialized, there's no real way to actually - # unmap it since we've lost access to the underlying operator class. - # This tries its best to simply "forward" all the attributes on this - # mapped operator to a new SerializedBaseOperator instance. - from airflow.serialization.serialized_objects import SerializedBaseOperator - - op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True) - for partial_attr, value in self.partial_kwargs.items(): - setattr(op, partial_attr, value) - SerializedBaseOperator.populate_operator(op, self.operator_class) - if self.dag is not None: # For Mypy; we only serialize tasks in a DAG so the check always satisfies. - SerializedBaseOperator.set_task_dag_references(op, self.dag) - return op - - def _get_specified_expand_input(self) -> ExpandInput: - """Input received from the expand call on the operator.""" - return getattr(self, self._expand_input_attr) - - def prepare_for_execution(self) -> MappedOperator: - # Since a mapped operator cannot be used for execution, and an unmapped - # BaseOperator needs to be created later (see render_template_fields), - # we don't need to create a copy of the MappedOperator here. - return self - - def iter_mapped_dependencies(self) -> Iterator[Operator]: - """Upstream dependencies that provide XComs used by this task for task mapping.""" - from airflow.models.xcom_arg import XComArg - - for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()): - yield operator - - @methodtools.lru_cache(maxsize=None) - def get_parse_time_mapped_ti_count(self) -> int: - current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count() - try: - parent_count = super().get_parse_time_mapped_ti_count() - except NotMapped: - return current_count - return parent_count * current_count - - def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: - from airflow.serialization.serialized_objects import _ExpandInputRef - - exp_input = self._get_specified_expand_input() - if isinstance(exp_input, _ExpandInputRef): - exp_input = exp_input.deref(self.dag) - current_count = exp_input.get_total_map_length(run_id, session=session) - try: - parent_count = super().get_mapped_ti_count(run_id, session=session) - except NotMapped: - return current_count - return parent_count * current_count - - def render_template_fields( - self, - context: Context, - jinja_env: jinja2.Environment | None = None, - ) -> None: - """ - Template all attributes listed in *self.template_fields*. - - This updates *context* to reference the map-expanded task and relevant - information, without modifying the mapped operator. The expanded task - in *context* is then rendered in-place. - - :param context: Context dict with values to apply on content. - :param jinja_env: Jinja environment to use for rendering. - """ - if not jinja_env: - jinja_env = self.get_template_env() - - # We retrieve the session here, stored by _run_raw_task in set_current_task_session - # context manager - we cannot pass the session via @provide_session because the signature - # of render_template_fields is defined by BaseOperator and there are already many subclasses - # overriding it, so changing the signature is not an option. However render_template_fields is - # always executed within "_run_raw_task" so we make sure that _run_raw_task uses the - # set_current_task_session context manager to store the session in the current task. - session = get_current_task_instance_session() - - mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session, include_xcom=True) - unmapped_task = self.unmap(mapped_kwargs) - context_update_for_unmapped(context, unmapped_task) - - # Since the operators that extend `BaseOperator` are not subclasses of - # `MappedOperator`, we need to call `_do_render_template_fields` from - # the unmapped task in order to call the operator method when we override - # it to customize the parsing of nested fields. - unmapped_task._do_render_template_fields( - parent=unmapped_task, - template_fields=self.template_fields, - context=context, - jinja_env=jinja_env, - seen_oids=seen_oids, - ) diff --git a/airflow/models/param.py b/airflow/models/param.py index a25d4bba8ee76..cd3ccec26a48a 100644 --- a/airflow/models/param.py +++ b/airflow/models/param.py @@ -28,9 +28,9 @@ from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: - from airflow.models.operator import Operator from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.dag import DAG + from airflow.sdk.types import Operator logger = logging.getLogger(__name__) diff --git a/airflow/models/renderedtifields.py b/airflow/models/renderedtifields.py index c56517f0815ca..f2d7d83920fff 100644 --- a/airflow/models/renderedtifields.py +++ b/airflow/models/renderedtifields.py @@ -47,8 +47,8 @@ from sqlalchemy.orm import Session from sqlalchemy.sql import FromClause - from airflow.models import Operator from airflow.models.taskinstance import TaskInstance + from airflow.sdk.types import Operator def get_serialized_template_fields(task: Operator): diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 5a605f40cd518..d61331dd620d3 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -106,9 +106,9 @@ from airflow.models.taskreschedule import TaskReschedule from airflow.models.xcom import LazyXComSelectSequence, XCom from airflow.plugins_manager import integrate_macros_plugins -from airflow.sdk.definitions._internal.contextmanager import _CURRENT_CONTEXT from airflow.sdk.definitions._internal.templater import SandboxedEnvironment from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef +from airflow.sdk.definitions.taskgroup import MappedTaskGroup from airflow.sentry import Sentry from airflow.settings import task_instance_mutation_hook from airflow.stats import Stats @@ -135,10 +135,8 @@ from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.sqlalchemy import ExecutorConfigType, ExtendedJSON, UtcDateTime from airflow.utils.state import DagRunState, State, TaskInstanceState -from airflow.utils.task_group import MappedTaskGroup from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.timeout import timeout -from airflow.utils.types import AttributeRemoved from airflow.utils.xcom import XCOM_RETURN_KEY TR = TaskReschedule @@ -161,7 +159,7 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG as SchedulerDAG, DagModel from airflow.models.dagrun import DagRun - from airflow.models.operator import Operator + from airflow.sdk.definitions._internal.abstractoperator import Operator from airflow.sdk.definitions.dag import DAG from airflow.sdk.types import OutletEventAccessorsProtocol, RuntimeTaskInstanceProtocol from airflow.timetables.base import DataInterval @@ -232,7 +230,7 @@ def _run_raw_task( :param session: SQLAlchemy ORM Session """ if TYPE_CHECKING: - assert ti.task + assert isinstance(ti.task, BaseOperator) ti.test_mode = test_mode ti.refresh_from_task(ti.task, pool_override=pool) @@ -374,6 +372,8 @@ def set_current_context(context: Context) -> Generator[Context, None, None]: This method should be called once per Task execution, before calling operator.execute. """ + from airflow.sdk.definitions._internal.contextmanager import _CURRENT_CONTEXT + _CURRENT_CONTEXT.append(context) try: yield context @@ -677,12 +677,14 @@ def _execute_task(task_instance: TaskInstance, context: Context, task_orig: Oper :meta private: """ - from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator task_to_execute = task_instance.task if TYPE_CHECKING: - assert task_to_execute + # TODO: TaskSDK this function will need 100% re-writing + # This only works with a "rich" BaseOperator, not the SDK version + assert isinstance(task_to_execute, BaseOperator) if isinstance(task_to_execute, MappedOperator): raise AirflowException("MappedOperator cannot be executed.") @@ -925,6 +927,7 @@ def _get_template_context( from airflow import macros from airflow.models.abstractoperator import NotMapped + from airflow.models.baseoperator import BaseOperator integrate_macros_plugins() @@ -934,10 +937,6 @@ def _get_template_context( assert task assert task.dag - if task.dag.__class__ is AttributeRemoved: - # TODO: Task-SDK: Remove this after AIP-44 code is removed - task.dag = dag # type: ignore[assignment] # required after deserialization - dag_run = task_instance.get_dagrun(session) data_interval = dag.get_run_data_interval(dag_run) @@ -1002,11 +1001,6 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]: return triggering_events - try: - expanded_ti_count: int | None = task.get_mapped_ti_count(task_instance.run_id, session=session) - except NotMapped: - expanded_ti_count = None - # NOTE: If you add to this dict, make sure to also update the following: # * Context in task_sdk/src/airflow/sdk/definitions/context.py # * KNOWN_CONTEXT_KEYS in airflow/utils/context.py @@ -1019,7 +1013,6 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]: "outlet_events": OutletEventAccessors(), "ds": ds, "ds_nodash": ds_nodash, - "expanded_ti_count": expanded_ti_count, "inlets": task.inlets, "inlet_events": InletEventsAccessors(task.inlets, session=session), "logical_date": logical_date, @@ -1032,7 +1025,7 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]: "prev_start_date_success": get_prev_start_date_success(), "prev_end_date_success": get_prev_end_date_success(), "run_id": task_instance.run_id, - "task": task, + "task": task, # type: ignore[typeddict-item] "task_instance": task_instance, "task_instance_key_str": f"{task.dag_id}__{task.task_id}__{ds_nodash}", "test_mode": task_instance.test_mode, @@ -1047,6 +1040,24 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]: }, "conn": ConnectionAccessor(), } + + try: + expanded_ti_count: int | None = BaseOperator.get_mapped_ti_count( + task, task_instance.run_id, session=session + ) + context["expanded_ti_count"] = expanded_ti_count + if expanded_ti_count: + context["_upstream_map_indexes"] = { # type: ignore[typeddict-unknown-key] + upstream.task_id: task_instance.get_relevant_upstream_map_indexes( + upstream, + expanded_ti_count, + session=session, + ) + for upstream in task.upstream_list + } + except NotMapped: + pass + # Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it # is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890 return context @@ -1205,12 +1216,7 @@ def _record_task_map_for_downstreams( :meta private: """ - from airflow.models.mappedoperator import MappedOperator - - # TODO: Task-SDK: Remove this after AIP-44 code is removed - if task.dag.__class__ is AttributeRemoved: - # required after deserialization - task.dag = dag # type: ignore[assignment] + from airflow.sdk.definitions.mappedoperator import MappedOperator if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate. return @@ -1253,6 +1259,8 @@ def _get_previous_dagrun( if dag is None: return None + if TYPE_CHECKING: + assert isinstance(dag, SchedulerDAG) dr = task_instance.get_dagrun(session=session) dr.dag = dag @@ -1538,6 +1546,10 @@ def _defer_task( ) -> TaskInstance: from airflow.models.trigger import Trigger + # TODO: TaskSDK add start_trigger_args to SDK definitions + if TYPE_CHECKING: + assert ti.task is None or isinstance(ti.task, BaseOperator) + timeout: timedelta | None if exception is not None: trigger_row = Trigger.from_object(exception.trigger) @@ -1910,6 +1922,7 @@ def _command_as_list( if hasattr(ti, "task") and getattr(ti.task, "dag", None) is not None: if TYPE_CHECKING: assert ti.task + assert isinstance(ti.task.dag, SchedulerDAG) dag = ti.task.dag else: dag = ti.dag_model @@ -2351,7 +2364,7 @@ def are_dependencies_met( def get_failed_dep_statuses(self, dep_context: DepContext | None = None, session: Session = NEW_SESSION): """Get failed Dependencies.""" if TYPE_CHECKING: - assert self.task + assert isinstance(self.task, BaseOperator) dep_context = dep_context or DepContext() for dep in dep_context.deps | self.task.deps: @@ -2449,6 +2462,7 @@ def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun: if getattr(self, "task", None) is not None: if TYPE_CHECKING: assert self.task + assert isinstance(self.task.dag, SchedulerDAG) dr.dag = self.task.dag # Record it in the instance for next time. This means that `self.logical_date` will work correctly set_committed_value(self, "dag_run", dr) @@ -2461,7 +2475,7 @@ def ensure_dag(cls, task_instance: TaskInstance, session: Session = NEW_SESSION) """Ensure that task has a dag object associated, might have been removed by serialization.""" if TYPE_CHECKING: assert task_instance.task - if task_instance.task.dag is None or task_instance.task.dag.__class__ is AttributeRemoved: + if task_instance.task.dag is None: task_instance.task.dag = DagBag(read_dags_from_db=True).get_dag( dag_id=task_instance.dag_id, session=session ) @@ -3123,7 +3137,7 @@ def fetch_handle_failure_context( try: if getattr(ti, "task", None) and context: if TYPE_CHECKING: - assert ti.task + assert isinstance(ti.task, BaseOperator) task = ti.task.unmap((context, session)) except Exception: cls.logger().error("Unable to unmap task to determine if we need to send an alert email") @@ -3220,7 +3234,7 @@ def get_template_context( """ if TYPE_CHECKING: assert self.task - assert self.task.dag + assert isinstance(self.task.dag, SchedulerDAG) return _get_template_context( task_instance=self, dag=self.task.dag, @@ -3238,7 +3252,7 @@ def get_rendered_template_fields(self, session: Session = NEW_SESSION) -> None: from airflow.models.renderedtifields import RenderedTaskInstanceFields if TYPE_CHECKING: - assert self.task + assert isinstance(self.task, BaseOperator) rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self, session=session) if rendered_task_instance_fields: @@ -3279,7 +3293,7 @@ def render_templates( the unmapped, fully rendered BaseOperator. The original ``self.task`` before replacement is returned. """ - from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator if not context: context = self.get_template_context() @@ -3292,9 +3306,6 @@ def render_templates( assert self.task assert ti.task - if ti.task.dag.__class__ is AttributeRemoved: - ti.task.dag = self.task.dag # type: ignore[assignment] - # If self.task is mapped, this call replaces self.task to point to the # unmapped BaseOperator created by this function! This is because the # MappedOperator is useless for template rendering, and we need to be @@ -3590,6 +3601,8 @@ def tg2(inp): :return: Specific map index or map indexes to pull, or ``None`` if we want to "whole" return value (i.e. no mapped task groups involved). """ + from airflow.models.baseoperator import BaseOperator + if TYPE_CHECKING: assert self.task @@ -3612,7 +3625,8 @@ def tg2(inp): # At this point we know the two tasks share a mapped task group, and we # should use a "partial" value. Let's break down the mapped ti count # between the ancestor and further expansion happened inside it. - ancestor_ti_count = common_ancestor.get_mapped_ti_count(self.run_id, session=session) + + ancestor_ti_count = BaseOperator.get_mapped_ti_count(common_ancestor, self.run_id, session=session) ancestor_map_index = self.map_index * ancestor_ti_count // ti_count # If the task is NOT further expanded inside the common ancestor, we @@ -3732,7 +3746,7 @@ def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> Mapp def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool: """Whether given operator is *further* mapped inside a task group.""" - from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator if isinstance(operator, MappedOperator): return True diff --git a/airflow/models/taskmap.py b/airflow/models/taskmap.py index 2702b906df034..fdd37f1f5b55c 100644 --- a/airflow/models/taskmap.py +++ b/airflow/models/taskmap.py @@ -21,15 +21,20 @@ import collections.abc import enum -from collections.abc import Collection +from collections.abc import Collection, Iterable, Sequence from typing import TYPE_CHECKING, Any -from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer, String +from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer, String, func, or_, select from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies -from airflow.utils.sqlalchemy import ExtendedJSON +from airflow.utils.db import exists_query +from airflow.utils.sqlalchemy import ExtendedJSON, with_row_locks +from airflow.utils.state import State, TaskInstanceState if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.models.dag import DAG as SchedulerDAG from airflow.models.taskinstance import TaskInstance @@ -114,3 +119,133 @@ def variant(self) -> TaskMapVariant: if self.keys is None: return TaskMapVariant.LIST return TaskMapVariant.DICT + + @classmethod + def expand_mapped_task(cls, task, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]: + """ + Create the mapped task instances for mapped task. + + :raise NotMapped: If this task does not need expansion. + :return: The newly created mapped task instances (if any) in ascending + order by map index, and the maximum map index value. + """ + from airflow.models.baseoperator import BaseOperator as DBBaseOperator + from airflow.models.expandinput import NotFullyPopulated + from airflow.models.taskinstance import TaskInstance + from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator + from airflow.settings import task_instance_mutation_hook + + if not isinstance(task, (BaseOperator, MappedOperator)): + raise RuntimeError( + f"cannot expand unrecognized operator type {type(task).__module__}.{type(task).__name__}" + ) + + try: + total_length: int | None = DBBaseOperator.get_mapped_ti_count(task, run_id, session=session) + except NotFullyPopulated as e: + if not task.dag or not task.dag.partial: + task.log.error( + "Cannot expand %r for run %s; missing upstream values: %s", + task, + run_id, + sorted(e.missing), + ) + total_length = None + + state: TaskInstanceState | None = None + unmapped_ti: TaskInstance | None = session.scalars( + select(TaskInstance).where( + TaskInstance.dag_id == task.dag_id, + TaskInstance.task_id == task.task_id, + TaskInstance.run_id == run_id, + TaskInstance.map_index == -1, + or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), + ) + ).one_or_none() + + all_expanded_tis: list[TaskInstance] = [] + + if unmapped_ti: + if TYPE_CHECKING: + assert task.dag is None or isinstance(task.dag, SchedulerDAG) + + # The unmapped task instance still exists and is unfinished, i.e. we + # haven't tried to run it before. + if total_length is None: + # If the DAG is partial, it's likely that the upstream tasks + # are not done yet, so the task can't fail yet. + if not task.dag or not task.dag.partial: + unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED + elif total_length < 1: + # If the upstream maps this to a zero-length value, simply mark + # the unmapped task instance as SKIPPED (if needed). + task.log.info( + "Marking %s as SKIPPED since the map has %d values to expand", + unmapped_ti, + total_length, + ) + unmapped_ti.state = TaskInstanceState.SKIPPED + else: + zero_index_ti_exists = exists_query( + TaskInstance.dag_id == task.dag_id, + TaskInstance.task_id == task.task_id, + TaskInstance.run_id == run_id, + TaskInstance.map_index == 0, + session=session, + ) + if not zero_index_ti_exists: + # Otherwise convert this into the first mapped index, and create + # TaskInstance for other indexes. + unmapped_ti.map_index = 0 + task.log.debug("Updated in place to become %s", unmapped_ti) + all_expanded_tis.append(unmapped_ti) + # execute hook for task instance map index 0 + task_instance_mutation_hook(unmapped_ti) + session.flush() + else: + task.log.debug("Deleting the original task instance: %s", unmapped_ti) + session.delete(unmapped_ti) + state = unmapped_ti.state + + if total_length is None or total_length < 1: + # Nothing to fixup. + indexes_to_map: Iterable[int] = () + else: + # Only create "missing" ones. + current_max_mapping = session.scalar( + select(func.max(TaskInstance.map_index)).where( + TaskInstance.dag_id == task.dag_id, + TaskInstance.task_id == task.task_id, + TaskInstance.run_id == run_id, + ) + ) + indexes_to_map = range(current_max_mapping + 1, total_length) + + for index in indexes_to_map: + # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings. + ti = TaskInstance(task, run_id=run_id, map_index=index, state=state) + task.log.debug("Expanding TIs upserted %s", ti) + task_instance_mutation_hook(ti) + ti = session.merge(ti) + ti.refresh_from_task(task) # session.merge() loses task information. + all_expanded_tis.append(ti) + + # Coerce the None case to 0 -- these two are almost treated identically, + # except the unmapped ti (if exists) is marked to different states. + total_expanded_ti_count = total_length or 0 + + # Any (old) task instances with inapplicable indexes (>= the total + # number we need) are set to "REMOVED". + query = select(TaskInstance).where( + TaskInstance.dag_id == task.dag_id, + TaskInstance.task_id == task.task_id, + TaskInstance.run_id == run_id, + TaskInstance.map_index >= total_expanded_ti_count, + ) + query = with_row_locks(query, of=TaskInstance, session=session, skip_locked=True) + to_update = session.scalars(query) + for ti in to_update: + ti.state = TaskInstanceState.REMOVED + session.flush() + return all_expanded_tis, total_expanded_ti_count - 1 diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 4bf91a68bee53..078a9e6ff5223 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -17,714 +17,100 @@ from __future__ import annotations -import contextlib -import inspect -import itertools -from collections.abc import Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Callable, Union, overload +from functools import singledispatch +from typing import TYPE_CHECKING from sqlalchemy import func, or_, select - -from airflow.exceptions import AirflowException, XComNotFound -from airflow.models import MappedOperator, TaskInstance -from airflow.models.abstractoperator import AbstractOperator -from airflow.models.taskmixin import DependencyMixin -from airflow.sdk.definitions._internal.mixins import ResolveMixin -from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet +from sqlalchemy.orm import Session + +from airflow.sdk.definitions._internal.types import ArgNotSet +from airflow.sdk.definitions.mappedoperator import MappedOperator +from airflow.sdk.definitions.xcom_arg import ( + ConcatXComArg, + MapXComArg, + PlainXComArg, + XComArg, + ZipXComArg, +) from airflow.utils.db import exists_query -from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.state import State -from airflow.utils.trigger_rule import TriggerRule from airflow.utils.xcom import XCOM_RETURN_KEY -if TYPE_CHECKING: - from sqlalchemy.orm import Session - - from airflow.models.operator import Operator - from airflow.sdk.definitions.baseoperator import BaseOperator - from airflow.sdk.definitions.dag import DAG - from airflow.utils.edgemodifier import EdgeModifier - -# Callable objects contained by MapXComArg. We only accept callables from -# the user, but deserialize them into strings in a serialized XComArg for -# safety (those callables are arbitrary user code). -MapCallables = Sequence[Union[Callable[[Any], Any], str]] - - -class XComArg(ResolveMixin, DependencyMixin): - """ - Reference to an XCom value pushed from another operator. - - The implementation supports:: - - xcomarg >> op - xcomarg << op - op >> xcomarg # By BaseOperator code - op << xcomarg # By BaseOperator code - - **Example**: The moment you get a result from any operator (decorated or regular) you can :: - - any_op = AnyOperator() - xcomarg = XComArg(any_op) - # or equivalently - xcomarg = any_op.output - my_op = MyOperator() - my_op >> xcomarg - - This object can be used in legacy Operators via Jinja. - - **Example**: You can make this result to be part of any generated string:: - - any_op = AnyOperator() - xcomarg = any_op.output - op1 = MyOperator(my_text_message=f"the value is {xcomarg}") - op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}") - - :param operator: Operator instance to which the XComArg references. - :param key: Key used to pull the XCom value. Defaults to *XCOM_RETURN_KEY*, - i.e. the referenced operator's return value. - """ - - @overload - def __new__(cls: type[XComArg], operator: Operator, key: str = XCOM_RETURN_KEY) -> XComArg: - """Execute when the user writes ``XComArg(...)`` directly.""" - - @overload - def __new__(cls: type[XComArg]) -> XComArg: - """Execute by Python internals from subclasses.""" - - def __new__(cls, *args, **kwargs) -> XComArg: - if cls is XComArg: - return PlainXComArg(*args, **kwargs) - return super().__new__(cls) - - @staticmethod - def iter_xcom_references(arg: Any) -> Iterator[tuple[Operator, str]]: - """ - Return XCom references in an arbitrary value. - - Recursively traverse ``arg`` and look for XComArg instances in any - collection objects, and instances with ``template_fields`` set. - """ - if isinstance(arg, ResolveMixin): - yield from arg.iter_references() - elif isinstance(arg, (tuple, set, list)): - for elem in arg: - yield from XComArg.iter_xcom_references(elem) - elif isinstance(arg, dict): - for elem in arg.values(): - yield from XComArg.iter_xcom_references(elem) - elif isinstance(arg, AbstractOperator): - for attr in arg.template_fields: - yield from XComArg.iter_xcom_references(getattr(arg, attr)) - - @staticmethod - def apply_upstream_relationship(op: DependencyMixin, arg: Any): - """ - Set dependency for XComArgs. - - This looks for XComArg objects in ``arg`` "deeply" (looking inside - collections objects and classes decorated with ``template_fields``), and - sets the relationship to ``op`` on any found. - """ - for operator, _ in XComArg.iter_xcom_references(arg): - op.set_upstream(operator) - - @property - def roots(self) -> list[Operator]: - """Required by DependencyMixin.""" - return [op for op, _ in self.iter_references()] - - @property - def leaves(self) -> list[Operator]: - """Required by DependencyMixin.""" - return [op for op, _ in self.iter_references()] - - def set_upstream( - self, - task_or_task_list: DependencyMixin | Sequence[DependencyMixin], - edge_modifier: EdgeModifier | None = None, - ): - """Proxy to underlying operator set_upstream method. Required by DependencyMixin.""" - for operator, _ in self.iter_references(): - operator.set_upstream(task_or_task_list, edge_modifier) - - def set_downstream( - self, - task_or_task_list: DependencyMixin | Sequence[DependencyMixin], - edge_modifier: EdgeModifier | None = None, - ): - """Proxy to underlying operator set_downstream method. Required by DependencyMixin.""" - for operator, _ in self.iter_references(): - operator.set_downstream(task_or_task_list, edge_modifier) - - def _serialize(self) -> dict[str, Any]: - """ - Serialize an XComArg. - - The implementation should be the inverse function to ``deserialize``, - returning a data dict converted from this XComArg derivative. DAG - serialization does not call this directly, but ``serialize_xcom_arg`` - instead, which adds additional information to dispatch deserialization - to the correct class. - """ - raise NotImplementedError() - - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - """ - Deserialize an XComArg. - - The implementation should be the inverse function to ``serialize``, - implementing given a data dict converted from this XComArg derivative, - how the original XComArg should be created. DAG serialization relies on - additional information added in ``serialize_xcom_arg`` to dispatch data - dicts to the correct ``_deserialize`` information, so this function does - not need to validate whether the incoming data contains correct keys. - """ - raise NotImplementedError() - - def map(self, f: Callable[[Any], Any]) -> MapXComArg: - return MapXComArg(self, [f]) - - def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: - return ZipXComArg([self, *others], fillvalue=fillvalue) - - def concat(self, *others: XComArg) -> ConcatXComArg: - return ConcatXComArg([self, *others]) - - def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: - """ - Inspect length of pushed value for task-mapping. - - This is used to determine how many task instances the scheduler should - create for a downstream using this XComArg for task-mapping. - - *None* may be returned if the depended XCom has not been pushed. - """ - raise NotImplementedError() - - def resolve( - self, context: Mapping[str, Any], session: Session | None = None, *, include_xcom: bool = True - ) -> Any: - """ - Pull XCom value. - - This should only be called during ``op.execute()`` with an appropriate - context (e.g. generated from ``TaskInstance.get_template_context()``). - Although the ``ResolveMixin`` parent mixin also has a ``resolve`` - protocol, this adds the optional ``session`` argument that some of the - subclasses need. - - :meta private: - """ - raise NotImplementedError() - - def __enter__(self): - if not self.operator.is_setup and not self.operator.is_teardown: - raise AirflowException("Only setup/teardown tasks can be used as context managers.") - SetupTeardownContext.push_setup_teardown_task(self.operator) - return SetupTeardownContext - - def __exit__(self, exc_type, exc_val, exc_tb): - SetupTeardownContext.set_work_task_roots_and_leaves() - - -class PlainXComArg(XComArg): - """ - Reference to one single XCom without any additional semantics. - - This class should not be accessed directly, but only through XComArg. The - class inheritance chain and ``__new__`` is implemented in this slightly - convoluted way because we want to - - a. Allow the user to continue using XComArg directly for the simple - semantics (see documentation of the base class for details). - b. Make ``isinstance(thing, XComArg)`` be able to detect all kinds of XCom - references. - c. Not allow many properties of PlainXComArg (including ``__getitem__`` and - ``__str__``) to exist on other kinds of XComArg implementations since - they don't make sense. - - :meta private: - """ - - def __init__(self, operator: Operator, key: str = XCOM_RETURN_KEY): - self.operator = operator - self.key = key - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, PlainXComArg): - return NotImplemented - return self.operator == other.operator and self.key == other.key - - def __getitem__(self, item: str) -> XComArg: - """Implement xcomresult['some_result_key'].""" - if not isinstance(item, str): - raise ValueError(f"XComArg only supports str lookup, received {type(item).__name__}") - return PlainXComArg(operator=self.operator, key=item) - - def __iter__(self): - """ - Override iterable protocol to raise error explicitly. - - The default ``__iter__`` implementation in Python calls ``__getitem__`` - with 0, 1, 2, etc. until it hits an ``IndexError``. This does not work - well with our custom ``__getitem__`` implementation, and results in poor - DAG-writing experience since a misplaced ``*`` expansion would create an - infinite loop consuming the entire DAG parser. - - This override catches the error eagerly, so an incorrectly implemented - DAG fails fast and avoids wasting resources on nonsensical iterating. - """ - raise TypeError("'XComArg' object is not iterable") - - def __repr__(self) -> str: - if self.key == XCOM_RETURN_KEY: - return f"XComArg({self.operator!r})" - return f"XComArg({self.operator!r}, {self.key!r})" - - def __str__(self) -> str: - """ - Backward compatibility for old-style jinja used in Airflow Operators. - - **Example**: to use XComArg at BashOperator:: - - BashOperator(cmd=f"... { xcomarg } ...") - - :return: - """ - xcom_pull_kwargs = [ - f"task_ids='{self.operator.task_id}'", - f"dag_id='{self.operator.dag_id}'", - ] - if self.key is not None: - xcom_pull_kwargs.append(f"key='{self.key}'") - - xcom_pull_str = ", ".join(xcom_pull_kwargs) - # {{{{ are required for escape {{ in f-string - xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_str}) }}}}" - return xcom_pull - - def _serialize(self) -> dict[str, Any]: - return {"task_id": self.operator.task_id, "key": self.key} - - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - return cls(dag.get_task(data["task_id"]), data["key"]) - - @property - def is_setup(self) -> bool: - return self.operator.is_setup - - @is_setup.setter - def is_setup(self, val: bool): - self.operator.is_setup = val - - @property - def is_teardown(self) -> bool: - return self.operator.is_teardown - - @is_teardown.setter - def is_teardown(self, val: bool): - self.operator.is_teardown = val +__all__ = ["XComArg", "get_task_map_length"] - @property - def on_failure_fail_dagrun(self) -> bool: - return self.operator.on_failure_fail_dagrun - - @on_failure_fail_dagrun.setter - def on_failure_fail_dagrun(self, val: bool): - self.operator.on_failure_fail_dagrun = val - - def as_setup(self) -> DependencyMixin: - for operator, _ in self.iter_references(): - operator.is_setup = True - return self - - def as_teardown( - self, - *, - setups: BaseOperator | Iterable[BaseOperator] | None = None, - on_failure_fail_dagrun: bool | None = None, - ): - for operator, _ in self.iter_references(): - operator.is_teardown = True - operator.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS - if on_failure_fail_dagrun is not None: - operator.on_failure_fail_dagrun = on_failure_fail_dagrun - if setups is not None: - setups = [setups] if isinstance(setups, DependencyMixin) else setups - for s in setups: - s.is_setup = True - s >> operator - return self - - def iter_references(self) -> Iterator[tuple[Operator, str]]: - yield self.operator, self.key - - def map(self, f: Callable[[Any], Any]) -> MapXComArg: - if self.key != XCOM_RETURN_KEY: - raise ValueError("cannot map against non-return XCom") - return super().map(f) - - def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: - if self.key != XCOM_RETURN_KEY: - raise ValueError("cannot map against non-return XCom") - return super().zip(*others, fillvalue=fillvalue) - - def concat(self, *others: XComArg) -> ConcatXComArg: - if self.key != XCOM_RETURN_KEY: - raise ValueError("cannot concatenate non-return XCom") - return super().concat(*others) - - def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: - from airflow.models.taskinstance import TaskInstance - from airflow.models.taskmap import TaskMap - from airflow.models.xcom import XCom - - dag_id = self.operator.dag_id - task_id = self.operator.task_id - is_mapped = isinstance(self.operator, MappedOperator) - - if is_mapped: - unfinished_ti_exists = exists_query( - TaskInstance.dag_id == dag_id, - TaskInstance.run_id == run_id, - TaskInstance.task_id == task_id, - # Special NULL treatment is needed because 'state' can be NULL. - # The "IN" part would produce "NULL NOT IN ..." and eventually - # "NULl = NULL", which is a big no-no in SQL. - or_( - TaskInstance.state.is_(None), - TaskInstance.state.in_(s.value for s in State.unfinished if s is not None), - ), - session=session, - ) - if unfinished_ti_exists: - return None # Not all of the expanded tis are done yet. - query = select(func.count(XCom.map_index)).where( - XCom.dag_id == dag_id, - XCom.run_id == run_id, - XCom.task_id == task_id, - XCom.map_index >= 0, - XCom.key == XCOM_RETURN_KEY, - ) - else: - query = select(TaskMap.length).where( - TaskMap.dag_id == dag_id, - TaskMap.run_id == run_id, - TaskMap.task_id == task_id, - TaskMap.map_index < 0, - ) - return session.scalar(query) - - # TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK - @provide_session - def resolve( - self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True - ) -> Any: - ti = context["ti"] - if TYPE_CHECKING: - assert isinstance(ti, TaskInstance) - task_id = self.operator.task_id - map_indexes = ti.get_relevant_upstream_map_indexes( - self.operator, - context["expanded_ti_count"], +if TYPE_CHECKING: + from airflow.models.expandinput import OperatorExpandArgument + + +@singledispatch +def get_task_map_length(xcom_arg: OperatorExpandArgument, run_id: str, *, session: Session) -> int | None: + # The base implementation -- specific XComArg subclasses have specialised implementations + raise NotImplementedError() + + +@get_task_map_length.register +def _(xcom_arg: PlainXComArg, run_id: str, *, session: Session): + from airflow.models.taskinstance import TaskInstance + from airflow.models.taskmap import TaskMap + from airflow.models.xcom import XCom + + dag_id = xcom_arg.operator.dag_id + task_id = xcom_arg.operator.task_id + is_mapped = isinstance(xcom_arg.operator, MappedOperator) + + if is_mapped: + unfinished_ti_exists = exists_query( + TaskInstance.dag_id == dag_id, + TaskInstance.run_id == run_id, + TaskInstance.task_id == task_id, + # Special NULL treatment is needed because 'state' can be NULL. + # The "IN" part would produce "NULL NOT IN ..." and eventually + # "NULl = NULL", which is a big no-no in SQL. + or_( + TaskInstance.state.is_(None), + TaskInstance.state.in_(s.value for s in State.unfinished if s is not None), + ), session=session, ) - - result = ti.xcom_pull( - task_ids=task_id, - map_indexes=map_indexes, - key=self.key, - default=NOTSET, + if unfinished_ti_exists: + return None # Not all of the expanded tis are done yet. + query = select(func.count(XCom.map_index)).where( + XCom.dag_id == dag_id, + XCom.run_id == run_id, + XCom.task_id == task_id, + XCom.map_index >= 0, + XCom.key == XCOM_RETURN_KEY, ) - if not isinstance(result, ArgNotSet): - return result - if self.key == XCOM_RETURN_KEY: - return None - if getattr(self.operator, "multiple_outputs", False): - # If the operator is set to have multiple outputs and it was not executed, - # we should return "None" instead of showing an error. This is because when - # multiple outputs XComs are created, the XCom keys associated with them will have - # different names than the predefined "XCOM_RETURN_KEY" and won't be found. - # Therefore, it's better to return "None" like we did above where self.key==XCOM_RETURN_KEY. - return None - raise XComNotFound(ti.dag_id, task_id, self.key) - - -def _get_callable_name(f: Callable | str) -> str: - """Try to "describe" a callable by getting its name.""" - if callable(f): - return f.__name__ - # Parse the source to find whatever is behind "def". For safety, we don't - # want to evaluate the code in any meaningful way! - with contextlib.suppress(Exception): - kw, name, _ = f.lstrip().split(None, 2) - if kw == "def": - return name - return "" - - -class _MapResult(Sequence): - def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: - self.value = value - self.callables = callables - - def __getitem__(self, index: Any) -> Any: - value = self.value[index] - - # In the worker, we can access all actual callables. Call them. - callables = [f for f in self.callables if callable(f)] - if len(callables) == len(self.callables): - for f in callables: - value = f(value) - return value - - # In the scheduler, we don't have access to the actual callables, nor do - # we want to run it since it's arbitrary code. This builds a string to - # represent the call chain in the UI or logs instead. - for v in self.callables: - value = f"{_get_callable_name(v)}({value})" - return value - - def __len__(self) -> int: - return len(self.value) - - -class MapXComArg(XComArg): - """ - An XCom reference with ``map()`` call(s) applied. - - This is based on an XComArg, but also applies a series of "transforms" that - convert the pulled XCom value. - - :meta private: - """ - - def __init__(self, arg: XComArg, callables: MapCallables) -> None: - for c in callables: - if getattr(c, "_airflow_is_task_decorator", False): - raise ValueError("map() argument must be a plain function, not a @task operator") - self.arg = arg - self.callables = callables - - def __repr__(self) -> str: - map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables) - return f"{self.arg!r}{map_calls}" - - def _serialize(self) -> dict[str, Any]: - return { - "arg": serialize_xcom_arg(self.arg), - "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables], - } - - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - # We are deliberately NOT deserializing the callables. These are shown - # in the UI, and displaying a function object is useless. - return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) - - def iter_references(self) -> Iterator[tuple[Operator, str]]: - yield from self.arg.iter_references() - - def map(self, f: Callable[[Any], Any]) -> MapXComArg: - # Flatten arg.map(f1).map(f2) into one MapXComArg. - return MapXComArg(self.arg, [*self.callables, f]) - - def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: - return self.arg.get_task_map_length(run_id, session=session) - - @provide_session - def resolve( - self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True - ) -> Any: - value = self.arg.resolve(context, session=session, include_xcom=include_xcom) - if not isinstance(value, (Sequence, dict)): - raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") - return _MapResult(value, self.callables) - - -class _ZipResult(Sequence): - def __init__(self, values: Sequence[Sequence | dict], *, fillvalue: Any = NOTSET) -> None: - self.values = values - self.fillvalue = fillvalue - - @staticmethod - def _get_or_fill(container: Sequence | dict, index: Any, fillvalue: Any) -> Any: - try: - return container[index] - except (IndexError, KeyError): - return fillvalue - - def __getitem__(self, index: Any) -> Any: - if index >= len(self): - raise IndexError(index) - return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values) - - def __len__(self) -> int: - lengths = (len(v) for v in self.values) - if isinstance(self.fillvalue, ArgNotSet): - return min(lengths) - return max(lengths) - - -class ZipXComArg(XComArg): - """ - An XCom reference with ``zip()`` applied. - - This is constructed from multiple XComArg instances, and presents an - iterable that "zips" them together like the built-in ``zip()`` (and - ``itertools.zip_longest()`` if ``fillvalue`` is provided). - """ - - def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None: - if not args: - raise ValueError("At least one input is required") - self.args = args - self.fillvalue = fillvalue - - def __repr__(self) -> str: - args_iter = iter(self.args) - first = repr(next(args_iter)) - rest = ", ".join(repr(arg) for arg in args_iter) - if isinstance(self.fillvalue, ArgNotSet): - return f"{first}.zip({rest})" - return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})" - - def _serialize(self) -> dict[str, Any]: - args = [serialize_xcom_arg(arg) for arg in self.args] - if isinstance(self.fillvalue, ArgNotSet): - return {"args": args} - return {"args": args, "fillvalue": self.fillvalue} - - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - return cls( - [deserialize_xcom_arg(arg, dag) for arg in data["args"]], - fillvalue=data.get("fillvalue", NOTSET), + else: + query = select(TaskMap.length).where( + TaskMap.dag_id == dag_id, + TaskMap.run_id == run_id, + TaskMap.task_id == task_id, + TaskMap.map_index < 0, ) - - def iter_references(self) -> Iterator[tuple[Operator, str]]: - for arg in self.args: - yield from arg.iter_references() - - def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: - all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args) - ready_lengths = [length for length in all_lengths if length is not None] - if len(ready_lengths) != len(self.args): - return None # If any of the referenced XComs is not ready, we are not ready either. - if isinstance(self.fillvalue, ArgNotSet): - return min(ready_lengths) - return max(ready_lengths) - - @provide_session - def resolve( - self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True - ) -> Any: - values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args] - for value in values: - if not isinstance(value, (Sequence, dict)): - raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}") - return _ZipResult(values, fillvalue=self.fillvalue) - - -class _ConcatResult(Sequence): - def __init__(self, values: Sequence[Sequence | dict]) -> None: - self.values = values - - def __getitem__(self, index: Any) -> Any: - if index >= 0: - i = index - else: - i = len(self) + index - for value in self.values: - if i < 0: - break - elif i >= (curlen := len(value)): - i -= curlen - elif isinstance(value, Sequence): - return value[i] - else: - return next(itertools.islice(iter(value), i, None)) - raise IndexError("list index out of range") - - def __len__(self) -> int: - return sum(len(v) for v in self.values) - - -class ConcatXComArg(XComArg): - """ - Concatenating multiple XCom references into one. - - This is done by calling ``concat()`` on an XComArg to combine it with - others. The effect is similar to Python's :func:`itertools.chain`, but the - return value also supports index access. - """ - - def __init__(self, args: Sequence[XComArg]) -> None: - if not args: - raise ValueError("At least one input is required") - self.args = args - - def __repr__(self) -> str: - args_iter = iter(self.args) - first = repr(next(args_iter)) - rest = ", ".join(repr(arg) for arg in args_iter) - return f"{first}.concat({rest})" - - def _serialize(self) -> dict[str, Any]: - return {"args": [serialize_xcom_arg(arg) for arg in self.args]} - - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - return cls([deserialize_xcom_arg(arg, dag) for arg in data["args"]]) - - def iter_references(self) -> Iterator[tuple[Operator, str]]: - for arg in self.args: - yield from arg.iter_references() - - def concat(self, *others: XComArg) -> ConcatXComArg: - # Flatten foo.concat(x).concat(y) into one call. - return ConcatXComArg([*self.args, *others]) - - def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: - all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args) - ready_lengths = [length for length in all_lengths if length is not None] - if len(ready_lengths) != len(self.args): - return None # If any of the referenced XComs is not ready, we are not ready either. - return sum(ready_lengths) - - @provide_session - def resolve( - self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True - ) -> Any: - values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args] - for value in values: - if not isinstance(value, (Sequence, dict)): - raise ValueError(f"XCom concat expects sequence or dict, not {type(value).__name__}") - return _ConcatResult(values) + return session.scalar(query) -_XCOM_ARG_TYPES: Mapping[str, type[XComArg]] = { - "": PlainXComArg, - "concat": ConcatXComArg, - "map": MapXComArg, - "zip": ZipXComArg, -} +@get_task_map_length.register +def _(xcom_arg: MapXComArg, run_id: str, *, session: Session): + return get_task_map_length(xcom_arg.arg, run_id, session=session) -def serialize_xcom_arg(value: XComArg) -> dict[str, Any]: - """DAG serialization interface.""" - key = next(k for k, v in _XCOM_ARG_TYPES.items() if isinstance(value, v)) - if key: - return {"type": key, **value._serialize()} - return value._serialize() +@get_task_map_length.register +def _(xcom_arg: ZipXComArg, run_id: str, *, session: Session): + all_lengths = (get_task_map_length(arg, run_id, session=session) for arg in xcom_arg.args) + ready_lengths = [length for length in all_lengths if length is not None] + if len(ready_lengths) != len(xcom_arg.args): + return None # If any of the referenced XComs is not ready, we are not ready either. + if isinstance(xcom_arg.fillvalue, ArgNotSet): + return min(ready_lengths) + return max(ready_lengths) -def deserialize_xcom_arg(data: dict[str, Any], dag: DAG) -> XComArg: - """DAG serialization interface.""" - klass = _XCOM_ARG_TYPES[data.get("type", "")] - return klass._deserialize(data, dag) +@get_task_map_length.register +def _(xcom_arg: ConcatXComArg, run_id: str, *, session: Session): + all_lengths = (get_task_map_length(arg, run_id, session=session) for arg in xcom_arg.args) + ready_lengths = [length for length in all_lengths if length is not None] + if len(ready_lengths) != len(xcom_arg.args): + return None # If any of the referenced XComs is not ready, we are not ready either. + return sum(ready_lengths) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 88e0f200bb24d..b7e08a45aed74 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -47,11 +47,9 @@ create_expand_input, get_map_type_key, ) -from airflow.models.mappedoperator import MappedOperator from airflow.models.param import Param, ParamsDict from airflow.models.taskinstance import SimpleTaskInstance from airflow.models.taskinstancekey import TaskInstanceKey -from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg from airflow.providers_manager import ProvidersManager from airflow.sdk.definitions.asset import ( Asset, @@ -65,6 +63,9 @@ BaseAsset, ) from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator +from airflow.sdk.definitions.mappedoperator import MappedOperator +from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup +from airflow.sdk.definitions.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors from airflow.serialization.dag_dependency import DagDependency from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding @@ -87,9 +88,8 @@ from airflow.utils.docs import get_docs_url from airflow.utils.module_loading import import_string, qualname from airflow.utils.operator_resources import Resources -from airflow.utils.task_group import MappedTaskGroup, TaskGroup from airflow.utils.timezone import from_timestamp, parse_timezone -from airflow.utils.types import NOTSET, ArgNotSet, AttributeRemoved +from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: from inspect import Parameter @@ -97,8 +97,8 @@ from airflow.models import DagRun from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.expandinput import ExpandInput - from airflow.models.operator import Operator from airflow.sdk.definitions._internal.node import DAGNode + from airflow.sdk.types import Operator from airflow.serialization.json_schema import Validator from airflow.timetables.base import DagRunInfo, DataInterval, Timetable @@ -1339,7 +1339,7 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None: if v is False: raise RuntimeError("_is_sensor=False should never have been serialized!") - object.__setattr__(op, "deps", op.deps | {ReadyToRescheduleDep()}) + object.__setattr__(op, "deps", op.deps | {ReadyToRescheduleDep()}) # type: ignore[union-attr] continue elif ( k in cls._decorated_fields @@ -1408,13 +1408,18 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator: op: Operator if encoded_op.get("_is_mapped", False): # Most of these will be loaded later, these are just some stand-ins. - op_data = {k: v for k, v in encoded_op.items() if k in BaseOperator.get_serialized_fields()} + op_data = { + k: v for k, v in encoded_op.items() if k in TaskSDKBaseOperator.get_serialized_fields() + } + + from airflow.models.mappedoperator import MappedOperator as MappedOperatorWithDB + try: operator_name = encoded_op["_operator_name"] except KeyError: operator_name = encoded_op["task_type"] - op = MappedOperator( + op = MappedOperatorWithDB( operator_class=op_data, expand_input=EXPAND_INPUT_EMPTY, partial_kwargs={}, @@ -1442,7 +1447,6 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator: ) else: op = SerializedBaseOperator(task_id=encoded_op["task_id"]) - op.dag = AttributeRemoved("dag") # type: ignore[assignment] cls.populate_operator(op, encoded_op) return op @@ -1455,12 +1459,7 @@ def detect_dependencies(cls, op: Operator) -> set[DagDependency]: @classmethod def _is_excluded(cls, var: Any, attrname: str, op: DAGNode): - if ( - var is not None - and op.has_dag() - and op.dag.__class__ is not AttributeRemoved - and attrname.endswith("_date") - ): + if var is not None and op.has_dag() and attrname.endswith("_date"): # If this date is the same as the matching field in the dag, then # don't store it again at the task level. dag_date = getattr(op.dag, attrname, None) diff --git a/airflow/ti_deps/deps/mapped_task_upstream_dep.py b/airflow/ti_deps/deps/mapped_task_upstream_dep.py index 247dc84f3b478..97531ef4257e6 100644 --- a/airflow/ti_deps/deps/mapped_task_upstream_dep.py +++ b/airflow/ti_deps/deps/mapped_task_upstream_dep.py @@ -51,7 +51,7 @@ def _get_dep_statuses( session: Session, dep_context: DepContext, ) -> Iterator[TIDepStatus]: - from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator if isinstance(ti.task, MappedOperator): mapped_dependencies = ti.task.iter_mapped_dependencies() diff --git a/airflow/ti_deps/deps/prev_dagrun_dep.py b/airflow/ti_deps/deps/prev_dagrun_dep.py index c756e6ec1c645..9ce5c1134240a 100644 --- a/airflow/ti_deps/deps/prev_dagrun_dep.py +++ b/airflow/ti_deps/deps/prev_dagrun_dep.py @@ -32,7 +32,7 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.operator import Operator + from airflow.sdk.types import Operator _SUCCESSFUL_STATES = (TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS) diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index ccd11ab303e66..f49e62d4b4605 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -139,10 +139,12 @@ def _get_expanded_ti_count() -> int: This extra closure allows us to query the database only when needed, and at most once. """ + from airflow.models.baseoperator import BaseOperator + if TYPE_CHECKING: assert ti.task - return ti.task.get_mapped_ti_count(ti.run_id, session=session) + return BaseOperator.get_mapped_ti_count(ti.task, ti.run_id, session=session) @functools.lru_cache def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None: diff --git a/airflow/utils/context.py b/airflow/utils/context.py index a36202f0793ec..74479a328a1a2 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -62,7 +62,7 @@ from sqlalchemy.orm import Session from sqlalchemy.sql.expression import Select, TextClause - from airflow.models.baseoperator import BaseOperator + from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.types import OutletEventAccessorsProtocol # NOTE: Please keep this in sync with the following: diff --git a/airflow/utils/dag_edges.py b/airflow/utils/dag_edges.py index bd1ad268aefed..aafecbf308232 100644 --- a/airflow/utils/dag_edges.py +++ b/airflow/utils/dag_edges.py @@ -18,11 +18,11 @@ from typing import TYPE_CHECKING -from airflow.models.abstractoperator import AbstractOperator +from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator if TYPE_CHECKING: - from airflow.models import Operator - from airflow.models.dag import DAG + from airflow.sdk import DAG + from airflow.sdk.types import Operator def dag_edges(dag: DAG): diff --git a/airflow/utils/dot_renderer.py b/airflow/utils/dot_renderer.py index 877d16450d700..c4f1e45dc1f5e 100644 --- a/airflow/utils/dot_renderer.py +++ b/airflow/utils/dot_renderer.py @@ -36,8 +36,8 @@ graphviz = None from airflow.exceptions import AirflowException -from airflow.models.baseoperator import BaseOperator -from airflow.models.mappedoperator import MappedOperator +from airflow.sdk import BaseOperator +from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.utils.dag_edges import dag_edges from airflow.utils.state import State from airflow.utils.task_group import TaskGroup diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index a8abfe0ea8638..73ee79126a9ee 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -282,7 +282,8 @@ def _render_filename(self, ti: TaskInstance, try_number: int, session=NEW_SESSIO if str_tpl: if ti.task is not None and ti.task.dag is not None: dag = ti.task.dag - data_interval = dag.get_run_data_interval(dag_run) + # TODO: TaskSDK: why do we need this on the DAG! Where is this render fn called from. Revisit + data_interval = dag.get_run_data_interval(dag_run) # type: ignore[attr-defined] else: from airflow.timetables.base import DataInterval diff --git a/airflow/utils/setup_teardown.py b/airflow/utils/setup_teardown.py index 37e1a03457971..3108657d30ac2 100644 --- a/airflow/utils/setup_teardown.py +++ b/airflow/utils/setup_teardown.py @@ -22,9 +22,9 @@ from airflow.exceptions import AirflowException if TYPE_CHECKING: - from airflow.models.abstractoperator import AbstractOperator from airflow.models.taskmixin import DependencyMixin from airflow.models.xcom_arg import PlainXComArg + from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator class BaseSetupTeardownContext: diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index dafd87ab7df02..b74921f75d533 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -29,29 +29,28 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.operator import Operator from airflow.typing_compat import TypeAlias TaskGroup: TypeAlias = airflow.sdk.definitions.taskgroup.TaskGroup -class MappedTaskGroup(airflow.sdk.definitions.taskgroup.MappedTaskGroup): - """ - A mapped task group. - - This doesn't really do anything special, just holds some additional metadata - for expansion later. +class MappedTaskGroup(airflow.sdk.definitions.taskgroup.MappedTaskGroup): # noqa: D101 + # TODO: Rename this to SerializedMappedTaskGroup perhaps? - Don't instantiate this class directly; call *expand* or *expand_kwargs* on - a ``@task_group`` function instead. - """ + def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: + """ + Return mapped task groups in the hierarchy. - def iter_mapped_dependencies(self) -> Iterator[Operator]: - """Upstream dependencies that provide XComs used by this mapped task group.""" - from airflow.models.xcom_arg import XComArg + Groups are returned from the closest to the outmost. If *self* is a + mapped task group, it is returned first. - for op, _ in XComArg.iter_xcom_references(self._expand_input): - yield op + :meta private: + """ + group: TaskGroup | None = self + while group is not None: + if isinstance(group, MappedTaskGroup): + yield group + group = group.parent_group def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: """ @@ -79,9 +78,9 @@ def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: def task_group_to_dict(task_item_or_group): """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" - from airflow.models.abstractoperator import AbstractOperator - from airflow.models.baseoperator import BaseOperator - from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator + from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator if isinstance(task := task_item_or_group, AbstractOperator): setup_teardown_type = {} diff --git a/airflow/utils/types.py b/airflow/utils/types.py index 12870210efb82..46f295c4ee21a 100644 --- a/airflow/utils/types.py +++ b/airflow/utils/types.py @@ -31,32 +31,6 @@ NOTSET = airflow.sdk.definitions._internal.types.NOTSET -class AttributeRemoved: - """ - Sentinel type to signal when attribute removed on serialization. - - :meta private: - """ - - def __init__(self, attribute_name: str): - self.attribute_name = attribute_name - - def __getattr__(self, item): - if item == "attribute_name": - return super().__getattribute__(item) - raise RuntimeError( - f"Attribute {self.attribute_name} was removed on " - f"serialization and must be set again - found when accessing {item}." - ) - - -""" -Sentinel value for attributes removed on serialization. - -:meta private: -""" - - class DagRunType(str, enum.Enum): """Class with DagRun types.""" diff --git a/airflow/www/views.py b/airflow/www/views.py index bad206e183089..dd5bc856f74a3 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -440,7 +440,8 @@ def set_overall_state(record): "id": item.task_id, "instances": instances, "label": item.label, - "extra_links": item.extra_links, + # TODO: Task-SDK: MappedOperator doesn't support extra links right now + "extra_links": getattr(item, "extra_links", []), "is_mapped": item_is_mapped, "has_outlet_assets": any(isinstance(i, (Asset, AssetAlias)) for i in (item.outlets or [])), "operator": item.operator_name, diff --git a/providers/src/airflow/providers/elasticsearch/log/es_task_handler.py b/providers/src/airflow/providers/elasticsearch/log/es_task_handler.py index 9881f1ac5dca7..15904e7ebf3b4 100644 --- a/providers/src/airflow/providers/elasticsearch/log/es_task_handler.py +++ b/providers/src/airflow/providers/elasticsearch/log/es_task_handler.py @@ -241,7 +241,8 @@ def _render_log_id(self, ti: TaskInstance | TaskInstanceKey, try_number: int) -> else: if TYPE_CHECKING: assert dag is not None - data_interval = dag.get_run_data_interval(dag_run) + # TODO: Task-SDK: Where should this function be? + data_interval = dag.get_run_data_interval(dag_run) # type: ignore[attr-defined] if self.json_format: data_interval_start = self._clean_date(data_interval[0]) diff --git a/providers/src/airflow/providers/openlineage/utils/selective_enable.py b/providers/src/airflow/providers/openlineage/utils/selective_enable.py index a3c16a1e18da3..3b6331a7cd6d7 100644 --- a/providers/src/airflow/providers/openlineage/utils/selective_enable.py +++ b/providers/src/airflow/providers/openlineage/utils/selective_enable.py @@ -18,11 +18,20 @@ from __future__ import annotations import logging -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar -from airflow.models import DAG, Operator, Param +from airflow.models import Param from airflow.models.xcom_arg import XComArg +if TYPE_CHECKING: + from airflow.sdk import DAG + from airflow.sdk.definitions._internal.abstractoperator import Operator +else: + try: + from airflow.sdk import DAG + except ImportError: + from airflow.models import DAG + ENABLE_OL_PARAM_NAME = "_selective_enable_ol" ENABLE_OL_PARAM = Param(True, const=True) DISABLE_OL_PARAM = Param(False, const=False) diff --git a/providers/src/airflow/providers/openlineage/utils/utils.py b/providers/src/airflow/providers/openlineage/utils/utils.py index 4408a833fba68..734437f44adc0 100644 --- a/providers/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/src/airflow/providers/openlineage/utils/utils.py @@ -32,7 +32,7 @@ from airflow import __version__ as AIRFLOW_VERSION # TODO: move this maybe to Airflow's logic? -from airflow.models import DAG, BaseOperator, DagRun, MappedOperator, TaskReschedule +from airflow.models import DagRun, TaskReschedule from airflow.providers.openlineage import ( __version__ as OPENLINEAGE_PROVIDER_VERSION, conf, @@ -70,13 +70,19 @@ from airflow.models import TaskInstance from airflow.providers.common.compat.assets import Asset + from airflow.sdk import DAG, BaseOperator, MappedOperator from airflow.utils.state import DagRunState, TaskInstanceState else: + try: + from airflow.sdk import DAG, BaseOperator, MappedOperator + except ImportError: + from airflow.models import DAG, BaseOperator, MappedOperator + try: from airflow.providers.common.compat.assets import Asset except ImportError: if AIRFLOW_V_3_0_PLUS: - from airflow.sdk.definitions.asset import Asset + from airflow.sdk import Asset else: # dataset is renamed to asset since Airflow 3.0 from airflow.datasets import Dataset as Asset @@ -565,8 +571,8 @@ def _emits_ol_events(task: BaseOperator | MappedOperator) -> bool: is_skipped_as_empty_operator = all( ( task.inherits_from_empty_operator, - not task.on_execute_callback, - not task.on_success_callback, + not getattr(task, "on_execute_callback", None), + not getattr(task, "on_success_callback", None), not task.outlets, ) ) diff --git a/providers/src/airflow/providers/opensearch/log/os_task_handler.py b/providers/src/airflow/providers/opensearch/log/os_task_handler.py index 8e784076c5f8c..e1a0a083e7291 100644 --- a/providers/src/airflow/providers/opensearch/log/os_task_handler.py +++ b/providers/src/airflow/providers/opensearch/log/os_task_handler.py @@ -294,7 +294,8 @@ def _render_log_id(self, ti: TaskInstance | TaskInstanceKey, try_number: int) -> else: if TYPE_CHECKING: assert dag is not None - data_interval = dag.get_run_data_interval(dag_run) + # TODO: Task-SDK: Where should this function be? + data_interval = dag.get_run_data_interval(dag_run) # type: ignore[attr-defined] if self.json_format: data_interval_start = self._clean_date(data_interval[0]) diff --git a/providers/src/airflow/providers/standard/operators/python.py b/providers/src/airflow/providers/standard/operators/python.py index 86f5f0156d243..7afaaecfaee3d 100644 --- a/providers/src/airflow/providers/standard/operators/python.py +++ b/providers/src/airflow/providers/standard/operators/python.py @@ -1165,7 +1165,7 @@ def my_task(): def _get_current_context() -> Mapping[str, Any]: # Airflow 2.x # TODO: To be removed when Airflow 2 support is dropped - from airflow.models.taskinstance import _CURRENT_CONTEXT + from airflow.models.taskinstance import _CURRENT_CONTEXT # type: ignore[attr-defined] if not _CURRENT_CONTEXT: raise RuntimeError( diff --git a/providers/tests/openlineage/utils/test_utils.py b/providers/tests/openlineage/utils/test_utils.py index 28be0b6306751..3f653d9025ef9 100644 --- a/providers/tests/openlineage/utils/test_utils.py +++ b/providers/tests/openlineage/utils/test_utils.py @@ -25,7 +25,6 @@ from airflow.decorators import task from airflow.models.baseoperator import BaseOperator from airflow.models.dagrun import DagRun -from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance, TaskInstanceState from airflow.operators.empty import EmptyOperator from airflow.providers.openlineage.plugins.facets import AirflowDagRunFacet, AirflowJobFacet @@ -186,7 +185,6 @@ def test_get_fully_qualified_class_name_serialized_operator(): def test_get_fully_qualified_class_name_mapped_operator(): mapped = MockOperator.partial(task_id="task_2").expand(arg2=["a", "b", "c"]) - assert isinstance(mapped, MappedOperator) mapped_op_path = get_fully_qualified_class_name(mapped) assert mapped_op_path == "tests_common.test_utils.mock_operators.MockOperator" @@ -216,7 +214,6 @@ def test_get_operator_class(): def test_get_operator_class_mapped_operator(): mapped = MockOperator.partial(task_id="task").expand(arg2=["a", "b", "c"]) - assert isinstance(mapped, MappedOperator) op_class = get_operator_class(mapped) assert op_class == MockOperator diff --git a/providers/tests/standard/operators/test_python.py b/providers/tests/standard/operators/test_python.py index c43c00dd0e814..039522e98340c 100644 --- a/providers/tests/standard/operators/test_python.py +++ b/providers/tests/standard/operators/test_python.py @@ -925,6 +925,7 @@ def test_virtualenv_serializable_context_fields(self, create_task_instance): "prev_execution_date", "prev_execution_date_success", "conf", + "expanded_ti_count", } else: declared_keys.remove("triggering_asset_events") diff --git a/scripts/ci/pre_commit/base_operator_partial_arguments.py b/scripts/ci/pre_commit/base_operator_partial_arguments.py index b50705331700e..070ffe767cccc 100755 --- a/scripts/ci/pre_commit/base_operator_partial_arguments.py +++ b/scripts/ci/pre_commit/base_operator_partial_arguments.py @@ -28,7 +28,9 @@ BASEOPERATOR_PY = ROOT_DIR.joinpath("airflow", "models", "baseoperator.py") SDK_BASEOPERATOR_PY = ROOT_DIR.joinpath("task_sdk", "src", "airflow", "sdk", "definitions", "baseoperator.py") -MAPPEDOPERATOR_PY = ROOT_DIR.joinpath("airflow", "models", "mappedoperator.py") +SDK_MAPPEDOPERATOR_PY = ROOT_DIR.joinpath( + "task_sdk", "src", "airflow", "sdk", "definitions", "mappedoperator.py" +) IGNORED = { # These are only used in the worker and thus mappable. @@ -72,7 +74,7 @@ BO_MOD = ast.parse(BASEOPERATOR_PY.read_text("utf-8"), str(BASEOPERATOR_PY)) SDK_BO_MOD = ast.parse(SDK_BASEOPERATOR_PY.read_text("utf-8"), str(SDK_BASEOPERATOR_PY)) -MO_MOD = ast.parse(MAPPEDOPERATOR_PY.read_text("utf-8"), str(MAPPEDOPERATOR_PY)) +SDK_MO_MOD = ast.parse(SDK_MAPPEDOPERATOR_PY.read_text("utf-8"), str(SDK_MAPPEDOPERATOR_PY)) # TODO: Task-SDK: Look at the BaseOperator init functions in both airflow.models.baseoperator and combine # them, until we fully remove BaseOperator class from core. @@ -100,19 +102,19 @@ ) # We now define the signature in a type checking block, the runtime impl uses **kwargs -BO_TYPE_CHECKING_BLOCKS = ( +SDK_BO_TYPE_CHECKING_BLOCKS = ( node - for node in ast.iter_child_nodes(BO_MOD) + for node in ast.iter_child_nodes(SDK_BO_MOD) if isinstance(node, ast.If) and node.test.id == "TYPE_CHECKING" # type: ignore[attr-defined] ) BO_PARTIAL = next( node - for node in itertools.chain.from_iterable(map(ast.iter_child_nodes, BO_TYPE_CHECKING_BLOCKS)) + for node in itertools.chain.from_iterable(map(ast.iter_child_nodes, SDK_BO_TYPE_CHECKING_BLOCKS)) if isinstance(node, ast.FunctionDef) and node.name == "partial" ) MO_CLS = next( node - for node in ast.iter_child_nodes(MO_MOD) + for node in ast.iter_child_nodes(SDK_MO_MOD) if isinstance(node, ast.ClassDef) and node.name == "MappedOperator" ) diff --git a/scripts/ci/pre_commit/template_context_key_sync.py b/scripts/ci/pre_commit/template_context_key_sync.py index 2f6c6021b2ed1..501801be55d97 100755 --- a/scripts/ci/pre_commit/template_context_key_sync.py +++ b/scripts/ci/pre_commit/template_context_key_sync.py @@ -102,6 +102,9 @@ def _compare_keys(retn_keys: set[str], decl_keys: set[str], hint_keys: set[str], retn_keys.add("templates_dict") docs_keys.add("templates_dict") + # Compat shim for task-sdk, not actually designed for user use + retn_keys.add("expanded_ti_count") + # Only present in callbacks. Not listed in templates-ref (that doc is for task execution). retn_keys.update(("exception", "reason", "try_number")) docs_keys.update(("exception", "reason", "try_number")) diff --git a/task_sdk/src/airflow/sdk/__init__.py b/task_sdk/src/airflow/sdk/__init__.py index 1bd0358a63c7e..b8d6b6609dba7 100644 --- a/task_sdk/src/airflow/sdk/__init__.py +++ b/task_sdk/src/airflow/sdk/__init__.py @@ -20,15 +20,17 @@ __all__ = [ "BaseOperator", + "Connection", "DAG", "EdgeModifier", "Label", + "MappedOperator", "TaskGroup", + "XComArg", + "__version__", "dag", - "Connection", "get_current_context", "get_parsing_context", - "__version__", ] __version__ = "1.0.0.dev1" @@ -39,17 +41,21 @@ from airflow.sdk.definitions.context import get_current_context, get_parsing_context from airflow.sdk.definitions.dag import DAG, dag from airflow.sdk.definitions.edges import EdgeModifier, Label + from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.taskgroup import TaskGroup + from airflow.sdk.definitions.xcom_arg import XComArg __lazy_imports: dict[str, str] = { - "DAG": ".definitions.dag", - "dag": ".definitions.dag", "BaseOperator": ".definitions.baseoperator", - "TaskGroup": ".definitions.taskgroup", + "Connection": ".definitions.connection", + "DAG": ".definitions.dag", "EdgeModifier": ".definitions.edges", "Label": ".definitions.edges", - "Connection": ".definitions.connection", + "MappedOperator": ".definitions.mappedoperator", + "TaskGroup": ".definitions.taskgroup", "Variable": ".definitions.variable", + "XComArg": ".definitions.xcom_arg", + "dag": ".definitions.dag", "get_current_context": ".definitions.context", "get_parsing_context": ".definitions.context", } diff --git a/task_sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py b/task_sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py index b3435f3091e83..481fb2ee773d7 100644 --- a/task_sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py @@ -21,18 +21,20 @@ import logging from abc import abstractmethod from collections.abc import ( + Callable, Collection, Iterable, + Iterator, + Mapping, ) -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, -) +from typing import TYPE_CHECKING, Any, ClassVar + +import methodtools from airflow.sdk.definitions._internal.mixins import DependencyMixin from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.definitions._internal.templater import Templater +from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule @@ -40,13 +42,18 @@ import jinja2 from airflow.models.baseoperatorlink import BaseOperatorLink - from airflow.models.operator import Operator from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.dag import DAG + from airflow.sdk.definitions.mappedoperator import MappedOperator + from airflow.sdk.definitions.taskgroup import MappedTaskGroup + from airflow.sdk.types import Operator + +TaskStateChangeCallback = Callable[[Mapping[str, Any]], None] DEFAULT_OWNER: str = "airflow" DEFAULT_POOL_SLOTS: int = 1 +DEFAULT_POOL_NAME = "default_pool" DEFAULT_PRIORITY_WEIGHT: int = 1 # Databases do not support arbitrary precision integers, so we need to limit the range of priority weights. # postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html) @@ -93,7 +100,7 @@ class AbstractOperator(Templater, DAGNode): priority_weight: int # Defines the operator level extra links. - operator_extra_links: Collection[BaseOperatorLink] + operator_extra_links: Collection[BaseOperatorLink] = () owner: str task_id: str @@ -163,6 +170,10 @@ def node_id(self) -> str: @abstractmethod def task_display_name(self) -> str: ... + @property + def is_mapped(self): + return self._is_mapped + @property def label(self) -> str | None: if self.task_display_name and self.task_display_name != self.task_id: @@ -174,6 +185,29 @@ def label(self) -> str | None: return self.task_id[len(tg.node_id) + 1 :] return self.task_id + @property + def on_failure_fail_dagrun(self): + """ + Whether the operator should fail the dagrun on failure. + + :meta private: + """ + return self._on_failure_fail_dagrun + + @on_failure_fail_dagrun.setter + def on_failure_fail_dagrun(self, value): + """ + Setter for on_failure_fail_dagrun property. + + :meta private: + """ + if value is True and self.is_teardown is not True: + raise ValueError( + f"Cannot set task on_failure_fail_dagrun for " + f"'{self.task_id}' because it is not a teardown task." + ) + self._on_failure_fail_dagrun = value + def as_setup(self): self.is_setup = True return self @@ -195,6 +229,15 @@ def as_teardown( s >> self return self + def __enter__(self): + if not self.is_setup and not self.is_teardown: + raise RuntimeError("Only setup/teardown tasks can be used as context managers.") + SetupTeardownContext.push_setup_teardown_task(self) + return SetupTeardownContext + + def __exit__(self, exc_type, exc_val, exc_tb): + SetupTeardownContext.set_work_task_roots_and_leaves() + def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]: """ Get a flat set of relative IDs, upstream or downstream. @@ -339,3 +382,111 @@ def _do_render_template_fields( raise else: setattr(parent, attr_name, rendered_content) + + def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: + """ + Return mapped nodes that are direct dependencies of the current task. + + For now, this walks the entire DAG to find mapped nodes that has this + current task as an upstream. We cannot use ``downstream_list`` since it + only contains operators, not task groups. In the future, we should + provide a way to record an DAG node's all downstream nodes instead. + + Note that this does not guarantee the returned tasks actually use the + current task for task mapping, but only checks those task are mapped + operators, and are downstreams of the current task. + + To get a list of tasks that uses the current task for task mapping, use + :meth:`iter_mapped_dependants` instead. + """ + from airflow.sdk.definitions.mappedoperator import MappedOperator + from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup + + def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]: + """ + Recursively walk children in a task group. + + This yields all direct children (including both tasks and task + groups), and all children of any task groups. + """ + for key, child in group.children.items(): + yield key, child + if isinstance(child, TaskGroup): + yield from _walk_group(child) + + dag = self.get_dag() + if not dag: + raise RuntimeError("Cannot check for mapped dependants when not attached to a DAG") + for key, child in _walk_group(dag.task_group): + if key == self.node_id: + continue + if not isinstance(child, (MappedOperator, MappedTaskGroup)): + continue + if self.node_id in child.upstream_task_ids: + yield child + + def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]: + """ + Return mapped nodes that depend on the current task the expansion. + + For now, this walks the entire DAG to find mapped nodes that has this + current task as an upstream. We cannot use ``downstream_list`` since it + only contains operators, not task groups. In the future, we should + provide a way to record an DAG node's all downstream nodes instead. + """ + return ( + downstream + for downstream in self._iter_all_mapped_downstreams() + if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies()) + ) + + def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: + """ + Return mapped task groups this task belongs to. + + Groups are returned from the innermost to the outmost. + + :meta private: + """ + if (group := self.task_group) is None: + return + yield from group.iter_mapped_task_groups() + + def get_closest_mapped_task_group(self) -> MappedTaskGroup | None: + """ + Get the mapped task group "closest" to this task in the DAG. + + :meta private: + """ + return next(self.iter_mapped_task_groups(), None) + + def get_needs_expansion(self) -> bool: + """ + Return true if the task is MappedOperator or is in a mapped task group. + + :meta private: + """ + if self._needs_expansion is None: + if self.get_closest_mapped_task_group() is not None: + self._needs_expansion = True + else: + self._needs_expansion = False + return self._needs_expansion + + @methodtools.lru_cache(maxsize=None) + def get_parse_time_mapped_ti_count(self) -> int: + """ + Return the number of mapped task instances that can be created on DAG run creation. + + This only considers literal mapped arguments, and would return *None* + when any non-literal values are used for mapping. + + :raise NotFullyPopulated: If non-literal mapped arguments are encountered. + :raise NotMapped: If the operator is neither mapped, nor has any parent + mapped task groups. + :return: Total number of mapped TIs this task should have. + """ + group = self.get_closest_mapped_task_group() + if group is None: + raise NotMapped() + return group.get_parse_time_mapped_ti_count() diff --git a/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py b/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py index bdd2a5c08c755..fcd68ba20b2c6 100644 --- a/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py +++ b/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py @@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from airflow.models.operator import Operator + from airflow.sdk.definitions._internal.abstractoperator import Operator from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.edges import EdgeModifier diff --git a/task_sdk/src/airflow/sdk/definitions/_internal/node.py b/task_sdk/src/airflow/sdk/definitions/_internal/node.py index b8c0260911837..fb69c7b926016 100644 --- a/task_sdk/src/airflow/sdk/definitions/_internal/node.py +++ b/task_sdk/src/airflow/sdk/definitions/_internal/node.py @@ -30,8 +30,8 @@ from airflow.sdk.definitions._internal.mixins import DependencyMixin if TYPE_CHECKING: - from airflow.models.operator import Operator from airflow.sdk.definitions._internal.types import Logger + from airflow.sdk.definitions.abstractoperator import Operator from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.edges import EdgeModifier from airflow.sdk.definitions.taskgroup import TaskGroup @@ -99,8 +99,8 @@ def dag_id(self) -> str: return self.dag.dag_id return "_in_memory_dag_" + @methodtools.lru_cache() # type: ignore[misc] @property - @methodtools.lru_cache() def log(self) -> Logger: typ = type(self) name = f"{typ.__module__}.{typ.__qualname__}" @@ -123,8 +123,8 @@ def _set_relatives( edge_modifier: EdgeModifier | None = None, ) -> None: """Set relatives for the task or task list.""" - from airflow.models.mappedoperator import MappedOperator from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator if not isinstance(task_or_task_list, Sequence): task_or_task_list = [task_or_task_list] diff --git a/task_sdk/src/airflow/sdk/definitions/_internal/templater.py b/task_sdk/src/airflow/sdk/definitions/_internal/templater.py index 58a2b450eb3d8..d7028d4d6bca7 100644 --- a/task_sdk/src/airflow/sdk/definitions/_internal/templater.py +++ b/task_sdk/src/airflow/sdk/definitions/_internal/templater.py @@ -177,9 +177,9 @@ def render_template( return self._render(template, context) if isinstance(value, ObjectStoragePath): return self._render_object_storage_path(value, context, jinja_env) - if isinstance(value, ResolveMixin): - # TODO: Task-SDK: Tidy up the typing on template context - return value.resolve(context, include_xcom=True) # type: ignore[arg-type] + + if resolve := getattr(value, "resolve", None): + return resolve(context, include_xcom=True) # Fast path for common built-in collections. if value.__class__ is tuple: diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index 51d4abbeda45d..91ebc4ab6cb8b 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -111,14 +111,8 @@ def normalize_noop(parts: SplitResult) -> SplitResult: def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | None: if scheme == "file": return normalize_noop - from packaging.version import Version - - from airflow import __version__ as AIRFLOW_VERSION from airflow.providers_manager import ProvidersManager - AIRFLOW_V_2 = Version(AIRFLOW_VERSION).base_version < Version("3.0.0").base_version - if AIRFLOW_V_2: - return ProvidersManager().dataset_uri_handlers.get(scheme) # type: ignore[attr-defined] return ProvidersManager().asset_uri_handlers.get(scheme) diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index 2258524525193..e7ecec69411ba 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -24,7 +24,7 @@ import inspect import sys import warnings -from collections.abc import Collection, Iterable, Sequence +from collections.abc import Callable, Collection, Iterable, Sequence from dataclasses import dataclass, field from datetime import datetime, timedelta from functools import total_ordering, wraps @@ -37,6 +37,7 @@ from airflow.sdk.definitions._internal.abstractoperator import ( DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, DEFAULT_OWNER, + DEFAULT_POOL_NAME, DEFAULT_POOL_SLOTS, DEFAULT_PRIORITY_WEIGHT, DEFAULT_QUEUE, @@ -47,10 +48,12 @@ DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, DEFAULT_WEIGHT_RULE, AbstractOperator, + TaskStateChangeCallback, ) from airflow.sdk.definitions._internal.decorators import fixup_decorator_warning_stack from airflow.sdk.definitions._internal.node import validate_key -from airflow.sdk.definitions._internal.types import NOTSET, validate_instance_args +from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, validate_instance_args +from airflow.sdk.definitions.mappedoperator import OperatorPartial, validate_mapping_kwargs from airflow.task.priority_strategy import ( PriorityWeightStrategy, airflow_priority_weight_strategies, @@ -59,12 +62,13 @@ from airflow.utils import timezone from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.trigger_rule import TriggerRule -from airflow.utils.types import AttributeRemoved from airflow.utils.weight_rule import db_safe_priority T = TypeVar("T", bound=FunctionType) if TYPE_CHECKING: + from types import ClassMethodDescriptorType + import jinja2 from airflow.models.xcom_arg import XComArg @@ -72,6 +76,7 @@ from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.serialization.enums import DagAttributeTypes + from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.typing_compat import Self from airflow.utils.operator_resources import Resources @@ -111,6 +116,189 @@ def get_merged_defaults( return args, params +class _PartialDescriptor: + """A descriptor that guards against ``.partial`` being called on Task objects.""" + + class_method: ClassMethodDescriptorType | None = None + + def __get__( + self, obj: BaseOperator, cls: type[BaseOperator] | None = None + ) -> Callable[..., OperatorPartial]: + # Call this "partial" so it looks nicer in stack traces. + def partial(**kwargs): + raise TypeError("partial can only be called on Operator classes, not Tasks themselves") + + if obj is not None: + return partial + return self.class_method.__get__(cls, cls) + + +_PARTIAL_DEFAULTS: dict[str, Any] = { + "map_index_template": None, + "owner": DEFAULT_OWNER, + "trigger_rule": DEFAULT_TRIGGER_RULE, + "depends_on_past": False, + "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, + "wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, + "wait_for_downstream": False, + "retries": DEFAULT_RETRIES, + # "executor": DEFAULT_EXECUTOR, + "queue": DEFAULT_QUEUE, + "pool_slots": DEFAULT_POOL_SLOTS, + "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT, + "retry_delay": DEFAULT_RETRY_DELAY, + "retry_exponential_backoff": False, + "priority_weight": DEFAULT_PRIORITY_WEIGHT, + "weight_rule": DEFAULT_WEIGHT_RULE, + "inlets": [], + "outlets": [], + "allow_nested_operators": True, +} + + +# This is what handles the actual mapping. + +if TYPE_CHECKING: + + def partial( + operator_class: type[BaseOperator], + *, + task_id: str, + dag: DAG | None = None, + task_group: TaskGroup | None = None, + start_date: datetime | ArgNotSet = NOTSET, + end_date: datetime | ArgNotSet = NOTSET, + owner: str | ArgNotSet = NOTSET, + email: None | str | Iterable[str] | ArgNotSet = NOTSET, + params: collections.abc.MutableMapping | None = None, + resources: dict[str, Any] | None | ArgNotSet = NOTSET, + trigger_rule: str | ArgNotSet = NOTSET, + depends_on_past: bool | ArgNotSet = NOTSET, + ignore_first_depends_on_past: bool | ArgNotSet = NOTSET, + wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET, + wait_for_downstream: bool | ArgNotSet = NOTSET, + retries: int | None | ArgNotSet = NOTSET, + queue: str | ArgNotSet = NOTSET, + pool: str | ArgNotSet = NOTSET, + pool_slots: int | ArgNotSet = NOTSET, + execution_timeout: timedelta | None | ArgNotSet = NOTSET, + max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET, + retry_delay: timedelta | float | ArgNotSet = NOTSET, + retry_exponential_backoff: bool | ArgNotSet = NOTSET, + priority_weight: int | ArgNotSet = NOTSET, + weight_rule: str | PriorityWeightStrategy | ArgNotSet = NOTSET, + sla: timedelta | None | ArgNotSet = NOTSET, + map_index_template: str | None | ArgNotSet = NOTSET, + max_active_tis_per_dag: int | None | ArgNotSet = NOTSET, + max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET, + on_execute_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + on_failure_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + on_success_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + on_retry_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + on_skipped_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + run_as_user: str | None | ArgNotSet = NOTSET, + executor: str | None | ArgNotSet = NOTSET, + executor_config: dict | None | ArgNotSet = NOTSET, + inlets: Any | None | ArgNotSet = NOTSET, + outlets: Any | None | ArgNotSet = NOTSET, + doc: str | None | ArgNotSet = NOTSET, + doc_md: str | None | ArgNotSet = NOTSET, + doc_json: str | None | ArgNotSet = NOTSET, + doc_yaml: str | None | ArgNotSet = NOTSET, + doc_rst: str | None | ArgNotSet = NOTSET, + task_display_name: str | None | ArgNotSet = NOTSET, + logger_name: str | None | ArgNotSet = NOTSET, + allow_nested_operators: bool = True, + **kwargs, + ) -> OperatorPartial: ... +else: + + def partial( + operator_class: type[BaseOperator], + *, + task_id: str, + dag: DAG | None = None, + task_group: TaskGroup | None = None, + params: collections.abc.MutableMapping | None = None, + **kwargs, + ): + from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext + + validate_mapping_kwargs(operator_class, "partial", kwargs) + + dag = dag or DagContext.get_current() + if dag: + task_group = task_group or TaskGroupContext.get_current(dag) + if task_group: + task_id = task_group.child_id(task_id) + + # Merge DAG and task group level defaults into user-supplied values. + dag_default_args, partial_params = get_merged_defaults( + dag=dag, + task_group=task_group, + task_params=params, + task_default_args=kwargs.pop("default_args", None), + ) + + # Create partial_kwargs from args and kwargs + partial_kwargs: dict[str, Any] = { + "task_id": task_id, + "dag": dag, + "task_group": task_group, + **kwargs, + } + + # Inject DAG-level default args into args provided to this function. + partial_kwargs.update( + (k, v) for k, v in dag_default_args.items() if partial_kwargs.get(k, NOTSET) is NOTSET + ) + + # Fill fields not provided by the user with default values. + for k, v in _PARTIAL_DEFAULTS.items(): + partial_kwargs.setdefault(k, v) + + # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). + if "task_concurrency" in kwargs: # Reject deprecated option. + raise TypeError("unexpected argument: task_concurrency") + if start_date := partial_kwargs.get("start_date", None): + partial_kwargs["start_date"] = timezone.convert_to_utc(start_date) + if end_date := partial_kwargs.get("end_date", None): + partial_kwargs["end_date"] = timezone.convert_to_utc(end_date) + if partial_kwargs["pool_slots"] < 1: + dag_str = "" + if dag: + dag_str = f" in dag {dag.dag_id}" + raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") + if retries := partial_kwargs.get("retries"): + partial_kwargs["retries"] = BaseOperator._convert_retries(retries) + partial_kwargs["retry_delay"] = BaseOperator._convert_retry_delay(partial_kwargs["retry_delay"]) + partial_kwargs["max_retry_delay"] = BaseOperator._convert_max_retry_delay( + partial_kwargs.get("max_retry_delay", None) + ) + partial_kwargs.setdefault("executor_config", {}) + + return OperatorPartial( + operator_class=operator_class, + kwargs=partial_kwargs, + params=partial_params, + ) + + class BaseOperatorMeta(abc.ABCMeta): """Metaclass of BaseOperator.""" @@ -224,11 +412,9 @@ def __new__(cls, name, bases, namespace, **kwargs): with contextlib.suppress(KeyError): # Update the partial descriptor with the class method, so it calls the actual function # (but let subclasses override it if they need to) - # TODO: Task-SDK - # partial_desc = vars(new_cls)["partial"] - # if isinstance(partial_desc, _PartialDescriptor): - # partial_desc.class_method = classmethod(partial) - ... + partial_desc = vars(new_cls)["partial"] + if isinstance(partial_desc, _PartialDescriptor): + partial_desc.class_method = classmethod(partial) # We patch `__init__` only if the class defines it. if inspect.getmro(new_cls)[1].__init__ is not new_cls.__init__: @@ -533,7 +719,7 @@ def say_hello_world(**context): default_factory=airflow_priority_weight_strategies[DEFAULT_WEIGHT_RULE] ) queue: str = DEFAULT_QUEUE - pool: str = "default" + pool: str = DEFAULT_POOL_NAME pool_slots: int = DEFAULT_POOL_SLOTS execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT # on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None @@ -579,8 +765,7 @@ def say_hello_world(**context): ui_color: str = "#fff" ui_fgcolor: str = "#000" - # TODO: Task-SDK Mapping - # partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore + partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore _dag: DAG | None = field(init=False, default=None) @@ -777,8 +962,7 @@ def __init__( # self.retries = parse_retries(retries) self.retries = retries self.queue = queue - # TODO: Task-SDK: pull this default name from Pool constant? - self.pool = "default_pool" if pool is None else pool + self.pool = DEFAULT_POOL_NAME if pool is None else pool self.pool_slots = pool_slots if self.pool_slots < 1: dag_str = f" in dag {dag.dag_id}" if dag else "" @@ -1006,25 +1190,17 @@ def dag(self) -> DAG: raise RuntimeError(f"Operator {self} has not been assigned to a DAG yet") @dag.setter - def dag(self, dag: DAG | None | AttributeRemoved) -> None: + def dag(self, dag: DAG | None) -> None: """Operators can be assigned to one DAG, one time. Repeat assignments to that same DAG are ok.""" - # TODO: Task-SDK: Remove the AttributeRemoved and this type ignore once we remove AIP-44 code - self._dag = dag # type: ignore[assignment] + self._dag = dag - def _convert__dag(self, dag: DAG | None | AttributeRemoved) -> DAG | None | AttributeRemoved: + def _convert__dag(self, dag: DAG | None) -> DAG | None: # Called automatically by __setattr__ method from airflow.sdk.definitions.dag import DAG if dag is None: return dag - # if set to removed, then just set and exit - if type(self._dag) is AttributeRemoved: - return dag - # if setting to removed, then just set and exit - if type(dag) is AttributeRemoved: - return AttributeRemoved("_dag") # type: ignore[assignment] - if not isinstance(dag, DAG): raise TypeError(f"Expected DAG; received {dag.__class__.__name__}") elif self._dag is not None and self._dag is not dag: @@ -1033,8 +1209,7 @@ def _convert__dag(self, dag: DAG | None | AttributeRemoved) -> DAG | None | Attr if self.__from_mapped: pass # Don't add to DAG -- the mapped task takes the place. elif dag.task_dict.get(self.task_id) is not self: - # TODO: Task-SDK: Remove this type ignore - dag.add_task(self) # type: ignore[arg-type] + dag.add_task(self) return dag @staticmethod diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index f4b71ec99584b..cd5217c8111d0 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -74,7 +74,7 @@ from pendulum.tz.timezone import FixedTimezone, Timezone from airflow.decorators import TaskDecoratorCollection - from airflow.models.operator import Operator + from airflow.sdk.definitions.abstractoperator import Operator from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.typing_compat import Self diff --git a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py new file mode 100644 index 0000000000000..0fc0a7fa1896a --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -0,0 +1,900 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import contextlib +import copy +import warnings +from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Union + +import attrs +import methodtools + +from airflow.models.abstractoperator import NotMapped +from airflow.models.expandinput import ( + DictOfListsExpandInput, + ListOfDictsExpandInput, + is_mappable, +) +from airflow.sdk.definitions._internal.abstractoperator import ( + DEFAULT_EXECUTOR, + DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, + DEFAULT_OWNER, + DEFAULT_POOL_NAME, + DEFAULT_POOL_SLOTS, + DEFAULT_PRIORITY_WEIGHT, + DEFAULT_QUEUE, + DEFAULT_RETRIES, + DEFAULT_RETRY_DELAY, + DEFAULT_TRIGGER_RULE, + DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, + DEFAULT_WEIGHT_RULE, + AbstractOperator, +) +from airflow.sdk.definitions._internal.types import NOTSET +from airflow.serialization.enums import DagAttributeTypes +from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy +from airflow.triggers.base import StartTriggerArgs +from airflow.typing_compat import Literal +from airflow.utils.helpers import is_container, prevent_duplicates +from airflow.utils.task_instance_session import get_current_task_instance_session +from airflow.utils.xcom import XCOM_RETURN_KEY + +if TYPE_CHECKING: + import datetime + + import jinja2 # Slow import. + import pendulum + from sqlalchemy.orm.session import Session + + from airflow.models.abstractoperator import ( + TaskStateChangeCallback, + ) + from airflow.models.baseoperatorlink import BaseOperatorLink + from airflow.models.expandinput import ( + ExpandInput, + OperatorExpandArgument, + OperatorExpandKwargsArgument, + ) + from airflow.models.param import ParamsDict + from airflow.models.xcom_arg import XComArg + from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.dag import DAG + from airflow.sdk.types import Operator + from airflow.ti_deps.deps.base_ti_dep import BaseTIDep + from airflow.utils.context import Context + from airflow.utils.operator_resources import Resources + from airflow.utils.task_group import TaskGroup + from airflow.utils.trigger_rule import TriggerRule + + TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, list[TaskStateChangeCallback]] + +ValidationSource = Union[Literal["expand"], Literal["partial"]] + + +def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: + # use a dict so order of args is same as code order + unknown_args = value.copy() + for klass in op.mro(): + init = klass.__init__ # type: ignore[misc] + try: + param_names = init._BaseOperatorMeta__param_names + except AttributeError: + continue + for name in param_names: + value = unknown_args.pop(name, NOTSET) + if func != "expand": + continue + if value is NOTSET: + continue + if is_mappable(value): + continue + type_name = type(value).__name__ + error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}" + raise ValueError(error) + if not unknown_args: + return # If we have no args left to check: stop looking at the MRO chain. + + if len(unknown_args) == 1: + error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}" + else: + names = ", ".join(repr(n) for n in unknown_args) + error = f"unexpected keyword arguments {names}" + raise TypeError(f"{op.__name__}.{func}() got {error}") + + +def ensure_xcomarg_return_value(arg: Any) -> None: + from airflow.sdk.definitions.xcom_arg import XComArg + + if isinstance(arg, XComArg): + for operator, key in arg.iter_references(): + if key != XCOM_RETURN_KEY: + raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}") + elif not is_container(arg): + return + elif isinstance(arg, Mapping): + for v in arg.values(): + ensure_xcomarg_return_value(v) + elif isinstance(arg, Iterable): + for v in arg: + ensure_xcomarg_return_value(v) + + +@attrs.define(kw_only=True, repr=False) +class OperatorPartial: + """ + An "intermediate state" returned by ``BaseOperator.partial()``. + + This only exists at DAG-parsing time; the only intended usage is for the + user to call ``.expand()`` on it at some point (usually in a method chain) to + create a ``MappedOperator`` to add into the DAG. + """ + + operator_class: type[BaseOperator] + kwargs: dict[str, Any] + params: ParamsDict | dict + + _expand_called: bool = False # Set when expand() is called to ease user debugging. + + def __attrs_post_init__(self): + validate_mapping_kwargs(self.operator_class, "partial", self.kwargs) + + def __repr__(self) -> str: + args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items()) + return f"{self.operator_class.__name__}.partial({args})" + + def __del__(self): + if not self._expand_called: + try: + task_id = repr(self.kwargs["task_id"]) + except KeyError: + task_id = f"at {hex(id(self))}" + warnings.warn(f"Task {task_id} was never mapped!", category=UserWarning, stacklevel=1) + + def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator: + if not mapped_kwargs: + raise TypeError("no arguments to expand against") + validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs) + prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") + # Since the input is already checked at parse time, we can set strict + # to False to skip the checks on execution. + return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False) + + def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator: + from airflow.models.xcom_arg import XComArg + + if isinstance(kwargs, Sequence): + for item in kwargs: + if not isinstance(item, (XComArg, Mapping)): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + elif not isinstance(kwargs, XComArg): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) + + def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: + from airflow.operators.empty import EmptyOperator + from airflow.sensors.base import BaseSensorOperator + + self._expand_called = True + ensure_xcomarg_return_value(expand_input.value) + + partial_kwargs = self.kwargs.copy() + task_id = partial_kwargs.pop("task_id") + dag = partial_kwargs.pop("dag") + task_group = partial_kwargs.pop("task_group") + start_date = partial_kwargs.pop("start_date", None) + end_date = partial_kwargs.pop("end_date", None) + + try: + operator_name = self.operator_class.custom_operator_name # type: ignore + except AttributeError: + operator_name = self.operator_class.__name__ + + op = MappedOperator( + operator_class=self.operator_class, + expand_input=expand_input, + partial_kwargs=partial_kwargs, + task_id=task_id, + params=self.params, + operator_extra_links=self.operator_class.operator_extra_links, + template_ext=self.operator_class.template_ext, + template_fields=self.operator_class.template_fields, + template_fields_renderers=self.operator_class.template_fields_renderers, + ui_color=self.operator_class.ui_color, + ui_fgcolor=self.operator_class.ui_fgcolor, + is_empty=issubclass(self.operator_class, EmptyOperator), + is_sensor=issubclass(self.operator_class, BaseSensorOperator), + task_module=self.operator_class.__module__, + task_type=self.operator_class.__name__, + operator_name=operator_name, + dag=dag, + task_group=task_group, + start_date=start_date, + end_date=end_date, + disallow_kwargs_override=strict, + # For classic operators, this points to expand_input because kwargs + # to BaseOperator.expand() contribute to operator arguments. + expand_input_attr="expand_input", + # TODO: Move these to task SDK's BaseOperator and remove getattr + start_trigger_args=getattr(self.operator_class, "start_trigger_args", None), + start_from_trigger=bool(getattr(self.operator_class, "start_from_trigger", False)), + ) + return op + + +@attrs.define( + kw_only=True, + # Disable custom __getstate__ and __setstate__ generation since it interacts + # badly with Airflow's DAG serialization and pickling. When a mapped task is + # deserialized, subclasses are coerced into MappedOperator, but when it goes + # through DAG pickling, all attributes defined in the subclasses are dropped + # by attrs's custom state management. Since attrs does not do anything too + # special here (the logic is only important for slots=True), we use Python's + # built-in implementation, which works (as proven by good old BaseOperator). + getstate_setstate=False, +) +class MappedOperator(AbstractOperator): + """Object representing a mapped operator in a DAG.""" + + # This attribute serves double purpose. For a "normal" operator instance + # loaded from DAG, this holds the underlying non-mapped operator class that + # can be used to create an unmapped operator for execution. For an operator + # recreated from a serialized DAG, however, this holds the serialized data + # that can be used to unmap this into a SerializedBaseOperator. + operator_class: type[BaseOperator] | dict[str, Any] + + _is_mapped: bool = attrs.field(init=False, default=True) + + expand_input: ExpandInput + partial_kwargs: dict[str, Any] + + # Needed for serialization. + task_id: str + params: ParamsDict | dict + deps: frozenset[BaseTIDep] = attrs.field(init=False) + operator_extra_links: Collection[BaseOperatorLink] + template_ext: Sequence[str] + template_fields: Collection[str] + template_fields_renderers: dict[str, str] + ui_color: str + ui_fgcolor: str + _is_empty: bool = attrs.field(alias="is_empty") + _is_sensor: bool = attrs.field(alias="is_sensor", default=False) + _task_module: str + _task_type: str + _operator_name: str + start_trigger_args: StartTriggerArgs | None + start_from_trigger: bool + _needs_expansion: bool = True + + dag: DAG | None + task_group: TaskGroup | None + start_date: pendulum.DateTime | None + end_date: pendulum.DateTime | None + upstream_task_ids: set[str] = attrs.field(factory=set, init=False) + downstream_task_ids: set[str] = attrs.field(factory=set, init=False) + + _disallow_kwargs_override: bool + """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``. + + If *False*, values from ``expand_input`` under duplicate keys override those + under corresponding keys in ``partial_kwargs``. + """ + + _expand_input_attr: str + """Where to get kwargs to calculate expansion length against. + + This should be a name to call ``getattr()`` on. + """ + + supports_lineage: bool = False + + HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset( + ("parse_time_mapped_ti_count", "operator_class", "start_trigger_args", "start_from_trigger") + ) + + @deps.default + def _deps(self): + from airflow.models.baseoperator import BaseOperator + + return BaseOperator.deps + + def __hash__(self): + return id(self) + + def __repr__(self): + return f"" + + def __attrs_post_init__(self): + from airflow.models.xcom_arg import XComArg + + if self.get_closest_mapped_task_group() is not None: + raise NotImplementedError("operator expansion in an expanded task group is not yet supported") + + if self.task_group: + self.task_group.add(self) + if self.dag: + self.dag.add_task(self) + XComArg.apply_upstream_relationship(self, self.expand_input.value) + for k, v in self.partial_kwargs.items(): + if k in self.template_fields: + XComArg.apply_upstream_relationship(self, v) + + @methodtools.lru_cache(maxsize=None) + @classmethod + def get_serialized_fields(cls): + # Not using 'cls' here since we only want to serialize base fields. + return (frozenset(attrs.fields_dict(MappedOperator)) | {"task_type"}) - { + "_task_type", + "dag", + "deps", + "expand_input", # This is needed to be able to accept XComArg. + "task_group", + "upstream_task_ids", + "supports_lineage", + "_is_setup", + "_is_teardown", + "_on_failure_fail_dagrun", + } + + @property + def task_type(self) -> str: + """Implementing Operator.""" + return self._task_type + + @property + def operator_name(self) -> str: + return self._operator_name + + @property + def inherits_from_empty_operator(self) -> bool: + """Implementing Operator.""" + return self._is_empty + + @property + def roots(self) -> Sequence[AbstractOperator]: + """Implementing DAGNode.""" + return [self] + + @property + def leaves(self) -> Sequence[AbstractOperator]: + """Implementing DAGNode.""" + return [self] + + @property + def task_display_name(self) -> str: + return self.partial_kwargs.get("task_display_name") or self.task_id + + @property + def owner(self) -> str: # type: ignore[override] + return self.partial_kwargs.get("owner", DEFAULT_OWNER) + + @property + def email(self) -> None | str | Iterable[str]: + return self.partial_kwargs.get("email") + + @property + def map_index_template(self) -> None | str: + return self.partial_kwargs.get("map_index_template") + + @map_index_template.setter + def map_index_template(self, value: str | None) -> None: + self.partial_kwargs["map_index_template"] = value + + @property + def trigger_rule(self) -> TriggerRule: + return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE) + + @trigger_rule.setter + def trigger_rule(self, value): + self.partial_kwargs["trigger_rule"] = value + + @property + def is_setup(self) -> bool: + return bool(self.partial_kwargs.get("is_setup")) + + @is_setup.setter + def is_setup(self, value: bool) -> None: + self.partial_kwargs["is_setup"] = value + + @property + def is_teardown(self) -> bool: + return bool(self.partial_kwargs.get("is_teardown")) + + @is_teardown.setter + def is_teardown(self, value: bool) -> None: + self.partial_kwargs["is_teardown"] = value + + @property + def depends_on_past(self) -> bool: + return bool(self.partial_kwargs.get("depends_on_past")) + + @depends_on_past.setter + def depends_on_past(self, value: bool) -> None: + self.partial_kwargs["depends_on_past"] = value + + @property + def ignore_first_depends_on_past(self) -> bool: + value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST) + return bool(value) + + @ignore_first_depends_on_past.setter + def ignore_first_depends_on_past(self, value: bool) -> None: + self.partial_kwargs["ignore_first_depends_on_past"] = value + + @property + def wait_for_past_depends_before_skipping(self) -> bool: + value = self.partial_kwargs.get( + "wait_for_past_depends_before_skipping", DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING + ) + return bool(value) + + @wait_for_past_depends_before_skipping.setter + def wait_for_past_depends_before_skipping(self, value: bool) -> None: + self.partial_kwargs["wait_for_past_depends_before_skipping"] = value + + @property + def wait_for_downstream(self) -> bool: + return bool(self.partial_kwargs.get("wait_for_downstream")) + + @wait_for_downstream.setter + def wait_for_downstream(self, value: bool) -> None: + self.partial_kwargs["wait_for_downstream"] = value + + @property + def retries(self) -> int: + return self.partial_kwargs.get("retries", DEFAULT_RETRIES) + + @retries.setter + def retries(self, value: int) -> None: + self.partial_kwargs["retries"] = value + + @property + def queue(self) -> str: + return self.partial_kwargs.get("queue", DEFAULT_QUEUE) + + @queue.setter + def queue(self, value: str) -> None: + self.partial_kwargs["queue"] = value + + @property + def pool(self) -> str: + return self.partial_kwargs.get("pool", DEFAULT_POOL_NAME) + + @pool.setter + def pool(self, value: str) -> None: + self.partial_kwargs["pool"] = value + + @property + def pool_slots(self) -> int: + return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS) + + @pool_slots.setter + def pool_slots(self, value: int) -> None: + self.partial_kwargs["pool_slots"] = value + + @property + def execution_timeout(self) -> datetime.timedelta | None: + return self.partial_kwargs.get("execution_timeout") + + @execution_timeout.setter + def execution_timeout(self, value: datetime.timedelta | None) -> None: + self.partial_kwargs["execution_timeout"] = value + + @property + def max_retry_delay(self) -> datetime.timedelta | None: + return self.partial_kwargs.get("max_retry_delay") + + @max_retry_delay.setter + def max_retry_delay(self, value: datetime.timedelta | None) -> None: + self.partial_kwargs["max_retry_delay"] = value + + @property + def retry_delay(self) -> datetime.timedelta: + return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY) + + @retry_delay.setter + def retry_delay(self, value: datetime.timedelta) -> None: + self.partial_kwargs["retry_delay"] = value + + @property + def retry_exponential_backoff(self) -> bool: + return bool(self.partial_kwargs.get("retry_exponential_backoff")) + + @retry_exponential_backoff.setter + def retry_exponential_backoff(self, value: bool) -> None: + self.partial_kwargs["retry_exponential_backoff"] = value + + @property + def priority_weight(self) -> int: # type: ignore[override] + return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT) + + @priority_weight.setter + def priority_weight(self, value: int) -> None: + self.partial_kwargs["priority_weight"] = value + + @property + def weight_rule(self) -> PriorityWeightStrategy: # type: ignore[override] + return validate_and_load_priority_weight_strategy( + self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) + ) + + @weight_rule.setter + def weight_rule(self, value: str | PriorityWeightStrategy) -> None: + self.partial_kwargs["weight_rule"] = validate_and_load_priority_weight_strategy(value) + + @property + def max_active_tis_per_dag(self) -> int | None: + return self.partial_kwargs.get("max_active_tis_per_dag") + + @max_active_tis_per_dag.setter + def max_active_tis_per_dag(self, value: int | None) -> None: + self.partial_kwargs["max_active_tis_per_dag"] = value + + @property + def max_active_tis_per_dagrun(self) -> int | None: + return self.partial_kwargs.get("max_active_tis_per_dagrun") + + @max_active_tis_per_dagrun.setter + def max_active_tis_per_dagrun(self, value: int | None) -> None: + self.partial_kwargs["max_active_tis_per_dagrun"] = value + + @property + def resources(self) -> Resources | None: + return self.partial_kwargs.get("resources") + + @property + def on_execute_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_execute_callback") + + @on_execute_callback.setter + def on_execute_callback(self, value: TaskStateChangeCallbackAttrType) -> None: + self.partial_kwargs["on_execute_callback"] = value + + @property + def on_failure_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_failure_callback") + + @on_failure_callback.setter + def on_failure_callback(self, value: TaskStateChangeCallbackAttrType) -> None: + self.partial_kwargs["on_failure_callback"] = value + + @property + def on_retry_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_retry_callback") + + @on_retry_callback.setter + def on_retry_callback(self, value: TaskStateChangeCallbackAttrType) -> None: + self.partial_kwargs["on_retry_callback"] = value + + @property + def on_success_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_success_callback") + + @on_success_callback.setter + def on_success_callback(self, value: TaskStateChangeCallbackAttrType) -> None: + self.partial_kwargs["on_success_callback"] = value + + @property + def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_skipped_callback") + + @on_skipped_callback.setter + def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) -> None: + self.partial_kwargs["on_skipped_callback"] = value + + @property + def run_as_user(self) -> str | None: + return self.partial_kwargs.get("run_as_user") + + @property + def executor(self) -> str | None: + return self.partial_kwargs.get("executor", DEFAULT_EXECUTOR) + + @property + def executor_config(self) -> dict: + return self.partial_kwargs.get("executor_config", {}) + + @property # type: ignore[override] + def inlets(self) -> list[Any]: # type: ignore[override] + return self.partial_kwargs.get("inlets", []) + + @inlets.setter + def inlets(self, value: list[Any]) -> None: # type: ignore[override] + self.partial_kwargs["inlets"] = value + + @property # type: ignore[override] + def outlets(self) -> list[Any]: # type: ignore[override] + return self.partial_kwargs.get("outlets", []) + + @outlets.setter + def outlets(self, value: list[Any]) -> None: # type: ignore[override] + self.partial_kwargs["outlets"] = value + + @property + def doc(self) -> str | None: + return self.partial_kwargs.get("doc") + + @property + def doc_md(self) -> str | None: + return self.partial_kwargs.get("doc_md") + + @property + def doc_json(self) -> str | None: + return self.partial_kwargs.get("doc_json") + + @property + def doc_yaml(self) -> str | None: + return self.partial_kwargs.get("doc_yaml") + + @property + def doc_rst(self) -> str | None: + return self.partial_kwargs.get("doc_rst") + + @property + def allow_nested_operators(self) -> bool: + return bool(self.partial_kwargs.get("allow_nested_operators")) + + def get_dag(self) -> DAG | None: + """Implement Operator.""" + return self.dag + + @property + def output(self) -> XComArg: + """Return reference to XCom pushed by current operator.""" + from airflow.models.xcom_arg import XComArg + + return XComArg(operator=self) + + def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: + """Implement DAGNode.""" + return DagAttributeTypes.OP, self.task_id + + def _expand_mapped_kwargs( + self, context: Mapping[str, Any], session: Session, *, include_xcom: bool + ) -> tuple[Mapping[str, Any], set[int]]: + """ + Get the kwargs to create the unmapped operator. + + This exists because taskflow operators expand against op_kwargs, not the + entire operator kwargs dict. + """ + return self._get_specified_expand_input().resolve(context, session, include_xcom=include_xcom) + + def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: + """ + Get init kwargs to unmap the underlying operator class. + + :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``. + """ + if strict: + prevent_duplicates( + self.partial_kwargs, + mapped_kwargs, + fail_reason="unmappable or already specified", + ) + + # If params appears in the mapped kwargs, we need to merge it into the + # partial params, overriding existing keys. + params = copy.copy(self.params) + with contextlib.suppress(KeyError): + params.update(mapped_kwargs["params"]) + + # Ordering is significant; mapped kwargs should override partial ones, + # and the specially handled params should be respected. + return { + "task_id": self.task_id, + "dag": self.dag, + "task_group": self.task_group, + "start_date": self.start_date, + "end_date": self.end_date, + **self.partial_kwargs, + **mapped_kwargs, + "params": params, + } + + def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool: + """ + Get the start_from_trigger value of the current abstract operator. + + MappedOperator uses this to unmap start_from_trigger to decide whether to start the task + execution directly from triggerer. + + :meta private: + """ + # start_from_trigger only makes sense when start_trigger_args exists. + if not self.start_trigger_args: + return False + + mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False) + if self._disallow_kwargs_override: + prevent_duplicates( + self.partial_kwargs, + mapped_kwargs, + fail_reason="unmappable or already specified", + ) + + # Ordering is significant; mapped kwargs should override partial ones. + return mapped_kwargs.get( + "start_from_trigger", self.partial_kwargs.get("start_from_trigger", self.start_from_trigger) + ) + + def expand_start_trigger_args(self, *, context: Context, session: Session) -> StartTriggerArgs | None: + """ + Get the kwargs to create the unmapped start_trigger_args. + + This method is for allowing mapped operator to start execution from triggerer. + """ + if not self.start_trigger_args: + return None + + mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False) + if self._disallow_kwargs_override: + prevent_duplicates( + self.partial_kwargs, + mapped_kwargs, + fail_reason="unmappable or already specified", + ) + + # Ordering is significant; mapped kwargs should override partial ones. + trigger_kwargs = mapped_kwargs.get( + "trigger_kwargs", + self.partial_kwargs.get("trigger_kwargs", self.start_trigger_args.trigger_kwargs), + ) + next_kwargs = mapped_kwargs.get( + "next_kwargs", + self.partial_kwargs.get("next_kwargs", self.start_trigger_args.next_kwargs), + ) + timeout = mapped_kwargs.get( + "trigger_timeout", self.partial_kwargs.get("trigger_timeout", self.start_trigger_args.timeout) + ) + return StartTriggerArgs( + trigger_cls=self.start_trigger_args.trigger_cls, + trigger_kwargs=trigger_kwargs, + next_method=self.start_trigger_args.next_method, + next_kwargs=next_kwargs, + timeout=timeout, + ) + + def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator: + """ + Get the "normal" Operator after applying the current mapping. + + The *resolve* argument is only used if ``operator_class`` is a real + class, i.e. if this operator is not serialized. If ``operator_class`` is + not a class (i.e. this DAG has been deserialized), this returns a + SerializedBaseOperator that "looks like" the actual unmapping result. + + If *resolve* is a two-tuple (context, session), the information is used + to resolve the mapped arguments into init arguments. If it is a mapping, + no resolving happens, the mapping directly provides those init arguments + resolved from mapped kwargs. + + :meta private: + """ + if isinstance(self.operator_class, type): + if isinstance(resolve, Mapping): + kwargs = resolve + elif resolve is not None: + kwargs, _ = self._expand_mapped_kwargs(*resolve, include_xcom=True) + else: + raise RuntimeError("cannot unmap a non-serialized operator without context") + kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override) + is_setup = kwargs.pop("is_setup", False) + is_teardown = kwargs.pop("is_teardown", False) + on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False) + op = self.operator_class(**kwargs, _airflow_from_mapped=True) + # We need to overwrite task_id here because BaseOperator further + # mangles the task_id based on the task hierarchy (namely, group_id + # is prepended, and '__N' appended to deduplicate). This is hacky, + # but better than duplicating the whole mangling logic. + op.task_id = self.task_id + op.is_setup = is_setup + op.is_teardown = is_teardown + op.on_failure_fail_dagrun = on_failure_fail_dagrun + return op + + # TODO: TaskSDK: This probably doesn't need to live in definition time as the next section of code is + # for unmapping a deserialized DAG -- i.e. in the scheduler. + + # After a mapped operator is serialized, there's no real way to actually + # unmap it since we've lost access to the underlying operator class. + # This tries its best to simply "forward" all the attributes on this + # mapped operator to a new SerializedBaseOperator instance. + from airflow.serialization.serialized_objects import SerializedBaseOperator + + op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True) + for partial_attr, value in self.partial_kwargs.items(): + setattr(op, partial_attr, value) + SerializedBaseOperator.populate_operator(op, self.operator_class) + if self.dag is not None: # For Mypy; we only serialize tasks in a DAG so the check always satisfies. + SerializedBaseOperator.set_task_dag_references(op, self.dag) # type: ignore[arg-type] + return op + + def _get_specified_expand_input(self) -> ExpandInput: + """Input received from the expand call on the operator.""" + return getattr(self, self._expand_input_attr) + + def prepare_for_execution(self) -> MappedOperator: + # Since a mapped operator cannot be used for execution, and an unmapped + # BaseOperator needs to be created later (see render_template_fields), + # we don't need to create a copy of the MappedOperator here. + return self + + def iter_mapped_dependencies(self) -> Iterator[Operator]: + """Upstream dependencies that provide XComs used by this task for task mapping.""" + from airflow.models.xcom_arg import XComArg + + for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()): + yield operator + + @methodtools.lru_cache(maxsize=None) + def get_parse_time_mapped_ti_count(self) -> int: + current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count() + try: + # The use of `methodtools` interferes with the zero-arg super + parent_count = super(MappedOperator, self).get_parse_time_mapped_ti_count() # noqa: UP008 + except NotMapped: + return current_count + return parent_count * current_count + + def render_template_fields( + self, + context: Context, + jinja_env: jinja2.Environment | None = None, + ) -> None: + """ + Template all attributes listed in *self.template_fields*. + + This updates *context* to reference the map-expanded task and relevant + information, without modifying the mapped operator. The expanded task + in *context* is then rendered in-place. + + :param context: Context dict with values to apply on content. + :param jinja_env: Jinja environment to use for rendering. + """ + from airflow.utils.context import context_update_for_unmapped + + if not jinja_env: + jinja_env = self.get_template_env() + + # We retrieve the session here, stored by _run_raw_task in set_current_task_session + # context manager - we cannot pass the session via @provide_session because the signature + # of render_template_fields is defined by BaseOperator and there are already many subclasses + # overriding it, so changing the signature is not an option. However render_template_fields is + # always executed within "_run_raw_task" so we make sure that _run_raw_task uses the + # set_current_task_session context manager to store the session in the current task. + session = get_current_task_instance_session() + + mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session, include_xcom=True) + unmapped_task = self.unmap(mapped_kwargs) + # TODO: Task-SDK: remove arg-type ignore once Kaxil's PR lands + context_update_for_unmapped(context, unmapped_task) # type: ignore[arg-type] + + # Since the operators that extend `BaseOperator` are not subclasses of + # `MappedOperator`, we need to call `_do_render_template_fields` from + # the unmapped task in order to call the operator method when we override + # it to customize the parsing of nested fields. + unmapped_task._do_render_template_fields( + parent=unmapped_task, + template_fields=self.template_fields, + context=context, + jinja_env=jinja_env, + seen_oids=seen_oids, + ) diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py b/task_sdk/src/airflow/sdk/definitions/taskgroup.py index cb5ee3eeece08..f609c5aa21c48 100644 --- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py @@ -46,6 +46,7 @@ from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.edges import EdgeModifier + from airflow.sdk.types import Operator from airflow.serialization.enums import DagAttributeTypes @@ -105,7 +106,9 @@ class TaskGroup(DAGNode): """ _group_id: str | None = attrs.field( - validator=attrs.validators.optional(attrs.validators.instance_of(str)) + validator=attrs.validators.optional(attrs.validators.instance_of(str)), + # This is the default behaviour for attrs, but by specifying this it makes IDEs happier + alias="group_id", ) group_display_name: str = attrs.field(default="", validator=attrs.validators.instance_of(str)) prefix_group_id: bool = attrs.field(default=True) @@ -561,7 +564,7 @@ def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: def iter_tasks(self) -> Iterator[AbstractOperator]: """Return an iterator of the child tasks.""" - from airflow.models.abstractoperator import AbstractOperator + from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator groups_to_visit = [self] @@ -575,7 +578,7 @@ def iter_tasks(self) -> Iterator[AbstractOperator]: groups_to_visit.append(child) else: raise ValueError( - f"Encountered a DAGNode that is not a TaskGroup or an AbstractOperator: {type(child)}" + f"Encountered a DAGNode that is not a TaskGroup or an AbstractOperator: {type(child).__module__}.{type(child)}" ) @@ -604,13 +607,6 @@ def __iter__(self): ) yield from self._iter_child(child) - def iter_mapped_dependencies(self) -> Iterator[DAGNode]: - """Upstream dependencies that provide XComs used by this mapped task group.""" - from airflow.models.xcom_arg import XComArg - - for op, _ in XComArg.iter_xcom_references(self._expand_input): - yield op - @methodtools.lru_cache(maxsize=None) def get_parse_time_mapped_ti_count(self) -> int: """ @@ -637,11 +633,18 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.set_upstream(op) super().__exit__(exc_type, exc_val, exc_tb) + def iter_mapped_dependencies(self) -> Iterator[Operator]: + """Upstream dependencies that provide XComs used by this mapped task group.""" + from airflow.models.xcom_arg import XComArg + + for op, _ in XComArg.iter_xcom_references(self._expand_input): + yield op + def task_group_to_dict(task_item_or_group): """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" - from airflow.models.abstractoperator import AbstractOperator - from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.definitions.abstractoperator import AbstractOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sensors.base import BaseSensorOperator if isinstance(task := task_item_or_group, AbstractOperator): diff --git a/task_sdk/src/airflow/sdk/definitions/xcom_arg.py b/task_sdk/src/airflow/sdk/definitions/xcom_arg.py new file mode 100644 index 0000000000000..436cd9d005012 --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -0,0 +1,639 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import contextlib +import inspect +import itertools +from collections.abc import Iterable, Iterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Callable, Union, overload + +from airflow.exceptions import AirflowException, XComNotFound +from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator +from airflow.sdk.definitions._internal.mixins import DependencyMixin, ResolveMixin +from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet +from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.setup_teardown import SetupTeardownContext +from airflow.utils.trigger_rule import TriggerRule +from airflow.utils.xcom import XCOM_RETURN_KEY + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.dag import DAG + from airflow.sdk.types import Operator + from airflow.utils.edgemodifier import EdgeModifier + +# Callable objects contained by MapXComArg. We only accept callables from +# the user, but deserialize them into strings in a serialized XComArg for +# safety (those callables are arbitrary user code). +MapCallables = Sequence[Union[Callable[[Any], Any], str]] + + +class XComArg(ResolveMixin, DependencyMixin): + """ + Reference to an XCom value pushed from another operator. + + The implementation supports:: + + xcomarg >> op + xcomarg << op + op >> xcomarg # By BaseOperator code + op << xcomarg # By BaseOperator code + + **Example**: The moment you get a result from any operator (decorated or regular) you can :: + + any_op = AnyOperator() + xcomarg = XComArg(any_op) + # or equivalently + xcomarg = any_op.output + my_op = MyOperator() + my_op >> xcomarg + + This object can be used in legacy Operators via Jinja. + + **Example**: You can make this result to be part of any generated string:: + + any_op = AnyOperator() + xcomarg = any_op.output + op1 = MyOperator(my_text_message=f"the value is {xcomarg}") + op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}") + + :param operator: Operator instance to which the XComArg references. + :param key: Key used to pull the XCom value. Defaults to *XCOM_RETURN_KEY*, + i.e. the referenced operator's return value. + """ + + @overload + def __new__(cls: type[XComArg], operator: Operator, key: str = XCOM_RETURN_KEY) -> XComArg: + """Execute when the user writes ``XComArg(...)`` directly.""" + + @overload + def __new__(cls: type[XComArg]) -> XComArg: + """Execute by Python internals from subclasses.""" + + def __new__(cls, *args, **kwargs) -> XComArg: + if cls is XComArg: + return PlainXComArg(*args, **kwargs) + return super().__new__(cls) + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + raise NotImplementedError() + + @staticmethod + def iter_xcom_references(arg: Any) -> Iterator[tuple[Operator, str]]: + """ + Return XCom references in an arbitrary value. + + Recursively traverse ``arg`` and look for XComArg instances in any + collection objects, and instances with ``template_fields`` set. + """ + if isinstance(arg, ResolveMixin): + yield from arg.iter_references() + elif isinstance(arg, (tuple, set, list)): + for elem in arg: + yield from XComArg.iter_xcom_references(elem) + elif isinstance(arg, dict): + for elem in arg.values(): + yield from XComArg.iter_xcom_references(elem) + elif isinstance(arg, AbstractOperator): + for attr in arg.template_fields: + yield from XComArg.iter_xcom_references(getattr(arg, attr)) + + @staticmethod + def apply_upstream_relationship(op: DependencyMixin, arg: Any): + """ + Set dependency for XComArgs. + + This looks for XComArg objects in ``arg`` "deeply" (looking inside + collections objects and classes decorated with ``template_fields``), and + sets the relationship to ``op`` on any found. + """ + for operator, _ in XComArg.iter_xcom_references(arg): + op.set_upstream(operator) + + @property + def roots(self) -> list[Operator]: + """Required by DependencyMixin.""" + return [op for op, _ in self.iter_references()] + + @property + def leaves(self) -> list[Operator]: + """Required by DependencyMixin.""" + return [op for op, _ in self.iter_references()] + + def set_upstream( + self, + task_or_task_list: DependencyMixin | Sequence[DependencyMixin], + edge_modifier: EdgeModifier | None = None, + ): + """Proxy to underlying operator set_upstream method. Required by DependencyMixin.""" + for operator, _ in self.iter_references(): + operator.set_upstream(task_or_task_list, edge_modifier) + + def set_downstream( + self, + task_or_task_list: DependencyMixin | Sequence[DependencyMixin], + edge_modifier: EdgeModifier | None = None, + ): + """Proxy to underlying operator set_downstream method. Required by DependencyMixin.""" + for operator, _ in self.iter_references(): + operator.set_downstream(task_or_task_list, edge_modifier) + + def _serialize(self) -> dict[str, Any]: + """ + Serialize an XComArg. + + The implementation should be the inverse function to ``deserialize``, + returning a data dict converted from this XComArg derivative. DAG + serialization does not call this directly, but ``serialize_xcom_arg`` + instead, which adds additional information to dispatch deserialization + to the correct class. + """ + raise NotImplementedError() + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: + """ + Deserialize an XComArg. + + The implementation should be the inverse function to ``serialize``, + implementing given a data dict converted from this XComArg derivative, + how the original XComArg should be created. DAG serialization relies on + additional information added in ``serialize_xcom_arg`` to dispatch data + dicts to the correct ``_deserialize`` information, so this function does + not need to validate whether the incoming data contains correct keys. + """ + raise NotImplementedError() + + def map(self, f: Callable[[Any], Any]) -> MapXComArg: + return MapXComArg(self, [f]) + + def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: + return ZipXComArg([self, *others], fillvalue=fillvalue) + + def concat(self, *others: XComArg) -> ConcatXComArg: + return ConcatXComArg([self, *others]) + + def resolve( + self, context: Mapping[str, Any], session: Session | None = None, *, include_xcom: bool = True + ) -> Any: + raise NotImplementedError() + + def __enter__(self): + if not self.operator.is_setup and not self.operator.is_teardown: + raise AirflowException("Only setup/teardown tasks can be used as context managers.") + SetupTeardownContext.push_setup_teardown_task(self.operator) + return SetupTeardownContext + + def __exit__(self, exc_type, exc_val, exc_tb): + SetupTeardownContext.set_work_task_roots_and_leaves() + + +class PlainXComArg(XComArg): + """ + Reference to one single XCom without any additional semantics. + + This class should not be accessed directly, but only through XComArg. The + class inheritance chain and ``__new__`` is implemented in this slightly + convoluted way because we want to + + a. Allow the user to continue using XComArg directly for the simple + semantics (see documentation of the base class for details). + b. Make ``isinstance(thing, XComArg)`` be able to detect all kinds of XCom + references. + c. Not allow many properties of PlainXComArg (including ``__getitem__`` and + ``__str__``) to exist on other kinds of XComArg implementations since + they don't make sense. + + :meta private: + """ + + def __init__(self, operator: Operator, key: str = XCOM_RETURN_KEY): + self.operator = operator + self.key = key + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, PlainXComArg): + return NotImplemented + return self.operator == other.operator and self.key == other.key + + def __getitem__(self, item: str) -> XComArg: + """Implement xcomresult['some_result_key'].""" + if not isinstance(item, str): + raise ValueError(f"XComArg only supports str lookup, received {type(item).__name__}") + return PlainXComArg(operator=self.operator, key=item) + + def __iter__(self): + """ + Override iterable protocol to raise error explicitly. + + The default ``__iter__`` implementation in Python calls ``__getitem__`` + with 0, 1, 2, etc. until it hits an ``IndexError``. This does not work + well with our custom ``__getitem__`` implementation, and results in poor + DAG-writing experience since a misplaced ``*`` expansion would create an + infinite loop consuming the entire DAG parser. + + This override catches the error eagerly, so an incorrectly implemented + DAG fails fast and avoids wasting resources on nonsensical iterating. + """ + raise TypeError("'XComArg' object is not iterable") + + def __repr__(self) -> str: + if self.key == XCOM_RETURN_KEY: + return f"XComArg({self.operator!r})" + return f"XComArg({self.operator!r}, {self.key!r})" + + def __str__(self) -> str: + """ + Backward compatibility for old-style jinja used in Airflow Operators. + + **Example**: to use XComArg at BashOperator:: + + BashOperator(cmd=f"... { xcomarg } ...") + + :return: + """ + xcom_pull_kwargs = [ + f"task_ids='{self.operator.task_id}'", + f"dag_id='{self.operator.dag_id}'", + ] + if self.key is not None: + xcom_pull_kwargs.append(f"key='{self.key}'") + + xcom_pull_str = ", ".join(xcom_pull_kwargs) + # {{{{ are required for escape {{ in f-string + xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_str}) }}}}" + return xcom_pull + + def _serialize(self) -> dict[str, Any]: + return {"task_id": self.operator.task_id, "key": self.key} + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: + return cls(dag.get_task(data["task_id"]), data["key"]) + + @property + def is_setup(self) -> bool: + return self.operator.is_setup + + @is_setup.setter + def is_setup(self, val: bool): + self.operator.is_setup = val + + @property + def is_teardown(self) -> bool: + return self.operator.is_teardown + + @is_teardown.setter + def is_teardown(self, val: bool): + self.operator.is_teardown = val + + @property + def on_failure_fail_dagrun(self) -> bool: + return self.operator.on_failure_fail_dagrun + + @on_failure_fail_dagrun.setter + def on_failure_fail_dagrun(self, val: bool): + self.operator.on_failure_fail_dagrun = val + + def as_setup(self) -> DependencyMixin: + for operator, _ in self.iter_references(): + operator.is_setup = True + return self + + def as_teardown( + self, + *, + setups: BaseOperator | Iterable[BaseOperator] | None = None, + on_failure_fail_dagrun: bool | None = None, + ): + for operator, _ in self.iter_references(): + operator.is_teardown = True + operator.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS + if on_failure_fail_dagrun is not None: + operator.on_failure_fail_dagrun = on_failure_fail_dagrun + if setups is not None: + setups = [setups] if isinstance(setups, DependencyMixin) else setups + for s in setups: + s.is_setup = True + s >> operator + return self + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + yield self.operator, self.key + + def map(self, f: Callable[[Any], Any]) -> MapXComArg: + if self.key != XCOM_RETURN_KEY: + raise ValueError("cannot map against non-return XCom") + return super().map(f) + + def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: + if self.key != XCOM_RETURN_KEY: + raise ValueError("cannot map against non-return XCom") + return super().zip(*others, fillvalue=fillvalue) + + def concat(self, *others: XComArg) -> ConcatXComArg: + if self.key != XCOM_RETURN_KEY: + raise ValueError("cannot concatenate non-return XCom") + return super().concat(*others) + + # TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK + def resolve( + self, context: Mapping[str, Any], session: Session | None = None, *, include_xcom: bool = True + ) -> Any: + ti = context["ti"] + task_id = self.operator.task_id + map_indexes = context.get("_upstream_map_indexes", {}).get(task_id) + + result = ti.xcom_pull( + task_ids=task_id, + map_indexes=map_indexes, + key=self.key, + default=NOTSET, + ) + if not isinstance(result, ArgNotSet): + return result + if self.key == XCOM_RETURN_KEY: + return None + if getattr(self.operator, "multiple_outputs", False): + # If the operator is set to have multiple outputs and it was not executed, + # we should return "None" instead of showing an error. This is because when + # multiple outputs XComs are created, the XCom keys associated with them will have + # different names than the predefined "XCOM_RETURN_KEY" and won't be found. + # Therefore, it's better to return "None" like we did above where self.key==XCOM_RETURN_KEY. + return None + raise XComNotFound(ti.dag_id, task_id, self.key) + + +def _get_callable_name(f: Callable | str) -> str: + """Try to "describe" a callable by getting its name.""" + if callable(f): + return f.__name__ + # Parse the source to find whatever is behind "def". For safety, we don't + # want to evaluate the code in any meaningful way! + with contextlib.suppress(Exception): + kw, name, _ = f.lstrip().split(None, 2) + if kw == "def": + return name + return "" + + +class _MapResult(Sequence): + def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: + self.value = value + self.callables = callables + + def __getitem__(self, index: Any) -> Any: + value = self.value[index] + + # In the worker, we can access all actual callables. Call them. + callables = [f for f in self.callables if callable(f)] + if len(callables) == len(self.callables): + for f in callables: + value = f(value) + return value + + # In the scheduler, we don't have access to the actual callables, nor do + # we want to run it since it's arbitrary code. This builds a string to + # represent the call chain in the UI or logs instead. + for v in self.callables: + value = f"{_get_callable_name(v)}({value})" + return value + + def __len__(self) -> int: + return len(self.value) + + +class MapXComArg(XComArg): + """ + An XCom reference with ``map()`` call(s) applied. + + This is based on an XComArg, but also applies a series of "transforms" that + convert the pulled XCom value. + + :meta private: + """ + + def __init__(self, arg: XComArg, callables: MapCallables) -> None: + for c in callables: + if getattr(c, "_airflow_is_task_decorator", False): + raise ValueError("map() argument must be a plain function, not a @task operator") + self.arg = arg + self.callables = callables + + def __repr__(self) -> str: + map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables) + return f"{self.arg!r}{map_calls}" + + def _serialize(self) -> dict[str, Any]: + return { + "arg": serialize_xcom_arg(self.arg), + "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables], + } + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: + # We are deliberately NOT deserializing the callables. These are shown + # in the UI, and displaying a function object is useless. + return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + yield from self.arg.iter_references() + + def map(self, f: Callable[[Any], Any]) -> MapXComArg: + # Flatten arg.map(f1).map(f2) into one MapXComArg. + return MapXComArg(self.arg, [*self.callables, f]) + + # TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK + @provide_session + def resolve( + self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True + ) -> Any: + value = self.arg.resolve(context, session=session, include_xcom=include_xcom) + if not isinstance(value, (Sequence, dict)): + raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") + return _MapResult(value, self.callables) + + +class _ZipResult(Sequence): + def __init__(self, values: Sequence[Sequence | dict], *, fillvalue: Any = NOTSET) -> None: + self.values = values + self.fillvalue = fillvalue + + @staticmethod + def _get_or_fill(container: Sequence | dict, index: Any, fillvalue: Any) -> Any: + try: + return container[index] + except (IndexError, KeyError): + return fillvalue + + def __getitem__(self, index: Any) -> Any: + if index >= len(self): + raise IndexError(index) + return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values) + + def __len__(self) -> int: + lengths = (len(v) for v in self.values) + if isinstance(self.fillvalue, ArgNotSet): + return min(lengths) + return max(lengths) + + +class ZipXComArg(XComArg): + """ + An XCom reference with ``zip()`` applied. + + This is constructed from multiple XComArg instances, and presents an + iterable that "zips" them together like the built-in ``zip()`` (and + ``itertools.zip_longest()`` if ``fillvalue`` is provided). + """ + + def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None: + if not args: + raise ValueError("At least one input is required") + self.args = args + self.fillvalue = fillvalue + + def __repr__(self) -> str: + args_iter = iter(self.args) + first = repr(next(args_iter)) + rest = ", ".join(repr(arg) for arg in args_iter) + if isinstance(self.fillvalue, ArgNotSet): + return f"{first}.zip({rest})" + return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})" + + def _serialize(self) -> dict[str, Any]: + args = [serialize_xcom_arg(arg) for arg in self.args] + if isinstance(self.fillvalue, ArgNotSet): + return {"args": args} + return {"args": args, "fillvalue": self.fillvalue} + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: + return cls( + [deserialize_xcom_arg(arg, dag) for arg in data["args"]], + fillvalue=data.get("fillvalue", NOTSET), + ) + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + for arg in self.args: + yield from arg.iter_references() + + @provide_session + def resolve( + self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True + ) -> Any: + values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args] + for value in values: + if not isinstance(value, (Sequence, dict)): + raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}") + return _ZipResult(values, fillvalue=self.fillvalue) + + +class _ConcatResult(Sequence): + def __init__(self, values: Sequence[Sequence | dict]) -> None: + self.values = values + + def __getitem__(self, index: Any) -> Any: + if index >= 0: + i = index + else: + i = len(self) + index + for value in self.values: + if i < 0: + break + elif i >= (curlen := len(value)): + i -= curlen + elif isinstance(value, Sequence): + return value[i] + else: + return next(itertools.islice(iter(value), i, None)) + raise IndexError("list index out of range") + + def __len__(self) -> int: + return sum(len(v) for v in self.values) + + +class ConcatXComArg(XComArg): + """ + Concatenating multiple XCom references into one. + + This is done by calling ``concat()`` on an XComArg to combine it with + others. The effect is similar to Python's :func:`itertools.chain`, but the + return value also supports index access. + """ + + def __init__(self, args: Sequence[XComArg]) -> None: + if not args: + raise ValueError("At least one input is required") + self.args = args + + def __repr__(self) -> str: + args_iter = iter(self.args) + first = repr(next(args_iter)) + rest = ", ".join(repr(arg) for arg in args_iter) + return f"{first}.concat({rest})" + + def _serialize(self) -> dict[str, Any]: + return {"args": [serialize_xcom_arg(arg) for arg in self.args]} + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: + return cls([deserialize_xcom_arg(arg, dag) for arg in data["args"]]) + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + for arg in self.args: + yield from arg.iter_references() + + def concat(self, *others: XComArg) -> ConcatXComArg: + # Flatten foo.concat(x).concat(y) into one call. + return ConcatXComArg([*self.args, *others]) + + @provide_session + def resolve( + self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True + ) -> Any: + values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args] + for value in values: + if not isinstance(value, (Sequence, dict)): + raise ValueError(f"XCom concat expects sequence or dict, not {type(value).__name__}") + return _ConcatResult(values) + + +_XCOM_ARG_TYPES: Mapping[str, type[XComArg]] = { + "": PlainXComArg, + "concat": ConcatXComArg, + "map": MapXComArg, + "zip": ZipXComArg, +} + + +def serialize_xcom_arg(value: XComArg) -> dict[str, Any]: + """DAG serialization interface.""" + key = next(k for k, v in _XCOM_ARG_TYPES.items() if isinstance(value, v)) + if key: + return {"type": key, **value._serialize()} + return value._serialize() + + +def deserialize_xcom_arg(data: dict[str, Any], dag: DAG) -> XComArg: + """DAG serialization interface.""" + klass = _XCOM_ARG_TYPES[data.get("type", "")] + return klass._deserialize(data, dag) diff --git a/task_sdk/src/airflow/sdk/types.py b/task_sdk/src/airflow/sdk/types.py index 35ee9f8e38cab..f9ec150ce3ae2 100644 --- a/task_sdk/src/airflow/sdk/types.py +++ b/task_sdk/src/airflow/sdk/types.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol, Union if TYPE_CHECKING: from collections.abc import Iterator @@ -25,6 +25,9 @@ from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, BaseAssetUniqueKey from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator + + Operator = Union[BaseOperator, MappedOperator] class DagRunProtocol(Protocol): diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py index 50429d91b018a..3fc7fc18015c7 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/tests/conftest.py @@ -20,6 +20,7 @@ import os from pathlib import Path from typing import TYPE_CHECKING, Any, NoReturn, Protocol +from unittest import mock import pytest @@ -215,3 +216,11 @@ def _make_context_dict( return context.model_dump(exclude_unset=True, mode="json") return _make_context_dict + + +@pytest.fixture +def mock_supervisor_comms(): + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as supervisor_comms: + yield supervisor_comms diff --git a/task_sdk/tests/defintions/test_baseoperator.py b/task_sdk/tests/defintions/test_baseoperator.py index 35f33818dc198..af6bf592f5373 100644 --- a/task_sdk/tests/defintions/test_baseoperator.py +++ b/task_sdk/tests/defintions/test_baseoperator.py @@ -621,3 +621,26 @@ def _do_render(): assert expected_log in caplog.text if not_expected_log: assert not_expected_log not in caplog.text + + +def test_find_mapped_dependants_in_another_group(): + from airflow.decorators import task as task_decorator + from airflow.sdk import TaskGroup + + @task_decorator + def gen(x): + return list(range(x)) + + @task_decorator + def add(x, y): + return x + y + + with DAG(dag_id="test"): + with TaskGroup(group_id="g1"): + gen_result = gen(3) + with TaskGroup(group_id="g2"): + add_result = add.partial(y=1).expand(x=gen_result) + + # breakpoint() + dependants = list(gen_result.operator.iter_mapped_dependants()) + assert dependants == [add_result.operator] diff --git a/task_sdk/tests/defintions/test_mappedoperator.py b/task_sdk/tests/defintions/test_mappedoperator.py new file mode 100644 index 0000000000000..aba7523b5ad39 --- /dev/null +++ b/task_sdk/tests/defintions/test_mappedoperator.py @@ -0,0 +1,301 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime, timedelta + +import pendulum +import pytest + +from airflow.models.param import ParamsDict +from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.definitions.dag import DAG +from airflow.sdk.definitions.mappedoperator import MappedOperator +from airflow.sdk.definitions.xcom_arg import XComArg +from airflow.utils.trigger_rule import TriggerRule + +from tests_common.test_utils.mapping import expand_mapped_task # noqa: F401 +from tests_common.test_utils.mock_operators import ( + MockOperator, +) + +DEFAULT_DATE = datetime(2016, 1, 1) + + +def test_task_mapping_with_dag(): + with DAG("test-dag") as dag: + task1 = BaseOperator(task_id="op1") + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) + finish = MockOperator(task_id="finish") + + task1 >> mapped >> finish + + assert task1.downstream_list == [mapped] + assert mapped in dag.tasks + assert mapped.task_group == dag.task_group + # At parse time there should only be three tasks! + assert len(dag.tasks) == 3 + + assert finish.upstream_list == [mapped] + assert mapped.downstream_list == [finish] + + +# TODO: +# test_task_mapping_with_dag_and_list_of_pandas_dataframe + + +def test_task_mapping_without_dag_context(): + with DAG("test-dag") as dag: + task1 = BaseOperator(task_id="op1") + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) + + task1 >> mapped + + assert isinstance(mapped, MappedOperator) + assert mapped in dag.tasks + assert task1.downstream_list == [mapped] + assert mapped in dag.tasks + # At parse time there should only be two tasks! + assert len(dag.tasks) == 2 + + +def test_task_mapping_default_args(): + default_args = {"start_date": DEFAULT_DATE.now(), "owner": "test"} + with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE, default_args=default_args): + task1 = BaseOperator(task_id="op1") + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) + + task1 >> mapped + + assert mapped.partial_kwargs["owner"] == "test" + assert mapped.start_date == pendulum.instance(default_args["start_date"]) + + +def test_task_mapping_override_default_args(): + default_args = {"retries": 2, "start_date": DEFAULT_DATE.now()} + with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE, default_args=default_args): + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task", retries=1).expand(arg2=literal) + + # retries should be 1 because it is provided as a partial arg + assert mapped.partial_kwargs["retries"] == 1 + # start_date should be equal to default_args["start_date"] because it is not provided as partial arg + assert mapped.start_date == pendulum.instance(default_args["start_date"]) + # owner should be equal to Airflow default owner (airflow) because it is not provided at all + assert mapped.owner == "airflow" + + +def test_map_unknown_arg_raises(): + with pytest.raises(TypeError, match=r"argument 'file'"): + BaseOperator.partial(task_id="a").expand(file=[1, 2, {"a": "b"}]) + + +def test_map_xcom_arg(): + """Test that dependencies are correct when mapping with an XComArg""" + with DAG("test-dag"): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="task_2").expand(arg2=task1.output) + finish = MockOperator(task_id="finish") + + mapped >> finish + + assert task1.downstream_list == [mapped] + + +# def test_map_xcom_arg_multiple_upstream_xcoms(dag_maker, session): + + +def test_partial_on_instance() -> None: + """`.partial` on an instance should fail -- it's only designed to be called on classes""" + with pytest.raises(TypeError): + MockOperator(task_id="a").partial() + + +def test_partial_on_class() -> None: + # Test that we accept args for superclasses too + op = MockOperator.partial(task_id="a", arg1="a", trigger_rule=TriggerRule.ONE_FAILED) + assert op.kwargs["arg1"] == "a" + assert op.kwargs["trigger_rule"] == TriggerRule.ONE_FAILED + + +def test_partial_on_class_invalid_ctor_args() -> None: + """Test that when we pass invalid args to partial(). + + I.e. if an arg is not known on the class or any of its parent classes we error at parse time + """ + with pytest.raises(TypeError, match=r"arguments 'foo', 'bar'"): + MockOperator.partial(task_id="a", foo="bar", bar=2) + + +def test_partial_on_invalid_pool_slots_raises() -> None: + """Test that when we pass an invalid value to pool_slots in partial(), + + i.e. if the value is not an integer, an error is raised at import time.""" + + with pytest.raises(TypeError, match="'<' not supported between instances of 'str' and 'int'"): + MockOperator.partial(task_id="pool_slots_test", pool="test", pool_slots="a").expand(arg1=[1, 2, 3]) + + +# def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expected): + + +# def test_expand_mapped_task_failed_state_in_db(dag_maker, session): + + +# def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): + + +def test_mapped_task_applies_default_args_classic(): + with DAG("test", default_args={"execution_timeout": timedelta(minutes=30)}) as dag: + MockOperator(task_id="simple", arg1=None, arg2=0) + MockOperator.partial(task_id="mapped").expand(arg1=[1], arg2=[2, 3]) + + assert dag.get_task("simple").execution_timeout == timedelta(minutes=30) + assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30) + + +def test_mapped_task_applies_default_args_taskflow(): + with DAG("test", default_args={"execution_timeout": timedelta(minutes=30)}) as dag: + + @dag.task + def simple(arg): + pass + + @dag.task + def mapped(arg): + pass + + simple(arg=0) + mapped.expand(arg=[1, 2]) + + assert dag.get_task("simple").execution_timeout == timedelta(minutes=30) + assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30) + + +@pytest.mark.parametrize( + "dag_params, task_params, expected_partial_params", + [ + pytest.param(None, None, ParamsDict(), id="none"), + pytest.param({"a": -1}, None, ParamsDict({"a": -1}), id="dag"), + pytest.param(None, {"b": -2}, ParamsDict({"b": -2}), id="task"), + pytest.param({"a": -1}, {"b": -2}, ParamsDict({"a": -1, "b": -2}), id="merge"), + ], +) +def test_mapped_expand_against_params(dag_params, task_params, expected_partial_params): + with DAG("test", params=dag_params) as dag: + MockOperator.partial(task_id="t", params=task_params).expand(params=[{"c": "x"}, {"d": 1}]) + + t = dag.get_task("t") + assert isinstance(t, MappedOperator) + assert t.params == expected_partial_params + assert t.expand_input.value == {"params": [{"c": "x"}, {"d": 1}]} + + +# def test_mapped_render_template_fields_validating_operator(dag_maker, session, tmp_path): + + +# def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, tmp_path): + + +# def test_mapped_render_nested_template_fields(dag_maker, session): + + +# def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis, expected): + + +# def test_expand_mapped_task_instance_with_named_index( + + +# def test_expand_mapped_task_task_instance_mutation_hook(dag_maker, session, create_mapped_task) -> None: + + +# def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, map_index, expected): + + +def test_xcomarg_property_of_mapped_operator(): + with DAG("test_xcomarg_property_of_mapped_operator"): + op_a = MockOperator.partial(task_id="a").expand(arg1=["x", "y", "z"]) + + assert op_a.output == XComArg(op_a) + + +def test_set_xcomarg_dependencies_with_mapped_operator(): + with DAG("test_set_xcomargs_dependencies_with_mapped_operator"): + op1 = MockOperator.partial(task_id="op1").expand(arg1=[1, 2, 3]) + op2 = MockOperator.partial(task_id="op2").expand(arg2=["a", "b", "c"]) + op3 = MockOperator(task_id="op3", arg1=op1.output) + op4 = MockOperator(task_id="op4", arg1=[op1.output, op2.output]) + op5 = MockOperator(task_id="op5", arg1={"op1": op1.output, "op2": op2.output}) + + assert op1 in op3.upstream_list + assert op1 in op4.upstream_list + assert op2 in op4.upstream_list + assert op1 in op5.upstream_list + assert op2 in op5.upstream_list + + +# def test_all_xcomargs_from_mapped_tasks_are_consumable(dag_maker, session): + + +def test_task_mapping_with_task_group_context(): + from airflow.sdk.definitions.taskgroup import TaskGroup + + with DAG("test-dag") as dag: + task1 = BaseOperator(task_id="op1") + finish = MockOperator(task_id="finish") + + with TaskGroup("test-group") as group: + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) + + task1 >> group >> finish + + assert task1.downstream_list == [mapped] + assert mapped.upstream_list == [task1] + + assert mapped in dag.tasks + assert mapped.task_group == group + + assert finish.upstream_list == [mapped] + assert mapped.downstream_list == [finish] + + +def test_task_mapping_with_explicit_task_group(): + from airflow.sdk.definitions.taskgroup import TaskGroup + + with DAG("test-dag") as dag: + task1 = BaseOperator(task_id="op1") + finish = MockOperator(task_id="finish") + + group = TaskGroup("test-group") + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task_2", task_group=group).expand(arg2=literal) + + task1 >> group >> finish + + assert task1.downstream_list == [mapped] + assert mapped.upstream_list == [task1] + + assert mapped in dag.tasks + assert mapped.task_group == group + + assert finish.upstream_list == [mapped] + assert mapped.downstream_list == [finish] diff --git a/tests/models/test_taskmixin.py b/task_sdk/tests/defintions/test_minxins.py similarity index 75% rename from tests/models/test_taskmixin.py rename to task_sdk/tests/defintions/test_minxins.py index 7cfd58732889f..83b4d6eabefdf 100644 --- a/tests/models/test_taskmixin.py +++ b/task_sdk/tests/defintions/test_minxins.py @@ -22,10 +22,8 @@ import pytest from airflow.decorators import setup, task, teardown -from airflow.models.baseoperator import BaseOperator -from airflow.operators.empty import EmptyOperator - -pytestmark = pytest.mark.db_test +from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.definitions.dag import DAG def cleared_tasks(dag, task_id): @@ -73,14 +71,14 @@ def my_task(): @pytest.mark.parametrize( "setup_type, work_type, teardown_type", itertools.product(["classic", "taskflow"], repeat=3) ) -def test_as_teardown(dag_maker, setup_type, work_type, teardown_type): +def test_as_teardown(setup_type, work_type, teardown_type): """ Check that as_teardown works properly as implemented in PlainXComArg It should mark the teardown as teardown, and if a task is provided, it should mark that as setup and set it as a direct upstream. """ - with dag_maker() as dag: + with DAG("test") as dag: s1 = make_task(name="s1", type_=setup_type) w1 = make_task(name="w1", type_=work_type) t1 = make_task(name="t1", type_=teardown_type) @@ -106,7 +104,7 @@ def test_as_teardown(dag_maker, setup_type, work_type, teardown_type): @pytest.mark.parametrize( "setup_type, work_type, teardown_type", itertools.product(["classic", "taskflow"], repeat=3) ) -def test_as_teardown_oneline(dag_maker, setup_type, work_type, teardown_type): +def test_as_teardown_oneline(setup_type, work_type, teardown_type): """ Check that as_teardown implementations work properly. Tests all combinations of taskflow and classic. @@ -114,7 +112,7 @@ def test_as_teardown_oneline(dag_maker, setup_type, work_type, teardown_type): and set it as a direct upstream. """ - with dag_maker() as dag: + with DAG("test") as dag: s1 = make_task(name="s1", type_=setup_type) w1 = make_task(name="w1", type_=work_type) t1 = make_task(name="t1", type_=teardown_type) @@ -156,10 +154,10 @@ def test_as_teardown_oneline(dag_maker, setup_type, work_type, teardown_type): @pytest.mark.parametrize("type_", ["classic", "taskflow"]) -def test_cannot_be_both_setup_and_teardown(dag_maker, type_): +def test_cannot_be_both_setup_and_teardown(type_): # can't change a setup task to a teardown task or vice versa for first, second in [("setup", "teardown"), ("teardown", "setup")]: - with dag_maker(): + with DAG("test"): s1 = make_task(name="s1", type_=type_) getattr(s1, f"as_{first}")() with pytest.raises( @@ -168,8 +166,8 @@ def test_cannot_be_both_setup_and_teardown(dag_maker, type_): getattr(s1, f"as_{second}")() -def test_cannot_set_on_failure_fail_dagrun_unless_teardown_classic(dag_maker): - with dag_maker(): +def test_cannot_set_on_failure_fail_dagrun_unless_teardown_classic(): + with DAG("test"): t = make_task(name="t", type_="classic") assert t.is_teardown is False with pytest.raises( @@ -179,7 +177,7 @@ def test_cannot_set_on_failure_fail_dagrun_unless_teardown_classic(dag_maker): t.on_failure_fail_dagrun = True -def test_cannot_set_on_failure_fail_dagrun_unless_teardown_taskflow(dag_maker): +def test_cannot_set_on_failure_fail_dagrun_unless_teardown_taskflow(): @task(on_failure_fail_dagrun=True) def my_bad_task(): pass @@ -188,7 +186,7 @@ def my_bad_task(): def my_ok_task(): pass - with dag_maker(): + with DAG("test"): with pytest.raises( ValueError, match="Cannot set task on_failure_fail_dagrun for " @@ -218,12 +216,12 @@ def my_ok_task(): class TestDependencyMixin: - def test_set_upstream(self, dag_maker): - with dag_maker("test_set_upstream"): - op_a = EmptyOperator(task_id="a") - op_b = EmptyOperator(task_id="b") - op_c = EmptyOperator(task_id="c") - op_d = EmptyOperator(task_id="d") + def test_set_upstream(self): + with DAG("test_set_upstream"): + op_a = BaseOperator(task_id="a") + op_b = BaseOperator(task_id="b") + op_c = BaseOperator(task_id="c") + op_d = BaseOperator(task_id="d") op_d << op_c << op_b << op_a @@ -231,12 +229,12 @@ def test_set_upstream(self, dag_maker): assert [op_b] == op_c.upstream_list assert [op_c] == op_d.upstream_list - def test_set_downstream(self, dag_maker): - with dag_maker("test_set_downstream"): - op_a = EmptyOperator(task_id="a") - op_b = EmptyOperator(task_id="b") - op_c = EmptyOperator(task_id="c") - op_d = EmptyOperator(task_id="d") + def test_set_downstream(self): + with DAG("test_set_downstream"): + op_a = BaseOperator(task_id="a") + op_b = BaseOperator(task_id="b") + op_c = BaseOperator(task_id="c") + op_d = BaseOperator(task_id="d") op_a >> op_b >> op_c >> op_d @@ -244,12 +242,12 @@ def test_set_downstream(self, dag_maker): assert [op_b] == op_c.upstream_list assert [op_c] == op_d.upstream_list - def test_set_upstream_list(self, dag_maker): - with dag_maker("test_set_upstream_list"): - op_a = EmptyOperator(task_id="a") - op_b = EmptyOperator(task_id="b") - op_c = EmptyOperator(task_id="c") - op_d = EmptyOperator(task_id="d") + def test_set_upstream_list(self): + with DAG("test_set_upstream_list"): + op_a = BaseOperator(task_id="a") + op_b = BaseOperator(task_id="b") + op_c = BaseOperator(task_id="c") + op_d = BaseOperator(task_id="d") [op_d, op_c << op_b] << op_a @@ -257,12 +255,12 @@ def test_set_upstream_list(self, dag_maker): assert [op_a] == op_d.upstream_list assert [op_b] == op_c.upstream_list - def test_set_downstream_list(self, dag_maker): - with dag_maker("test_set_downstream_list"): - op_a = EmptyOperator(task_id="a") - op_b = EmptyOperator(task_id="b") - op_c = EmptyOperator(task_id="c") - op_d = EmptyOperator(task_id="d") + def test_set_downstream_list(self): + with DAG("test_set_downstream_list"): + op_a = BaseOperator(task_id="a") + op_b = BaseOperator(task_id="b") + op_c = BaseOperator(task_id="c") + op_d = BaseOperator(task_id="d") op_a >> [op_b >> op_c, op_d] @@ -270,12 +268,12 @@ def test_set_downstream_list(self, dag_maker): assert [op_a] == op_d.upstream_list assert {op_a, op_b} == set(op_c.upstream_list) - def test_set_upstream_inner_list(self, dag_maker): - with dag_maker("test_set_upstream_inner_list"): - op_a = EmptyOperator(task_id="a") - op_b = EmptyOperator(task_id="b") - op_c = EmptyOperator(task_id="c") - op_d = EmptyOperator(task_id="d") + def test_set_upstream_inner_list(self): + with DAG("test_set_upstream_inner_list"): + op_a = BaseOperator(task_id="a") + op_b = BaseOperator(task_id="b") + op_c = BaseOperator(task_id="c") + op_d = BaseOperator(task_id="d") with pytest.raises(AttributeError) as e_info: [op_d << [op_c, op_b]] << op_a @@ -285,12 +283,12 @@ def test_set_upstream_inner_list(self, dag_maker): assert op_c.upstream_list == [] assert {op_b, op_c} == set(op_d.upstream_list) - def test_set_downstream_inner_list(self, dag_maker): - with dag_maker("test_set_downstream_inner_list"): - op_a = EmptyOperator(task_id="a") - op_b = EmptyOperator(task_id="b") - op_c = EmptyOperator(task_id="c") - op_d = EmptyOperator(task_id="d") + def test_set_downstream_inner_list(self): + with DAG("test_set_downstream_inner_list"): + op_a = BaseOperator(task_id="a") + op_b = BaseOperator(task_id="b") + op_c = BaseOperator(task_id="c") + op_d = BaseOperator(task_id="d") op_a >> [[op_b, op_c] >> op_d] @@ -298,13 +296,13 @@ def test_set_downstream_inner_list(self, dag_maker): assert op_c.upstream_list == [] assert {op_b, op_c, op_a} == set(op_d.upstream_list) - def test_set_upstream_list_subarray(self, dag_maker): - with dag_maker("test_set_upstream_list"): - op_a = EmptyOperator(task_id="a") - op_b_1 = EmptyOperator(task_id="b_1") - op_b_2 = EmptyOperator(task_id="b_2") - op_c = EmptyOperator(task_id="c") - op_d = EmptyOperator(task_id="d") + def test_set_upstream_list_subarray(self): + with DAG("test_set_upstream_list"): + op_a = BaseOperator(task_id="a") + op_b_1 = BaseOperator(task_id="b_1") + op_b_2 = BaseOperator(task_id="b_2") + op_c = BaseOperator(task_id="c") + op_d = BaseOperator(task_id="d") with pytest.raises(AttributeError) as e_info: [op_d, op_c << [op_b_1, op_b_2]] << op_a @@ -316,13 +314,13 @@ def test_set_upstream_list_subarray(self, dag_maker): assert op_d.upstream_list == [] assert {op_b_1, op_b_2} == set(op_c.upstream_list) - def test_set_downstream_list_subarray(self, dag_maker): - with dag_maker("test_set_downstream_list"): - op_a = EmptyOperator(task_id="a") - op_b_1 = EmptyOperator(task_id="b_1") - op_b_2 = EmptyOperator(task_id="b2") - op_c = EmptyOperator(task_id="c") - op_d = EmptyOperator(task_id="d") + def test_set_downstream_list_subarray(self): + with DAG("test_set_downstream_list"): + op_a = BaseOperator(task_id="a") + op_b_1 = BaseOperator(task_id="b_1") + op_b_2 = BaseOperator(task_id="b2") + op_c = BaseOperator(task_id="c") + op_d = BaseOperator(task_id="d") op_a >> [[op_b_1, op_b_2] >> op_c, op_d] diff --git a/task_sdk/tests/execution_time/conftest.py b/task_sdk/tests/execution_time/conftest.py index d2a961a5307da..832f2b60ca351 100644 --- a/task_sdk/tests/execution_time/conftest.py +++ b/task_sdk/tests/execution_time/conftest.py @@ -19,7 +19,6 @@ import sys from typing import TYPE_CHECKING -from unittest import mock if TYPE_CHECKING: from datetime import datetime @@ -42,14 +41,6 @@ def disable_capturing(): sys.stdin, sys.stdout, sys.stderr = old_in, old_out, old_err -@pytest.fixture -def mock_supervisor_comms(): - with mock.patch( - "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True - ) as supervisor_comms: - yield supervisor_comms - - @pytest.fixture def mocked_parse(spy_agency): """ diff --git a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py index 7a4d802f6cc15..79dcde6cbd2a5 100644 --- a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py @@ -129,7 +129,7 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None): self.app.dag_bag.sync_to_db("dags-folder", None) session.flush() - mapped.expand_mapped_task(dr.run_id, session=session) + TaskMap.expand_mapped_task(mapped, dr.run_id, session=session) @pytest.fixture def one_task_with_mapped_tis(self, dag_maker, session): diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py b/tests/api_fastapi/core_api/routes/public/test_task_instances.py index 24572c37906dc..f09f6293e9082 100644 --- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py @@ -529,7 +529,7 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None): dagbag.sync_to_db("dags-folder", None) session.flush() - mapped.expand_mapped_task(dr.run_id, session=session) + TaskMap.expand_mapped_task(mapped, dr.run_id, session=session) @pytest.fixture def one_task_with_mapped_tis(self, dag_maker, session): diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index 6ffe935348e6d..bbaa4236f3100 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -26,16 +26,14 @@ from airflow.decorators import setup, task as task_decorator, teardown from airflow.decorators.base import DecoratedMappedOperator from airflow.exceptions import AirflowException, XComNotFound -from airflow.models.baseoperator import BaseOperator -from airflow.models.dag import DAG from airflow.models.expandinput import DictOfListsExpandInput -from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap -from airflow.models.xcom_arg import PlainXComArg, XComArg +from airflow.models.xcom_arg import PlainXComArg +from airflow.sdk import DAG, BaseOperator, TaskGroup, XComArg +from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.utils import timezone from airflow.utils.state import State -from airflow.utils.task_group import TaskGroup from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 0ff8896746a4c..13dad2bc8a4c3 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -411,28 +411,6 @@ def task0(): copy.deepcopy(dag) -@pytest.mark.db_test -def test_find_mapped_dependants_in_another_group(dag_maker): - from airflow.utils.task_group import TaskGroup - - @task_decorator - def gen(x): - return list(range(x)) - - @task_decorator - def add(x, y): - return x + y - - with dag_maker(): - with TaskGroup(group_id="g1"): - gen_result = gen(3) - with TaskGroup(group_id="g2"): - add_result = add.partial(y=1).expand(x=gen_result) - - dependants = list(gen_result.operator.iter_mapped_dependants()) - assert dependants == [add_result.operator] - - def get_states(dr): """ For a given dag run, get a dict of states. diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index fb3519b708928..77b749156af3c 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -1514,26 +1514,24 @@ def test_clear_set_dagrun_state(self, dag_run_state): assert dagrun.state == dag_run_state @pytest.mark.parametrize("dag_run_state", [DagRunState.QUEUED, DagRunState.RUNNING]) - def test_clear_set_dagrun_state_for_mapped_task(self, dag_run_state): + @pytest.mark.need_serialized_dag + def test_clear_set_dagrun_state_for_mapped_task(self, dag_maker, dag_run_state): dag_id = "test_clear_set_dagrun_state" self._clean_up(dag_id) task_id = "t1" - dag = DAG(dag_id, schedule=None, start_date=DEFAULT_DATE, max_active_runs=1) + with dag_maker(dag_id, schedule=None, start_date=DEFAULT_DATE, max_active_runs=1) as dag: - @dag.task - def make_arg_lists(): - return [[1], [2], [{"a": "b"}]] + @task_decorator + def make_arg_lists(): + return [[1], [2], [{"a": "b"}]] - def consumer(value): - print(value) + def consumer(value): + print(value) - mapped = PythonOperator.partial(task_id=task_id, dag=dag, python_callable=consumer).expand( - op_args=make_arg_lists() - ) + PythonOperator.partial(task_id=task_id, python_callable=consumer).expand(op_args=make_arg_lists()) - session = settings.Session() - triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} + session = dag_maker.session dagrun_1 = dag.create_dagrun( run_id="backfill", run_type=DagRunType.BACKFILL_JOB, @@ -1542,8 +1540,10 @@ def consumer(value): logical_date=DEFAULT_DATE, session=session, data_interval=(DEFAULT_DATE, DEFAULT_DATE), - **triggered_by_kwargs, + triggered_by=DagRunTriggeredByType.TEST, ) + # Get the (de)serialized MappedOperator + mapped = dag.get_task(task_id) expand_mapped_task(mapped, dagrun_1.run_id, "make_arg_lists", length=2, session=session) upstream_ti = dagrun_1.get_task_instance("make_arg_lists", session=session) @@ -2732,20 +2732,21 @@ def get_ti_from_db(task): } +@pytest.mark.need_serialized_dag def test_set_task_instance_state_mapped(dag_maker, session): """Test that when setting an individual mapped TI that the other TIs are not affected""" task_id = "t1" with dag_maker(session=session) as dag: - @dag.task + @task_decorator def make_arg_lists(): return [[1], [2], [{"a": "b"}]] def consumer(value): print(value) - mapped = PythonOperator.partial(task_id=task_id, dag=dag, python_callable=consumer).expand( + mapped = PythonOperator.partial(task_id=task_id, python_callable=consumer).expand( op_args=make_arg_lists() ) @@ -2755,6 +2756,8 @@ def consumer(value): run_type=DagRunType.SCHEDULED, state=DagRunState.FAILED, ) + + mapped = dag.get_task(task_id) expand_mapped_task(mapped, dr1.run_id, "make_arg_lists", length=2, session=session) # set_state(future=True) only applies to scheduled runs diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index ff107d8204d72..92130bd3471c8 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -18,12 +18,10 @@ from __future__ import annotations from collections import defaultdict -from datetime import timedelta from typing import TYPE_CHECKING from unittest import mock from unittest.mock import patch -import pendulum import pytest from sqlalchemy import select @@ -31,16 +29,13 @@ from airflow.exceptions import AirflowSkipException from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG -from airflow.models.mappedoperator import MappedOperator -from airflow.models.param import ParamsDict from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap -from airflow.models.xcom_arg import XComArg from airflow.providers.standard.operators.python import PythonOperator +from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.utils.state import TaskInstanceState from airflow.utils.task_group import TaskGroup from airflow.utils.task_instance_session import set_current_task_instance_session -from airflow.utils.trigger_rule import TriggerRule from airflow.utils.xcom import XCOM_RETURN_KEY from tests.models import DEFAULT_DATE @@ -57,25 +52,6 @@ from airflow.sdk.definitions.context import Context -def test_task_mapping_with_dag(): - with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE) as dag: - task1 = BaseOperator(task_id="op1") - literal = ["a", "b", "c"] - mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) - finish = MockOperator(task_id="finish") - - task1 >> mapped >> finish - - assert task1.downstream_list == [mapped] - assert mapped in dag.tasks - assert mapped.task_group == dag.task_group - # At parse time there should only be three tasks! - assert len(dag.tasks) == 3 - - assert finish.upstream_list == [mapped] - assert mapped.downstream_list == [finish] - - @patch("airflow.models.abstractoperator.AbstractOperator.render_template") def test_task_mapping_with_dag_and_list_of_pandas_dataframe(mock_render_template, caplog): class UnrenderableClass: @@ -105,66 +81,6 @@ def execute(self, context: Context): mock_render_template.assert_called() -def test_task_mapping_without_dag_context(): - with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE) as dag: - task1 = BaseOperator(task_id="op1") - literal = ["a", "b", "c"] - mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) - - task1 >> mapped - - assert isinstance(mapped, MappedOperator) - assert mapped in dag.tasks - assert task1.downstream_list == [mapped] - assert mapped in dag.tasks - # At parse time there should only be two tasks! - assert len(dag.tasks) == 2 - - -def test_task_mapping_default_args(): - default_args = {"start_date": DEFAULT_DATE.now(), "owner": "test"} - with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE, default_args=default_args): - task1 = BaseOperator(task_id="op1") - literal = ["a", "b", "c"] - mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) - - task1 >> mapped - - assert mapped.partial_kwargs["owner"] == "test" - assert mapped.start_date == pendulum.instance(default_args["start_date"]) - - -def test_task_mapping_override_default_args(): - default_args = {"retries": 2, "start_date": DEFAULT_DATE.now()} - with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE, default_args=default_args): - literal = ["a", "b", "c"] - mapped = MockOperator.partial(task_id="task", retries=1).expand(arg2=literal) - - # retries should be 1 because it is provided as a partial arg - assert mapped.partial_kwargs["retries"] == 1 - # start_date should be equal to default_args["start_date"] because it is not provided as partial arg - assert mapped.start_date == pendulum.instance(default_args["start_date"]) - # owner should be equal to Airflow default owner (airflow) because it is not provided at all - assert mapped.owner == "airflow" - - -def test_map_unknown_arg_raises(): - with pytest.raises(TypeError, match=r"argument 'file'"): - BaseOperator.partial(task_id="a").expand(file=[1, 2, {"a": "b"}]) - - -def test_map_xcom_arg(): - """Test that dependencies are correct when mapping with an XComArg""" - with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE): - task1 = BaseOperator(task_id="op1") - mapped = MockOperator.partial(task_id="task_2").expand(arg2=task1.output) - finish = MockOperator(task_id="finish") - - mapped >> finish - - assert task1.downstream_list == [mapped] - - def test_map_xcom_arg_multiple_upstream_xcoms(dag_maker, session): """Test that the correct number of downstream tasks are generated when mapping with an XComArg""" @@ -189,12 +105,12 @@ def execute(self, context): ti_1 = dr.get_task_instance("task_1", session) ti_1.run() - ti_2s, _ = task2.expand_mapped_task(dr.run_id, session=session) + ti_2s, _ = TaskMap.expand_mapped_task(task2, dr.run_id, session=session) for ti in ti_2s: ti.refresh_from_task(dag.get_task("task_2")) ti.run() - ti_3s, _ = task3.expand_mapped_task(dr.run_id, session=session) + ti_3s, _ = TaskMap.expand_mapped_task(task3, dr.run_id, session=session) for ti in ti_3s: ti.refresh_from_task(dag.get_task("task_3")) ti.run() @@ -202,37 +118,6 @@ def execute(self, context): assert len(ti_3s) == len(ti_2s) == len(upstream_return) -def test_partial_on_instance() -> None: - """`.partial` on an instance should fail -- it's only designed to be called on classes""" - with pytest.raises(TypeError): - MockOperator(task_id="a").partial() - - -def test_partial_on_class() -> None: - # Test that we accept args for superclasses too - op = MockOperator.partial(task_id="a", arg1="a", trigger_rule=TriggerRule.ONE_FAILED) - assert op.kwargs["arg1"] == "a" - assert op.kwargs["trigger_rule"] == TriggerRule.ONE_FAILED - - -def test_partial_on_class_invalid_ctor_args() -> None: - """Test that when we pass invalid args to partial(). - - I.e. if an arg is not known on the class or any of its parent classes we error at parse time - """ - with pytest.raises(TypeError, match=r"arguments 'foo', 'bar'"): - MockOperator.partial(task_id="a", foo="bar", bar=2) - - -def test_partial_on_invalid_pool_slots_raises() -> None: - """Test that when we pass an invalid value to pool_slots in partial(), - - i.e. if the value is not an integer, an error is raised at import time.""" - - with pytest.raises(TypeError, match="'<' not supported between instances of 'str' and 'int'"): - MockOperator.partial(task_id="pool_slots_test", pool="test", pool_slots="a").expand(arg1=[1, 2, 3]) - - @pytest.mark.parametrize( ["num_existing_tis", "expected"], ( @@ -288,7 +173,7 @@ def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expec session.add(ti) session.flush() - mapped.expand_mapped_task(dr.run_id, session=session) + TaskMap.expand_mapped_task(mapped, dr.run_id, session=session) indices = ( session.query(TaskInstance.map_index, TaskInstance.state) @@ -339,7 +224,7 @@ def test_expand_mapped_task_failed_state_in_db(dag_maker, session): # Make sure we have the faulty state in the database assert indices == [(-1, None), (0, "success"), (1, "success")] - mapped.expand_mapped_task(dr.run_id, session=session) + TaskMap.expand_mapped_task(mapped, dr.run_id, session=session) indices = ( session.query(TaskInstance.map_index, TaskInstance.state) @@ -370,52 +255,6 @@ def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): assert indices == [(-1, TaskInstanceState.SKIPPED)] -def test_mapped_task_applies_default_args_classic(dag_maker): - with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)}) as dag: - MockOperator(task_id="simple", arg1=None, arg2=0) - MockOperator.partial(task_id="mapped").expand(arg1=[1], arg2=[2, 3]) - - assert dag.get_task("simple").execution_timeout == timedelta(minutes=30) - assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30) - - -def test_mapped_task_applies_default_args_taskflow(dag_maker): - with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)}) as dag: - - @dag.task - def simple(arg): - pass - - @dag.task - def mapped(arg): - pass - - simple(arg=0) - mapped.expand(arg=[1, 2]) - - assert dag.get_task("simple").execution_timeout == timedelta(minutes=30) - assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30) - - -@pytest.mark.parametrize( - "dag_params, task_params, expected_partial_params", - [ - pytest.param(None, None, ParamsDict(), id="none"), - pytest.param({"a": -1}, None, ParamsDict({"a": -1}), id="dag"), - pytest.param(None, {"b": -2}, ParamsDict({"b": -2}), id="task"), - pytest.param({"a": -1}, {"b": -2}, ParamsDict({"a": -1, "b": -2}), id="merge"), - ], -) -def test_mapped_expand_against_params(dag_maker, dag_params, task_params, expected_partial_params): - with dag_maker(params=dag_params) as dag: - MockOperator.partial(task_id="t", params=task_params).expand(params=[{"c": "x"}, {"d": 1}]) - - t = dag.get_task("t") - assert isinstance(t, MappedOperator) - assert t.params == expected_partial_params - assert t.expand_input.value == {"params": [{"c": "x"}, {"d": 1}]} - - def test_mapped_render_template_fields_validating_operator(dag_maker, session, tmp_path): file_template_dir = tmp_path / "path" / "to" file_template_dir.mkdir(parents=True, exist_ok=True) @@ -609,7 +448,7 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis session.add(ti) session.flush() - mapped.expand_mapped_task(dr.run_id, session=session) + TaskMap.expand_mapped_task(mapped, dr.run_id, session=session) indices = ( session.query(TaskInstance.map_index, TaskInstance.state) @@ -785,29 +624,6 @@ def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, ses assert ti.task.arg2 == "a" -def test_xcomarg_property_of_mapped_operator(dag_maker): - with dag_maker("test_xcomarg_property_of_mapped_operator"): - op_a = MockOperator.partial(task_id="a").expand(arg1=["x", "y", "z"]) - dag_maker.create_dagrun() - - assert op_a.output == XComArg(op_a) - - -def test_set_xcomarg_dependencies_with_mapped_operator(dag_maker): - with dag_maker("test_set_xcomargs_dependencies_with_mapped_operator"): - op1 = MockOperator.partial(task_id="op1").expand(arg1=[1, 2, 3]) - op2 = MockOperator.partial(task_id="op2").expand(arg2=["a", "b", "c"]) - op3 = MockOperator(task_id="op3", arg1=op1.output) - op4 = MockOperator(task_id="op4", arg1=[op1.output, op2.output]) - op5 = MockOperator(task_id="op5", arg1={"op1": op1.output, "op2": op2.output}) - - assert op1 in op3.upstream_list - assert op1 in op4.upstream_list - assert op2 in op4.upstream_list - assert op1 in op5.upstream_list - assert op2 in op5.upstream_list - - def test_all_xcomargs_from_mapped_tasks_are_consumable(dag_maker, session): class PushXcomOperator(MockOperator): def __init__(self, arg1, **kwargs): @@ -830,48 +646,6 @@ def execute(self, context): ti.run() -def test_task_mapping_with_task_group_context(): - with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE) as dag: - task1 = BaseOperator(task_id="op1") - finish = MockOperator(task_id="finish") - - with TaskGroup("test-group") as group: - literal = ["a", "b", "c"] - mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) - - task1 >> group >> finish - - assert task1.downstream_list == [mapped] - assert mapped.upstream_list == [task1] - - assert mapped in dag.tasks - assert mapped.task_group == group - - assert finish.upstream_list == [mapped] - assert mapped.downstream_list == [finish] - - -def test_task_mapping_with_explicit_task_group(): - with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE) as dag: - task1 = BaseOperator(task_id="op1") - finish = MockOperator(task_id="finish") - - group = TaskGroup("test-group") - literal = ["a", "b", "c"] - mapped = MockOperator.partial(task_id="task_2", task_group=group).expand(arg2=literal) - - task1 >> group >> finish - - assert task1.downstream_list == [mapped] - assert mapped.upstream_list == [task1] - - assert mapped in dag.tasks - assert mapped.task_group == group - - assert finish.upstream_list == [mapped] - assert mapped.downstream_list == [finish] - - class TestMappedSetupTeardown: @staticmethod def get_states(dr): diff --git a/tests/models/test_renderedtifields.py b/tests/models/test_renderedtifields.py index 3bdf9810a1e87..dfb7b652252b4 100644 --- a/tests/models/test_renderedtifields.py +++ b/tests/models/test_renderedtifields.py @@ -33,6 +33,7 @@ from airflow.decorators import task as task_decorator from airflow.models import DagRun, Variable from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF +from airflow.models.taskmap import TaskMap from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.python import PythonOperator from airflow.utils.task_instance_session import set_current_task_instance_session @@ -289,7 +290,7 @@ def test_delete_old_records_mapped( run_id=f"run_{num}", logical_date=dag.start_date + timedelta(days=num) ) - mapped.expand_mapped_task(dr.run_id, session=dag_maker.session) + TaskMap.expand_mapped_task(mapped, dr.run_id, session=dag_maker.session) session.refresh(dr) for ti in dr.task_instances: ti.task = dag.get_task(ti.task_id) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index fca20b4bed00c..6494236a6cc20 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -4055,7 +4055,6 @@ def test_operator_field_with_serialization(self, create_task_instance): deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op) assert deserialized_op.task_type == "EmptyOperator" # Verify that ti.operator field renders correctly "with" Serialization - deserialized_op.dag = ti.task.dag ser_ti = TI(task=deserialized_op, run_id=None) assert ser_ti.operator == "EmptyOperator" assert ser_ti.task.operator_name == "EmptyOperator" @@ -4904,7 +4903,7 @@ def show(value): emit_ti.run() show_task = dag.get_task("show") - mapped_tis, max_map_index = show_task.expand_mapped_task(dag_run.run_id, session=session) + mapped_tis, max_map_index = TaskMap.expand_mapped_task(show_task, dag_run.run_id, session=session) assert max_map_index + 1 == len(mapped_tis) == len(upstream_return) for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): @@ -4938,7 +4937,7 @@ def show(number, letter): ti.run() show_task = dag.get_task("show") - mapped_tis, max_map_index = show_task.expand_mapped_task(dag_run.run_id, session=session) + mapped_tis, max_map_index = TaskMap.expand_mapped_task(show_task, dag_run.run_id, session=session) assert max_map_index + 1 == len(mapped_tis) == 6 for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): @@ -4978,7 +4977,7 @@ def show(a, b): show_task = dag.get_task("show") with pytest.raises(NotFullyPopulated): assert show_task.get_parse_time_mapped_ti_count() - mapped_tis, max_map_index = show_task.expand_mapped_task(dag_run.run_id, session=session) + mapped_tis, max_map_index = TaskMap.expand_mapped_task(show_task, dag_run.run_id, session=session) assert max_map_index + 1 == len(mapped_tis) == 4 for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): @@ -5002,7 +5001,7 @@ def show(a, b): show_task = dag.get_task("show") assert show_task.get_parse_time_mapped_ti_count() == 6 - mapped_tis, max_map_index = show_task.expand_mapped_task(dag_run.run_id, session=session) + mapped_tis, max_map_index = TaskMap.expand_mapped_task(show_task, dag_run.run_id, session=session) assert len(mapped_tis) == 0 # Expanded at parse! assert max_map_index == 5 @@ -5050,7 +5049,9 @@ def cmds(): ti.run() bash_task = dag.get_task("dynamic.bash") - mapped_bash_tis, max_map_index = bash_task.expand_mapped_task(dag_run.run_id, session=session) + mapped_bash_tis, max_map_index = TaskMap.expand_mapped_task( + bash_task, dag_run.run_id, session=session + ) assert max_map_index == 3 # 2 * 2 mapped tasks. for ti in sorted(mapped_bash_tis, key=operator.attrgetter("map_index")): ti.refresh_from_task(bash_task) @@ -5170,7 +5171,7 @@ def add_one(x): ti.run() task_345 = dag.get_task("add_one__1") - for ti in task_345.expand_mapped_task(dagrun.run_id, session=session)[0]: + for ti in TaskMap.expand_mapped_task(task_345, dagrun.run_id, session=session)[0]: ti.refresh_from_task(task_345) ti.run() diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py index b2e885e940833..a654ef1dd4a2e 100644 --- a/tests/models/test_xcom_arg_map.py +++ b/tests/models/test_xcom_arg_map.py @@ -365,7 +365,7 @@ def convert_zipped(zipped): def test_xcom_concat(dag_maker, session): - from airflow.models.xcom_arg import _ConcatResult + from airflow.sdk.definitions.xcom_arg import _ConcatResult agg_results = set() all_results = None diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 2b5d4cce4c7bb..84a63674e5119 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -657,6 +657,9 @@ def validate_deserialized_task( task, ): """Verify non-Airflow operators are casted to BaseOperator or MappedOperator.""" + from airflow.sdk import BaseOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator + assert not isinstance(task, SerializedBaseOperator) assert isinstance(task, (BaseOperator, MappedOperator)) @@ -2650,9 +2653,9 @@ def test_sensor_expand_deserialized_unmap(): deser_unmapped = deser_mapped.unmap(None) ser_normal = SerializedBaseOperator.serialize(normal) deser_normal = SerializedBaseOperator.deserialize(ser_normal) - deser_normal.dag = dag comps = set(BashSensor._comps) comps.remove("task_id") + comps.remove("dag_id") assert all(getattr(deser_unmapped, c, None) == getattr(deser_normal, c, None) for c in comps) @@ -2896,7 +2899,7 @@ def x(arg1, arg2, arg3): def test_mapped_task_group_serde(): from airflow.decorators.task_group import task_group from airflow.models.expandinput import DictOfListsExpandInput - from airflow.utils.task_group import MappedTaskGroup + from airflow.sdk.definitions.taskgroup import MappedTaskGroup with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as dag: diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py index 413c83eb51bdb..1b68f039eaa17 100644 --- a/tests_common/pytest_plugin.py +++ b/tests_common/pytest_plugin.py @@ -784,6 +784,8 @@ def __init__(self): self.dagbag = DagBag(os.devnull, include_examples=False, read_dags_from_db=False) def __enter__(self): + self.serialized_model = None + self.dag.__enter__() if self.want_serialized: return lazy_object_proxy.Proxy(self._serialized_dag) diff --git a/tests_common/test_utils/compat.py b/tests_common/test_utils/compat.py index 3bd4b89dfc1c4..7757a879a7c1a 100644 --- a/tests_common/test_utils/compat.py +++ b/tests_common/test_utils/compat.py @@ -21,7 +21,6 @@ from typing import TYPE_CHECKING, Any, cast from airflow.exceptions import AirflowOptionalProviderFeatureException -from airflow.models import Connection, Operator from airflow.utils.helpers import prune_dict from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS @@ -60,6 +59,7 @@ if TYPE_CHECKING: + from airflow.models import Connection from airflow.models.asset import ( AssetAliasModel, AssetDagRunQueue, @@ -68,6 +68,7 @@ DagScheduleAssetReference, TaskOutletAssetReference, ) + from airflow.sdk.types import Operator else: try: from airflow.models.asset import ( @@ -103,7 +104,7 @@ def deserialize_operator(serialized_operator: dict[str, Any]) -> Operator: # are updated to airflow 2.10+. from airflow.serialization.serialized_objects import BaseSerialization - return cast(Operator, BaseSerialization.deserialize(serialized_operator)) + return BaseSerialization.deserialize(serialized_operator) else: from airflow.serialization.serialized_objects import SerializedBaseOperator diff --git a/tests_common/test_utils/mapping.py b/tests_common/test_utils/mapping.py index 15a57679cf024..dbcb0ecc364e6 100644 --- a/tests_common/test_utils/mapping.py +++ b/tests_common/test_utils/mapping.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator def expand_mapped_task( @@ -41,4 +41,4 @@ def expand_mapped_task( ) session.flush() - mapped.expand_mapped_task(run_id, session=session) + TaskMap.expand_mapped_task(mapped, run_id, session=session) diff --git a/tests_common/test_utils/system_tests.py b/tests_common/test_utils/system_tests.py index 9be67c06822ed..5428eae9114ab 100644 --- a/tests_common/test_utils/system_tests.py +++ b/tests_common/test_utils/system_tests.py @@ -25,6 +25,7 @@ from airflow.utils.state import DagRunState if TYPE_CHECKING: + from airflow.models.dagrun import DagRun from airflow.sdk.definitions.context import Context logger = logging.getLogger(__name__) @@ -32,6 +33,8 @@ def get_test_run(dag, **test_kwargs): def callback(context: Context): + if TYPE_CHECKING: + assert isinstance(context["dag_run"], DagRun) ti = context["dag_run"].get_task_instances() if not ti: logger.warning("could not retrieve tasks that ran in the DAG, cannot display a summary")