From d1b2a4465387e9414e6c15f8df85591136a7784b Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Tue, 21 Jan 2025 11:05:19 +0000 Subject: [PATCH] Start porting mapped task to SDK (#45627) This PR restructures the Mapped Operator and Mapped Task Group code to live in the Task SDK at definition time. The big thing this change _does not do_ is make it possible to execute mapped tasks via the Task Execution API server etc -- that is up next. There were some un-avoidable changes to the scheduler/expansion part of mapped tasks here. Of note: `BaseOperator.get_mapped_ti_count` has moved from an instance method on BaseOperator to be a class method. The reason for this was that with the move of more and more of the "definition time" code into the TaskSDK BaseOperator and AbstractOperator it is no longer possible to add DB-accessing code to a base class and have it apply to the subclasses. (i.e. `airflow.models.abstractoperator.AbstractOperator` is now _not always_ in the MRO for tasks. Eventually that class will be deleted, but not yet) On a similar vein XComArg's `get_task_map_length` is also moved to a single dispatch class method on the TaskMap model since now the definition time objects live in the TaskSDK, and there is no realistic way to get a per-type subclass with DB logic (i.e. it's very complex to end up with a PlainDBXComArg, a MapDBXComArg, etc. that we can attach the method too) For those who aren't aware, singledispatch (and singledispatchmethod) are a part of the standard library when the type of the first argument is used to determine which implementation to call. If you are familiar with C++ or Java this is very similar to method overloading, the one caveat is that it _only_ examines the type of the first argument, not the full signature. --- airflow/api/common/mark_tasks.py | 3 + .../api_connexion/schemas/common_schema.py | 4 +- airflow/api_connexion/schemas/task_schema.py | 7 +- .../api_fastapi/core_api/datamodels/tasks.py | 23 +- .../api_fastapi/core_api/services/ui/grid.py | 9 +- .../commands/remote_commands/task_command.py | 5 + airflow/decorators/base.py | 6 +- airflow/models/abstractoperator.py | 212 +---- airflow/models/baseoperator.py | 324 ++----- airflow/models/dag.py | 5 +- airflow/models/dagrun.py | 24 +- airflow/models/expandinput.py | 16 +- airflow/models/mappedoperator.py | 763 +-------------- airflow/models/param.py | 2 +- airflow/models/renderedtifields.py | 2 +- airflow/models/taskinstance.py | 84 +- airflow/models/taskmap.py | 141 ++- airflow/models/xcom_arg.py | 774 ++------------- airflow/serialization/serialized_objects.py | 29 +- .../ti_deps/deps/mapped_task_upstream_dep.py | 2 +- airflow/ti_deps/deps/prev_dagrun_dep.py | 2 +- airflow/ti_deps/deps/trigger_rule_dep.py | 4 +- airflow/utils/context.py | 2 +- airflow/utils/dag_edges.py | 6 +- airflow/utils/dot_renderer.py | 4 +- airflow/utils/log/file_task_handler.py | 3 +- airflow/utils/setup_teardown.py | 2 +- airflow/utils/task_group.py | 35 +- airflow/utils/types.py | 26 - airflow/www/views.py | 3 +- .../elasticsearch/log/es_task_handler.py | 3 +- .../openlineage/utils/selective_enable.py | 13 +- .../providers/openlineage/utils/utils.py | 14 +- .../opensearch/log/os_task_handler.py | 3 +- .../providers/standard/operators/python.py | 2 +- .../tests/openlineage/utils/test_utils.py | 3 - .../tests/standard/operators/test_python.py | 1 + .../base_operator_partial_arguments.py | 14 +- .../pre_commit/template_context_key_sync.py | 3 + task_sdk/src/airflow/sdk/__init__.py | 18 +- .../definitions/_internal/abstractoperator.py | 165 +++- .../sdk/definitions/_internal/mixins.py | 2 +- .../airflow/sdk/definitions/_internal/node.py | 6 +- .../sdk/definitions/_internal/templater.py | 6 +- .../airflow/sdk/definitions/asset/__init__.py | 6 - .../airflow/sdk/definitions/baseoperator.py | 227 ++++- task_sdk/src/airflow/sdk/definitions/dag.py | 2 +- .../airflow/sdk/definitions/mappedoperator.py | 900 ++++++++++++++++++ .../src/airflow/sdk/definitions/taskgroup.py | 27 +- .../src/airflow/sdk/definitions/xcom_arg.py | 639 +++++++++++++ task_sdk/src/airflow/sdk/types.py | 5 +- task_sdk/tests/conftest.py | 9 + .../tests/defintions/test_baseoperator.py | 23 + .../tests/defintions/test_mappedoperator.py | 301 ++++++ .../tests/defintions/test_minxins.py | 126 ++- task_sdk/tests/execution_time/conftest.py | 9 - .../test_mapped_task_instance_endpoint.py | 2 +- .../routes/public/test_task_instances.py | 2 +- tests/decorators/test_python.py | 8 +- tests/models/test_baseoperator.py | 22 - tests/models/test_dag.py | 33 +- tests/models/test_mappedoperator.py | 238 +---- tests/models/test_renderedtifields.py | 3 +- tests/models/test_taskinstance.py | 15 +- tests/models/test_xcom_arg_map.py | 2 +- tests/serialization/test_dag_serialization.py | 7 +- tests_common/pytest_plugin.py | 2 + tests_common/test_utils/compat.py | 5 +- tests_common/test_utils/mapping.py | 4 +- tests_common/test_utils/system_tests.py | 3 + 70 files changed, 2911 insertions(+), 2484 deletions(-) create mode 100644 task_sdk/src/airflow/sdk/definitions/mappedoperator.py create mode 100644 task_sdk/src/airflow/sdk/definitions/xcom_arg.py create mode 100644 task_sdk/tests/defintions/test_mappedoperator.py rename tests/models/test_taskmixin.py => task_sdk/tests/defintions/test_minxins.py (75%) diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py index 76988f0f82aee..f249199651849 100644 --- a/airflow/api/common/mark_tasks.py +++ b/airflow/api/common/mark_tasks.py @@ -83,6 +83,9 @@ def set_state( if not run_id: raise ValueError("Received tasks with no run_id") + if TYPE_CHECKING: + assert isinstance(dag, DAG) + dag_run_ids = get_run_ids(dag, run_id, future, past, session=session) task_id_map_index_list = list(find_task_relatives(tasks, downstream, upstream)) # now look for the task instances that are affected diff --git a/airflow/api_connexion/schemas/common_schema.py b/airflow/api_connexion/schemas/common_schema.py index 569a745a62f52..9c08d6932af28 100644 --- a/airflow/api_connexion/schemas/common_schema.py +++ b/airflow/api_connexion/schemas/common_schema.py @@ -25,7 +25,7 @@ from dateutil import relativedelta from marshmallow import Schema, fields, validate -from airflow.models.mappedoperator import MappedOperator +from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.serialization.serialized_objects import SerializedBaseOperator @@ -128,7 +128,7 @@ def _get_module(self, obj): def _get_class_name(self, obj): if isinstance(obj, (MappedOperator, SerializedBaseOperator)): - return obj._task_type + return obj.task_type if isinstance(obj, type): return obj.__name__ return type(obj).__name__ diff --git a/airflow/api_connexion/schemas/task_schema.py b/airflow/api_connexion/schemas/task_schema.py index 1329a351d7b76..648187c0b70cc 100644 --- a/airflow/api_connexion/schemas/task_schema.py +++ b/airflow/api_connexion/schemas/task_schema.py @@ -26,7 +26,6 @@ TimeDeltaSchema, WeightRuleField, ) -from airflow.models.mappedoperator import MappedOperator if TYPE_CHECKING: from airflow.models.operator import Operator @@ -62,7 +61,7 @@ class TaskSchema(Schema): template_fields = fields.List(fields.String(), dump_only=True) downstream_task_ids = fields.List(fields.String(), dump_only=True) params = fields.Method("_get_params", dump_only=True) - is_mapped = fields.Method("_get_is_mapped", dump_only=True) + is_mapped = fields.Boolean(dump_only=True) doc_md = fields.String(dump_only=True) @staticmethod @@ -80,10 +79,6 @@ def _get_params(obj): params = obj.params return {k: v.dump() for k, v in params.items()} - @staticmethod - def _get_is_mapped(obj): - return isinstance(obj, MappedOperator) - class TaskCollection(NamedTuple): """List of Tasks with metadata.""" diff --git a/airflow/api_fastapi/core_api/datamodels/tasks.py b/airflow/api_fastapi/core_api/datamodels/tasks.py index 0806d4453c49a..13a9af7043680 100644 --- a/airflow/api_fastapi/core_api/datamodels/tasks.py +++ b/airflow/api_fastapi/core_api/datamodels/tasks.py @@ -26,29 +26,18 @@ from airflow.api_fastapi.common.types import TimeDeltaWithValidation from airflow.api_fastapi.core_api.base import BaseModel -from airflow.models.mappedoperator import MappedOperator -from airflow.serialization.serialized_objects import SerializedBaseOperator, encode_priority_weight_strategy +from airflow.serialization.serialized_objects import encode_priority_weight_strategy from airflow.task.priority_strategy import PriorityWeightStrategy def _get_class_ref(obj) -> dict[str, str | None]: """Return the class_ref dict for obj.""" - is_mapped_or_serialized = isinstance(obj, (MappedOperator, SerializedBaseOperator)) - - module_path = None - if is_mapped_or_serialized: - module_path = obj._task_module - else: + module_path = getattr(obj, "_task_module", None) + if module_path is None: module_type = inspect.getmodule(obj) module_path = module_type.__name__ if module_type else None - class_name = None - if is_mapped_or_serialized: - class_name = obj._task_type - elif obj.__class__ is type: - class_name = obj.__name__ - else: - class_name = type(obj).__name__ + class_name = obj.task_type return { "module_path": module_path, @@ -89,9 +78,7 @@ class TaskResponse(BaseModel): @model_validator(mode="before") @classmethod def validate_model(cls, task: Any) -> Any: - task.__dict__.update( - {"class_ref": _get_class_ref(task), "is_mapped": isinstance(task, MappedOperator)} - ) + task.__dict__.update({"class_ref": _get_class_ref(task), "is_mapped": task.is_mapped}) return task @field_validator("weight_rule", mode="before") diff --git a/airflow/api_fastapi/core_api/services/ui/grid.py b/airflow/api_fastapi/core_api/services/ui/grid.py index 739cf25b15153..74c44318a5a7a 100644 --- a/airflow/api_fastapi/core_api/services/ui/grid.py +++ b/airflow/api_fastapi/core_api/services/ui/grid.py @@ -32,11 +32,12 @@ ) from airflow.configuration import conf from airflow.exceptions import AirflowConfigException -from airflow.models import MappedOperator -from airflow.models.baseoperator import BaseOperator +from airflow.models.baseoperator import BaseOperator as DBBaseOperator from airflow.models.taskmap import TaskMap +from airflow.sdk import BaseOperator +from airflow.sdk.definitions.mappedoperator import MappedOperator +from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup from airflow.utils.state import TaskInstanceState -from airflow.utils.task_group import MappedTaskGroup, TaskGroup @cache @@ -139,7 +140,7 @@ def _get_total_task_count( node if isinstance(node, int) else ( - node.get_mapped_ti_count(run_id=run_id, session=session) + DBBaseOperator.get_mapped_ti_count(node, run_id=run_id, session=session) or 0 if isinstance(node, (MappedTaskGroup, MappedOperator)) else node ) diff --git a/airflow/cli/commands/remote_commands/task_command.py b/airflow/cli/commands/remote_commands/task_command.py index 51198af74961a..28ecb9d7bd798 100644 --- a/airflow/cli/commands/remote_commands/task_command.py +++ b/airflow/cli/commands/remote_commands/task_command.py @@ -206,6 +206,11 @@ def _get_ti( dag = task.dag if dag is None: raise ValueError("Cannot get task instance for a task not assigned to a DAG") + if not isinstance(dag, DAG): + # TODO: Task-SDK: Shouldn't really happen, and this command will go away before 3.0 + raise ValueError( + f"We need a {DAG.__module__}.DAG, but we got {type(dag).__module__}.{type(dag).__name__}!" + ) # this check is imperfect because diff dags could have tasks with same name # but in a task, dag_id is a property that accesses its dag, and we don't diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 3e4589b3c78db..9db6d058adeb5 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -41,11 +41,11 @@ ListOfDictsExpandInput, is_mappable, ) -from airflow.models.mappedoperator import MappedOperator, ensure_xcomarg_return_value -from airflow.models.xcom_arg import XComArg from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext from airflow.sdk.definitions.asset import Asset from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator +from airflow.sdk.definitions.mappedoperator import MappedOperator, ensure_xcomarg_return_value +from airflow.sdk.definitions.xcom_arg import XComArg from airflow.typing_compat import ParamSpec from airflow.utils import timezone from airflow.utils.context import KNOWN_CONTEXT_KEYS @@ -62,9 +62,9 @@ OperatorExpandArgument, OperatorExpandKwargsArgument, ) - from airflow.models.mappedoperator import ValidationSource from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.dag import DAG + from airflow.sdk.definitions.mappedoperator import ValidationSource from airflow.utils.task_group import TaskGroup diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index 732180d372946..b8fb54f6966fd 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -19,40 +19,36 @@ import datetime import inspect -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Iterable, Sequence from functools import cached_property from typing import TYPE_CHECKING, Any, Callable -import methodtools from sqlalchemy import select from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models.expandinput import NotFullyPopulated -from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator as TaskSDKAbstractOperator +from airflow.sdk.definitions._internal.abstractoperator import ( + AbstractOperator as TaskSDKAbstractOperator, + NotMapped as NotMapped, # Re-export this for compat +) from airflow.sdk.definitions.context import Context from airflow.utils.db import exists_query from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.sqlalchemy import with_row_locks from airflow.utils.state import State, TaskInstanceState -from airflow.utils.task_group import MappedTaskGroup from airflow.utils.trigger_rule import TriggerRule -from airflow.utils.weight_rule import WeightRule, db_safe_priority +from airflow.utils.weight_rule import db_safe_priority if TYPE_CHECKING: - import jinja2 # Slow imports. from sqlalchemy.orm import Session from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.dag import DAG as SchedulerDAG - from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance - from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.triggers.base import StartTriggerArgs - from airflow.utils.task_group import TaskGroup TaskStateChangeCallback = Callable[[Context], None] @@ -71,19 +67,12 @@ ) MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60) -DEFAULT_WEIGHT_RULE: WeightRule = WeightRule( - conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) -) DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta( "core", "default_task_execution_timeout" ) -class NotMapped(Exception): - """Raise if a task is neither mapped nor has any parent mapped groups.""" - - class AbstractOperator(LoggingMixin, TaskSDKAbstractOperator): """ Common implementation for operators, including unmapped and mapped. @@ -98,124 +87,8 @@ class AbstractOperator(LoggingMixin, TaskSDKAbstractOperator): :meta private: """ - trigger_rule: TriggerRule weight_rule: PriorityWeightStrategy - @property - def on_failure_fail_dagrun(self): - """ - Whether the operator should fail the dagrun on failure. - - :meta private: - """ - return self._on_failure_fail_dagrun - - @on_failure_fail_dagrun.setter - def on_failure_fail_dagrun(self, value): - """ - Setter for on_failure_fail_dagrun property. - - :meta private: - """ - if value is True and self.is_teardown is not True: - raise ValueError( - f"Cannot set task on_failure_fail_dagrun for " - f"'{self.task_id}' because it is not a teardown task." - ) - self._on_failure_fail_dagrun = value - - def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: - """ - Return mapped nodes that are direct dependencies of the current task. - - For now, this walks the entire DAG to find mapped nodes that has this - current task as an upstream. We cannot use ``downstream_list`` since it - only contains operators, not task groups. In the future, we should - provide a way to record an DAG node's all downstream nodes instead. - - Note that this does not guarantee the returned tasks actually use the - current task for task mapping, but only checks those task are mapped - operators, and are downstreams of the current task. - - To get a list of tasks that uses the current task for task mapping, use - :meth:`iter_mapped_dependants` instead. - """ - from airflow.models.mappedoperator import MappedOperator - from airflow.utils.task_group import TaskGroup - - def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]: - """ - Recursively walk children in a task group. - - This yields all direct children (including both tasks and task - groups), and all children of any task groups. - """ - for key, child in group.children.items(): - yield key, child - if isinstance(child, TaskGroup): - yield from _walk_group(child) - - dag = self.get_dag() - if not dag: - raise RuntimeError("Cannot check for mapped dependants when not attached to a DAG") - for key, child in _walk_group(dag.task_group): - if key == self.node_id: - continue - if not isinstance(child, (MappedOperator, MappedTaskGroup)): - continue - if self.node_id in child.upstream_task_ids: - yield child - - def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]: - """ - Return mapped nodes that depend on the current task the expansion. - - For now, this walks the entire DAG to find mapped nodes that has this - current task as an upstream. We cannot use ``downstream_list`` since it - only contains operators, not task groups. In the future, we should - provide a way to record an DAG node's all downstream nodes instead. - """ - return ( - downstream - for downstream in self._iter_all_mapped_downstreams() - if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies()) - ) - - def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: - """ - Return mapped task groups this task belongs to. - - Groups are returned from the innermost to the outmost. - - :meta private: - """ - if (group := self.task_group) is None: - return - # TODO: Task-SDK: this type ignore shouldn't be necessary, revisit once mapping support is fully in the - # SDK - yield from group.iter_mapped_task_groups() # type: ignore[misc] - - def get_closest_mapped_task_group(self) -> MappedTaskGroup | None: - """ - Get the mapped task group "closest" to this task in the DAG. - - :meta private: - """ - return next(self.iter_mapped_task_groups(), None) - - def get_needs_expansion(self) -> bool: - """ - Return true if the task is MappedOperator or is in a mapped task group. - - :meta private: - """ - if self._needs_expansion is None: - if self.get_closest_mapped_task_group() is not None: - self._needs_expansion = True - else: - self._needs_expansion = False - return self._needs_expansion - def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator: """ Get the "normal" operator from current abstract operator. @@ -343,43 +216,6 @@ def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None: return link.get_link(self.unmap(None), ti.dag_run.logical_date) # type: ignore[misc] return link.get_link(self.unmap(None), ti_key=ti.key) - @methodtools.lru_cache(maxsize=None) - def get_parse_time_mapped_ti_count(self) -> int: - """ - Return the number of mapped task instances that can be created on DAG run creation. - - This only considers literal mapped arguments, and would return *None* - when any non-literal values are used for mapping. - - :raise NotFullyPopulated: If non-literal mapped arguments are encountered. - :raise NotMapped: If the operator is neither mapped, nor has any parent - mapped task groups. - :return: Total number of mapped TIs this task should have. - """ - group = self.get_closest_mapped_task_group() - if group is None: - raise NotMapped - return group.get_parse_time_mapped_ti_count() - - def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: - """ - Return the number of mapped TaskInstances that can be created at run time. - - This considers both literal and non-literal mapped arguments, and the - result is therefore available when all depended tasks have finished. The - return value should be identical to ``parse_time_mapped_ti_count`` if - all mapped arguments are literal. - - :raise NotFullyPopulated: If upstream tasks are not all complete yet. - :raise NotMapped: If the operator is neither mapped, nor has any parent - mapped task groups. - :return: Total number of mapped TIs this task should have. - """ - group = self.get_closest_mapped_task_group() - if group is None: - raise NotMapped - return group.get_mapped_ti_count(run_id, session=session) - def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]: """ Create the mapped task instances for mapped task. @@ -390,16 +226,20 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence """ from sqlalchemy import func, or_ - from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.settings import task_instance_mutation_hook if not isinstance(self, (BaseOperator, MappedOperator)): - raise RuntimeError(f"cannot expand unrecognized operator type {type(self).__name__}") + raise RuntimeError( + f"cannot expand unrecognized operator type {type(self).__module__}.{type(self).__name__}" + ) + + from airflow.models.baseoperator import BaseOperator as DBBaseOperator try: - total_length: int | None = self.get_mapped_ti_count(run_id, session=session) + total_length: int | None = DBBaseOperator.get_mapped_ti_count(self, run_id, session=session) except NotFullyPopulated as e: # It's possible that the upstream tasks are not yet done, but we # don't have upstream of upstreams in partial DAGs (possible in the @@ -509,29 +349,3 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence ti.state = TaskInstanceState.REMOVED session.flush() return all_expanded_tis, total_expanded_ti_count - 1 - - def render_template_fields( - self, - context: Context, - jinja_env: jinja2.Environment | None = None, - ) -> None: - """ - Template all attributes listed in *self.template_fields*. - - If the operator is mapped, this should return the unmapped, fully - rendered, and map-expanded operator. The mapped operator should not be - modified. However, *context* may be modified in-place to reference the - unmapped operator for template rendering. - - If the operator is not mapped, this should modify the operator in-place. - """ - raise NotImplementedError() - - def __enter__(self): - if not self.is_setup and not self.is_teardown: - raise AirflowException("Only setup/teardown tasks can be used as context managers.") - SetupTeardownContext.push_setup_teardown_task(self) - return SetupTeardownContext - - def __exit__(self, exc_type, exc_val, exc_tb): - SetupTeardownContext.set_work_task_roots_and_leaves() diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 88a656eb7b1ed..b4931ac16d324 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -23,13 +23,12 @@ from __future__ import annotations -import collections.abc -import contextlib import functools import logging -from collections.abc import Collection, Iterable, Sequence +import operator +from collections.abc import Collection, Iterable, Iterator, Sequence from datetime import datetime, timedelta -from functools import wraps +from functools import singledispatchmethod, wraps from threading import local from types import FunctionType from typing import ( @@ -53,36 +52,27 @@ TaskDeferred, ) from airflow.lineage import apply_lineage, prepare_lineage + +# Keeping this file at all is a temp thing as we migrate the repo to the task sdk as the base, but to keep +# main working and useful for others to develop against we use the TaskSDK here but keep this file around from airflow.models.abstractoperator import ( - DEFAULT_EXECUTOR, - DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, - DEFAULT_OWNER, - DEFAULT_POOL_SLOTS, - DEFAULT_PRIORITY_WEIGHT, - DEFAULT_QUEUE, - DEFAULT_RETRIES, - DEFAULT_RETRY_DELAY, - DEFAULT_TASK_EXECUTION_TIMEOUT, - DEFAULT_TRIGGER_RULE, - DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, - DEFAULT_WEIGHT_RULE, AbstractOperator, + NotMapped, ) from airflow.models.base import _sentinel -from airflow.models.mappedoperator import OperatorPartial, validate_mapping_kwargs from airflow.models.taskinstance import TaskInstance, clear_task_instances from airflow.models.taskmixin import DependencyMixin from airflow.models.trigger import TRIGGER_FAIL_REPR, TriggerFailureReason +from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator as TaskSDKAbstractOperator from airflow.sdk.definitions.baseoperator import ( BaseOperatorMeta as TaskSDKBaseOperatorMeta, - get_merged_defaults, + get_merged_defaults as get_merged_defaults, # Re-export for compat ) - -# Keeping this file at all is a temp thing as we migrate the repo to the task sdk as the base, but to keep -# main working and useful for others to develop against we use the TaskSDK here but keep this file around from airflow.sdk.definitions.context import Context -from airflow.sdk.definitions.dag import DAG, BaseOperator as TaskSDKBaseOperator +from airflow.sdk.definitions.dag import BaseOperator as TaskSDKBaseOperator from airflow.sdk.definitions.edges import EdgeModifier as TaskSDKEdgeModifier +from airflow.sdk.definitions.mappedoperator import MappedOperator +from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup from airflow.serialization.enums import DagAttributeTypes from airflow.ti_deps.deps.mapped_task_upstream_dep import MappedTaskUpstreamDep from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep @@ -95,27 +85,19 @@ from airflow.utils.operator_helpers import ExecutionCallableRunner from airflow.utils.operator_resources import Resources from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.types import NOTSET, DagRunTriggeredByType +from airflow.utils.types import DagRunTriggeredByType from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: - from types import ClassMethodDescriptorType - from sqlalchemy.orm import Session from airflow.models.abstractoperator import TaskStateChangeCallback from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.dag import DAG as SchedulerDAG from airflow.models.operator import Operator - from airflow.task.priority_strategy import PriorityWeightStrategy + from airflow.sdk.definitions.node import DAGNode from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.triggers.base import BaseTrigger, StartTriggerArgs - from airflow.utils.types import ArgNotSet - - -# Todo: AIP-44: Once we get rid of AIP-44 we can remove this. But without this here pydantic fails to resolve -# types for serialization -from airflow.utils.task_group import TaskGroup # noqa: TC001 TaskPreExecuteHook = Callable[[Context], None] TaskPostExecuteHook = Callable[[Context, Any], None] @@ -153,193 +135,6 @@ def coerce_resources(resources: dict[str, Any] | None) -> Resources | None: return Resources(**resources) -class _PartialDescriptor: - """A descriptor that guards against ``.partial`` being called on Task objects.""" - - class_method: ClassMethodDescriptorType | None = None - - def __get__( - self, obj: BaseOperator, cls: type[BaseOperator] | None = None - ) -> Callable[..., OperatorPartial]: - # Call this "partial" so it looks nicer in stack traces. - def partial(**kwargs): - raise TypeError("partial can only be called on Operator classes, not Tasks themselves") - - if obj is not None: - return partial - return self.class_method.__get__(cls, cls) - - -_PARTIAL_DEFAULTS: dict[str, Any] = { - "map_index_template": None, - "owner": DEFAULT_OWNER, - "trigger_rule": DEFAULT_TRIGGER_RULE, - "depends_on_past": False, - "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, - "wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, - "wait_for_downstream": False, - "retries": DEFAULT_RETRIES, - "executor": DEFAULT_EXECUTOR, - "queue": DEFAULT_QUEUE, - "pool_slots": DEFAULT_POOL_SLOTS, - "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT, - "retry_delay": DEFAULT_RETRY_DELAY, - "retry_exponential_backoff": False, - "priority_weight": DEFAULT_PRIORITY_WEIGHT, - "weight_rule": DEFAULT_WEIGHT_RULE, - "inlets": [], - "outlets": [], - "allow_nested_operators": True, -} - - -# This is what handles the actual mapping. - -if TYPE_CHECKING: - - def partial( - operator_class: type[BaseOperator], - *, - task_id: str, - dag: DAG | None = None, - task_group: TaskGroup | None = None, - start_date: datetime | ArgNotSet = NOTSET, - end_date: datetime | ArgNotSet = NOTSET, - owner: str | ArgNotSet = NOTSET, - email: None | str | Iterable[str] | ArgNotSet = NOTSET, - params: collections.abc.MutableMapping | None = None, - resources: dict[str, Any] | None | ArgNotSet = NOTSET, - trigger_rule: str | ArgNotSet = NOTSET, - depends_on_past: bool | ArgNotSet = NOTSET, - ignore_first_depends_on_past: bool | ArgNotSet = NOTSET, - wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET, - wait_for_downstream: bool | ArgNotSet = NOTSET, - retries: int | None | ArgNotSet = NOTSET, - queue: str | ArgNotSet = NOTSET, - pool: str | ArgNotSet = NOTSET, - pool_slots: int | ArgNotSet = NOTSET, - execution_timeout: timedelta | None | ArgNotSet = NOTSET, - max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET, - retry_delay: timedelta | float | ArgNotSet = NOTSET, - retry_exponential_backoff: bool | ArgNotSet = NOTSET, - priority_weight: int | ArgNotSet = NOTSET, - weight_rule: str | PriorityWeightStrategy | ArgNotSet = NOTSET, - sla: timedelta | None | ArgNotSet = NOTSET, - map_index_template: str | None | ArgNotSet = NOTSET, - max_active_tis_per_dag: int | None | ArgNotSet = NOTSET, - max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET, - on_execute_callback: None - | TaskStateChangeCallback - | list[TaskStateChangeCallback] - | ArgNotSet = NOTSET, - on_failure_callback: None - | TaskStateChangeCallback - | list[TaskStateChangeCallback] - | ArgNotSet = NOTSET, - on_success_callback: None - | TaskStateChangeCallback - | list[TaskStateChangeCallback] - | ArgNotSet = NOTSET, - on_retry_callback: None - | TaskStateChangeCallback - | list[TaskStateChangeCallback] - | ArgNotSet = NOTSET, - on_skipped_callback: None - | TaskStateChangeCallback - | list[TaskStateChangeCallback] - | ArgNotSet = NOTSET, - run_as_user: str | None | ArgNotSet = NOTSET, - executor: str | None | ArgNotSet = NOTSET, - executor_config: dict | None | ArgNotSet = NOTSET, - inlets: Any | None | ArgNotSet = NOTSET, - outlets: Any | None | ArgNotSet = NOTSET, - doc: str | None | ArgNotSet = NOTSET, - doc_md: str | None | ArgNotSet = NOTSET, - doc_json: str | None | ArgNotSet = NOTSET, - doc_yaml: str | None | ArgNotSet = NOTSET, - doc_rst: str | None | ArgNotSet = NOTSET, - task_display_name: str | None | ArgNotSet = NOTSET, - logger_name: str | None | ArgNotSet = NOTSET, - allow_nested_operators: bool = True, - **kwargs, - ) -> OperatorPartial: ... -else: - - def partial( - operator_class: type[BaseOperator], - *, - task_id: str, - dag: DAG | None = None, - task_group: TaskGroup | None = None, - params: collections.abc.MutableMapping | None = None, - **kwargs, - ): - from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext - - validate_mapping_kwargs(operator_class, "partial", kwargs) - - dag = dag or DagContext.get_current() - if dag: - task_group = task_group or TaskGroupContext.get_current(dag) - if task_group: - task_id = task_group.child_id(task_id) - - # Merge DAG and task group level defaults into user-supplied values. - dag_default_args, partial_params = get_merged_defaults( - dag=dag, - task_group=task_group, - task_params=params, - task_default_args=kwargs.pop("default_args", None), - ) - - # Create partial_kwargs from args and kwargs - partial_kwargs: dict[str, Any] = { - "task_id": task_id, - "dag": dag, - "task_group": task_group, - **kwargs, - } - - # Inject DAG-level default args into args provided to this function. - partial_kwargs.update( - (k, v) for k, v in dag_default_args.items() if partial_kwargs.get(k, NOTSET) is NOTSET - ) - - # Fill fields not provided by the user with default values. - for k, v in _PARTIAL_DEFAULTS.items(): - partial_kwargs.setdefault(k, v) - - # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). - if "task_concurrency" in kwargs: # Reject deprecated option. - raise TypeError("unexpected argument: task_concurrency") - if wait := partial_kwargs.get("wait_for_downstream", False): - partial_kwargs["depends_on_past"] = wait - if start_date := partial_kwargs.get("start_date", None): - partial_kwargs["start_date"] = timezone.convert_to_utc(start_date) - if end_date := partial_kwargs.get("end_date", None): - partial_kwargs["end_date"] = timezone.convert_to_utc(end_date) - if partial_kwargs["pool_slots"] < 1: - dag_str = "" - if dag: - dag_str = f" in dag {dag.dag_id}" - raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") - if retries := partial_kwargs.get("retries"): - partial_kwargs["retries"] = parse_retries(retries) - partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay") - if partial_kwargs.get("max_retry_delay", None) is not None: - partial_kwargs["max_retry_delay"] = coerce_timedelta( - partial_kwargs["max_retry_delay"], - key="max_retry_delay", - ) - partial_kwargs.setdefault("executor_config", {}) - - return OperatorPartial( - operator_class=operator_class, - kwargs=partial_kwargs, - params=partial_params, - ) - - class ExecutorSafeguard: """ The ExecutorSafeguard decorator. @@ -388,12 +183,6 @@ def __new__(cls, name, bases, namespace, **kwargs): if callable(execute_method) and not getattr(execute_method, "__isabstractmethod__", False): namespace["execute"] = ExecutorSafeguard().decorator(execute_method) new_cls = super().__new__(cls, name, bases, namespace, **kwargs) - with contextlib.suppress(KeyError): - # Update the partial descriptor with the class method, so it calls the actual function - # (but let subclasses override it if they need to) - partial_desc = vars(new_cls)["partial"] - if isinstance(partial_desc, _PartialDescriptor): - partial_desc.class_method = classmethod(partial) return new_cls @@ -653,8 +442,6 @@ def dag(self, val: SchedulerDAG): # For type checking only ... - partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore - @classmethod @methodtools.lru_cache(maxsize=None) def get_serialized_fields(cls): @@ -1018,6 +805,89 @@ def expand_start_trigger_args(self, *, context: Context, session: Session) -> St """ return self.start_trigger_args + if TYPE_CHECKING: + + @classmethod + def get_mapped_ti_count( + cls, node: DAGNode | MappedTaskGroup, run_id: str, *, session: Session + ) -> int: + """ + Return the number of mapped TaskInstances that can be created at run time. + + This considers both literal and non-literal mapped arguments, and the + result is therefore available when all depended tasks have finished. The + return value should be identical to ``parse_time_mapped_ti_count`` if + all mapped arguments are literal. + + :raise NotFullyPopulated: If upstream tasks are not all complete yet. + :raise NotMapped: If the operator is neither mapped, nor has any parent + mapped task groups. + :return: Total number of mapped TIs this task should have. + """ + else: + + @singledispatchmethod + @classmethod + def get_mapped_ti_count(cls, task: DAGNode, run_id: str, *, session: Session) -> int: + raise NotImplementedError(f"Not implemented for {type(task)}") + + # https://github.com/python/cpython/issues/86153 + # WHile we support Python 3.9 we can't rely on the type hint, we need to pass the type explicitly to + # register. + @get_mapped_ti_count.register(TaskSDKAbstractOperator) + @classmethod + def _(cls, task: TaskSDKAbstractOperator, run_id: str, *, session: Session) -> int: + group = task.get_closest_mapped_task_group() + if group is None: + raise NotMapped() + return cls.get_mapped_ti_count(group, run_id, session=session) + + @get_mapped_ti_count.register(MappedOperator) + @classmethod + def _(cls, task: MappedOperator, run_id: str, *, session: Session) -> int: + from airflow.serialization.serialized_objects import _ExpandInputRef + + exp_input = task._get_specified_expand_input() + if isinstance(exp_input, _ExpandInputRef): + exp_input = exp_input.deref(task.dag) + current_count = exp_input.get_total_map_length(run_id, session=session) + + group = task.get_closest_mapped_task_group() + if group is None: + return current_count + parent_count = cls.get_mapped_ti_count(group, run_id, session=session) + return parent_count * current_count + + @get_mapped_ti_count.register(TaskGroup) + @classmethod + def _(cls, group: TaskGroup, run_id: str, *, session: Session) -> int: + """ + Return the number of instances a task in this group should be mapped to at run time. + + This considers both literal and non-literal mapped arguments, and the + result is therefore available when all depended tasks have finished. The + return value should be identical to ``parse_time_mapped_ti_count`` if + all mapped arguments are literal. + + If this group is inside mapped task groups, all the nested counts are + multiplied and accounted. + + :raise NotFullyPopulated: If upstream tasks are not all complete yet. + :return: Total number of mapped TIs this task should have. + """ + + def iter_mapped_task_groups(group) -> Iterator[MappedTaskGroup]: + while group is not None: + if isinstance(group, MappedTaskGroup): + yield group + group = group.parent_group + + groups = iter_mapped_task_groups(group) + return functools.reduce( + operator.mul, + (g._expand_input.get_total_map_length(run_id, session=session) for g in groups), + ) + def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None: r""" diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 7da56bca624ab..9b88e6d70f71c 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1600,6 +1600,7 @@ def test( :param mark_success_pattern: regex of task_ids to mark as success instead of running :param session: database connection (optional) """ + from airflow.serialization.serialized_objects import SerializedDAG def add_logger_if_needed(ti: TaskInstance): """ @@ -1642,8 +1643,10 @@ def add_logger_if_needed(ti: TaskInstance): self.log.debug("Getting dagrun for dag %s", self.dag_id) logical_date = timezone.coerce_datetime(logical_date) data_interval = self.timetable.infer_manual_data_interval(run_after=logical_date) + scheduler_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(self)) + dr: DagRun = _get_or_create_dagrun( - dag=self, + dag=scheduler_dag, start_date=logical_date, logical_date=logical_date, run_id=DagRun.generate_run_id(DagRunType.MANUAL, logical_date), diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 8060901f948a3..8795c0507087b 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -66,6 +66,7 @@ from airflow.models.expandinput import NotFullyPopulated from airflow.models.taskinstance import TaskInstance as TI from airflow.models.tasklog import LogTemplate +from airflow.models.taskmap import TaskMap from airflow.stats import Stats from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES @@ -85,6 +86,7 @@ from sqlalchemy.orm import Query, Session + from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG from airflow.models.operator import Operator from airflow.typing_compat import Literal @@ -1134,15 +1136,15 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: if ti.map_index >= 0: # Already expanded, we're good. return None - from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator - if isinstance(ti.task, MappedOperator): + if isinstance(ti.task, TaskSDKMappedOperator): # If we get here, it could be that we are moving from non-mapped to mapped # after task instance clearing or this ti is not yet expanded. Safe to clear # the db references. ti.clear_db_references(session=session) try: - expanded_tis, _ = ti.task.expand_mapped_task(self.run_id, session=session) + expanded_tis, _ = TaskMap.expand_mapped_task(ti.task, self.run_id, session=session) except NotMapped: # Not a mapped task, nothing needed. return None if expanded_tis: @@ -1155,7 +1157,7 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: revised_map_index_task_ids = set() for schedulable in itertools.chain(schedulable_tis, additional_tis): if TYPE_CHECKING: - assert schedulable.task + assert isinstance(schedulable.task, BaseOperator) old_state = schedulable.state if not schedulable.are_dependencies_met(session=session, dep_context=dep_context): old_states[schedulable.key] = old_state @@ -1327,6 +1329,8 @@ def _check_for_removed_or_restored_tasks( :return: Task IDs in the DAG run """ + from airflow.models.baseoperator import BaseOperator + tis = self.get_task_instances(session=session) # check for removed or restored tasks @@ -1362,7 +1366,7 @@ def _check_for_removed_or_restored_tasks( except NotFullyPopulated: # What if it is _now_ dynamically mapped, but wasn't before? try: - total_length = task.get_mapped_ti_count(self.run_id, session=session) + total_length = BaseOperator.get_mapped_ti_count(task, self.run_id, session=session) except NotFullyPopulated: # Not all upstreams finished, so we can't tell what should be here. Remove everything. if ti.map_index >= 0: @@ -1462,10 +1466,13 @@ def _create_tasks( :param tasks: Tasks to create jobs for in the DAG run :param task_creator: Function to create task instances """ + from airflow.models.baseoperator import BaseOperator + map_indexes: Iterable[int] for task in tasks: try: - count = task.get_mapped_ti_count(self.run_id, session=session) + breakpoint + count = BaseOperator.get_mapped_ti_count(task, self.run_id, session=session) except (NotMapped, NotFullyPopulated): map_indexes = (-1,) else: @@ -1531,10 +1538,11 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) -> we delay expansion to the "last resort". See comments at the call site for more details. """ + from airflow.models.baseoperator import BaseOperator from airflow.settings import task_instance_mutation_hook try: - total_length = task.get_mapped_ti_count(self.run_id, session=session) + total_length = BaseOperator.get_mapped_ti_count(task, self.run_id, session=session) except NotMapped: return # Not a mapped task, don't need to do anything. except NotFullyPopulated: @@ -1613,7 +1621,7 @@ def schedule_tis( schedulable_ti_ids = [] for ti in schedulable_tis: if TYPE_CHECKING: - assert ti.task + assert isinstance(ti.task, BaseOperator) if ( ti.task.inherits_from_empty_operator and not ti.task.on_execute_callback diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index fcbb55dc3d24d..8fb35f7032965 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -31,8 +31,8 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.operator import Operator from airflow.models.xcom_arg import XComArg + from airflow.sdk.types import Operator from airflow.serialization.serialized_objects import _ExpandInputRef from airflow.typing_compat import TypeGuard @@ -59,11 +59,6 @@ class MappedArgument(ResolveMixin): _input: ExpandInput _key: str - def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: - # TODO (AIP-42): Implement run-time task map length inspection. This is - # needed when we implement task mapping inside a mapped task group. - raise NotImplementedError() - def iter_references(self) -> Iterable[tuple[Operator, str]]: yield from self._input.iter_references() @@ -145,8 +140,11 @@ def _get_map_lengths(self, run_id: str, *, session: Session) -> dict[str, int]: # TODO: This initiates one database call for each XComArg. Would it be # more efficient to do one single db call and unpack the value here? def _get_length(v: OperatorExpandArgument) -> int | None: + from airflow.models.xcom_arg import get_task_map_length + if _needs_run_time_resolution(v): - return v.get_task_map_length(run_id, session=session) + return get_task_map_length(v, run_id, session=session) + # Unfortunately a user-defined TypeGuard cannot apply negative type # narrowing. https://github.com/python/typing/discussions/1013 if TYPE_CHECKING: @@ -243,9 +241,11 @@ def get_parse_time_mapped_ti_count(self) -> int: raise NotFullyPopulated({"expand_kwargs() argument"}) def get_total_map_length(self, run_id: str, *, session: Session) -> int: + from airflow.models.xcom_arg import get_task_map_length + if isinstance(self.value, collections.abc.Sized): return len(self.value) - length = self.value.get_task_map_length(run_id, session=session) + length = get_task_map_length(self.value, run_id, session=session) if length is None: raise NotFullyPopulated({"expand_kwargs() argument"}) return length diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 8c61bdc42fd11..e7352ad1323d3 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -17,226 +17,22 @@ # under the License. from __future__ import annotations -import collections.abc import contextlib import copy -import warnings -from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Union +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any import attrs -import methodtools -from airflow.models.abstractoperator import ( - DEFAULT_EXECUTOR, - DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, - DEFAULT_OWNER, - DEFAULT_POOL_SLOTS, - DEFAULT_PRIORITY_WEIGHT, - DEFAULT_QUEUE, - DEFAULT_RETRIES, - DEFAULT_RETRY_DELAY, - DEFAULT_TRIGGER_RULE, - DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, - DEFAULT_WEIGHT_RULE, - AbstractOperator, - NotMapped, -) -from airflow.models.expandinput import ( - DictOfListsExpandInput, - ListOfDictsExpandInput, - is_mappable, -) -from airflow.models.pool import Pool -from airflow.serialization.enums import DagAttributeTypes -from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy +from airflow.models.abstractoperator import AbstractOperator +from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator from airflow.triggers.base import StartTriggerArgs -from airflow.utils.context import context_update_for_unmapped -from airflow.utils.helpers import is_container, prevent_duplicates -from airflow.utils.task_instance_session import get_current_task_instance_session -from airflow.utils.types import NOTSET -from airflow.utils.xcom import XCOM_RETURN_KEY +from airflow.utils.helpers import prevent_duplicates if TYPE_CHECKING: - import datetime - from typing import Literal - - import jinja2 # Slow import. - import pendulum from sqlalchemy.orm.session import Session - from airflow.models.abstractoperator import ( - TaskStateChangeCallback, - ) - from airflow.models.baseoperator import BaseOperator - from airflow.models.baseoperatorlink import BaseOperatorLink - from airflow.models.dag import DAG - from airflow.models.expandinput import ( - ExpandInput, - OperatorExpandArgument, - OperatorExpandKwargsArgument, - ) - from airflow.models.operator import Operator - from airflow.models.param import ParamsDict - from airflow.models.xcom_arg import XComArg from airflow.sdk.definitions.context import Context - from airflow.ti_deps.deps.base_ti_dep import BaseTIDep - from airflow.utils.operator_resources import Resources - from airflow.utils.task_group import TaskGroup - from airflow.utils.trigger_rule import TriggerRule - - TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, list[TaskStateChangeCallback]] - - ValidationSource = Literal["expand", "partial"] - - -def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: - # use a dict so order of args is same as code order - unknown_args = value.copy() - for klass in op.mro(): - init = klass.__init__ # type: ignore[misc] - try: - param_names = init._BaseOperatorMeta__param_names - except AttributeError: - continue - for name in param_names: - value = unknown_args.pop(name, NOTSET) - if func != "expand": - continue - if value is NOTSET: - continue - if is_mappable(value): - continue - type_name = type(value).__name__ - error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}" - raise ValueError(error) - if not unknown_args: - return # If we have no args left to check: stop looking at the MRO chain. - - if len(unknown_args) == 1: - error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}" - else: - names = ", ".join(repr(n) for n in unknown_args) - error = f"unexpected keyword arguments {names}" - raise TypeError(f"{op.__name__}.{func}() got {error}") - - -def ensure_xcomarg_return_value(arg: Any) -> None: - from airflow.models.xcom_arg import XComArg - - if isinstance(arg, XComArg): - for operator, key in arg.iter_references(): - if key != XCOM_RETURN_KEY: - raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}") - elif not is_container(arg): - return - elif isinstance(arg, collections.abc.Mapping): - for v in arg.values(): - ensure_xcomarg_return_value(v) - elif isinstance(arg, collections.abc.Iterable): - for v in arg: - ensure_xcomarg_return_value(v) - - -@attrs.define(kw_only=True, repr=False) -class OperatorPartial: - """ - An "intermediate state" returned by ``BaseOperator.partial()``. - - This only exists at DAG-parsing time; the only intended usage is for the - user to call ``.expand()`` on it at some point (usually in a method chain) to - create a ``MappedOperator`` to add into the DAG. - """ - - operator_class: type[BaseOperator] - kwargs: dict[str, Any] - params: ParamsDict | dict - - _expand_called: bool = False # Set when expand() is called to ease user debugging. - - def __attrs_post_init__(self): - validate_mapping_kwargs(self.operator_class, "partial", self.kwargs) - - def __repr__(self) -> str: - args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items()) - return f"{self.operator_class.__name__}.partial({args})" - - def __del__(self): - if not self._expand_called: - try: - task_id = repr(self.kwargs["task_id"]) - except KeyError: - task_id = f"at {hex(id(self))}" - warnings.warn(f"Task {task_id} was never mapped!", category=UserWarning, stacklevel=1) - - def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator: - if not mapped_kwargs: - raise TypeError("no arguments to expand against") - validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs) - prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") - # Since the input is already checked at parse time, we can set strict - # to False to skip the checks on execution. - return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False) - - def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator: - from airflow.models.xcom_arg import XComArg - - if isinstance(kwargs, collections.abc.Sequence): - for item in kwargs: - if not isinstance(item, (XComArg, collections.abc.Mapping)): - raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") - elif not isinstance(kwargs, XComArg): - raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") - return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) - - def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: - from airflow.operators.empty import EmptyOperator - from airflow.sensors.base import BaseSensorOperator - - self._expand_called = True - ensure_xcomarg_return_value(expand_input.value) - - partial_kwargs = self.kwargs.copy() - task_id = partial_kwargs.pop("task_id") - dag = partial_kwargs.pop("dag") - task_group = partial_kwargs.pop("task_group") - start_date = partial_kwargs.pop("start_date", None) - end_date = partial_kwargs.pop("end_date", None) - - try: - operator_name = self.operator_class.custom_operator_name # type: ignore - except AttributeError: - operator_name = self.operator_class.__name__ - - op = MappedOperator( - operator_class=self.operator_class, - expand_input=expand_input, - partial_kwargs=partial_kwargs, - task_id=task_id, - params=self.params, - operator_extra_links=self.operator_class.operator_extra_links, - template_ext=self.operator_class.template_ext, - template_fields=self.operator_class.template_fields, - template_fields_renderers=self.operator_class.template_fields_renderers, - ui_color=self.operator_class.ui_color, - ui_fgcolor=self.operator_class.ui_fgcolor, - is_empty=issubclass(self.operator_class, EmptyOperator), - is_sensor=issubclass(self.operator_class, BaseSensorOperator), - task_module=self.operator_class.__module__, - task_type=self.operator_class.__name__, - operator_name=operator_name, - dag=dag, - task_group=task_group, - start_date=start_date, - end_date=end_date, - disallow_kwargs_override=strict, - # For classic operators, this points to expand_input because kwargs - # to BaseOperator.expand() contribute to operator arguments. - expand_input_attr="expand_input", - start_trigger_args=self.operator_class.start_trigger_args, - start_from_trigger=self.operator_class.start_from_trigger, - ) - return op @attrs.define( @@ -249,422 +45,12 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: # special here (the logic is only important for slots=True), we use Python's # built-in implementation, which works (as proven by good old BaseOperator). getstate_setstate=False, + repr=False, ) -class MappedOperator(AbstractOperator): +# TODO: Task-SDK: Multiple inheritance is a crime. There must be a better way +class MappedOperator(TaskSDKMappedOperator, AbstractOperator): # type: ignore[misc] # It complains about weight_rule being different """Object representing a mapped operator in a DAG.""" - # This attribute serves double purpose. For a "normal" operator instance - # loaded from DAG, this holds the underlying non-mapped operator class that - # can be used to create an unmapped operator for execution. For an operator - # recreated from a serialized DAG, however, this holds the serialized data - # that can be used to unmap this into a SerializedBaseOperator. - operator_class: type[BaseOperator] | dict[str, Any] - - expand_input: ExpandInput - partial_kwargs: dict[str, Any] - - # Needed for serialization. - task_id: str - params: ParamsDict | dict - deps: frozenset[BaseTIDep] = attrs.field(init=False) - operator_extra_links: Collection[BaseOperatorLink] - template_ext: Sequence[str] - template_fields: Collection[str] - template_fields_renderers: dict[str, str] - ui_color: str - ui_fgcolor: str - _is_empty: bool - _is_sensor: bool = False - _task_module: str - _task_type: str - _operator_name: str - start_trigger_args: StartTriggerArgs | None - start_from_trigger: bool - _needs_expansion: bool = True - - dag: DAG | None - task_group: TaskGroup | None - start_date: pendulum.DateTime | None - end_date: pendulum.DateTime | None - upstream_task_ids: set[str] = attrs.field(factory=set, init=False) - downstream_task_ids: set[str] = attrs.field(factory=set, init=False) - - _disallow_kwargs_override: bool - """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``. - - If *False*, values from ``expand_input`` under duplicate keys override those - under corresponding keys in ``partial_kwargs``. - """ - - _expand_input_attr: str - """Where to get kwargs to calculate expansion length against. - - This should be a name to call ``getattr()`` on. - """ - - supports_lineage: bool = False - - HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset( - ("parse_time_mapped_ti_count", "operator_class", "start_trigger_args", "start_from_trigger") - ) - - @deps.default - def _deps(self): - from airflow.models.baseoperator import BaseOperator - - return BaseOperator.deps - - def __hash__(self): - return id(self) - - def __repr__(self): - return f"" - - def __attrs_post_init__(self): - from airflow.models.xcom_arg import XComArg - - if self.get_closest_mapped_task_group() is not None: - raise NotImplementedError("operator expansion in an expanded task group is not yet supported") - - if self.task_group: - self.task_group.add(self) - if self.dag: - self.dag.add_task(self) - XComArg.apply_upstream_relationship(self, self.expand_input.value) - for k, v in self.partial_kwargs.items(): - if k in self.template_fields: - XComArg.apply_upstream_relationship(self, v) - - @methodtools.lru_cache(maxsize=None) - @classmethod - def get_serialized_fields(cls): - # Not using 'cls' here since we only want to serialize base fields. - return (frozenset(attrs.fields_dict(MappedOperator)) | {"task_type"}) - { - "_task_type", - "dag", - "deps", - "expand_input", # This is needed to be able to accept XComArg. - "task_group", - "upstream_task_ids", - "supports_lineage", - "_is_setup", - "_is_teardown", - "_on_failure_fail_dagrun", - } - - @property - def task_type(self) -> str: - """Implementing Operator.""" - return self._task_type - - @property - def operator_name(self) -> str: - return self._operator_name - - @property - def inherits_from_empty_operator(self) -> bool: - """Implementing Operator.""" - return self._is_empty - - @property - def roots(self) -> Sequence[AbstractOperator]: - """Implementing DAGNode.""" - return [self] - - @property - def leaves(self) -> Sequence[AbstractOperator]: - """Implementing DAGNode.""" - return [self] - - @property - def task_display_name(self) -> str: - return self.partial_kwargs.get("task_display_name") or self.task_id - - @property - def owner(self) -> str: # type: ignore[override] - return self.partial_kwargs.get("owner", DEFAULT_OWNER) - - @property - def email(self) -> None | str | Iterable[str]: - return self.partial_kwargs.get("email") - - @property - def map_index_template(self) -> None | str: - return self.partial_kwargs.get("map_index_template") - - @map_index_template.setter - def map_index_template(self, value: str | None) -> None: - self.partial_kwargs["map_index_template"] = value - - @property - def trigger_rule(self) -> TriggerRule: - return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE) - - @trigger_rule.setter - def trigger_rule(self, value): - self.partial_kwargs["trigger_rule"] = value - - @property - def is_setup(self) -> bool: - return bool(self.partial_kwargs.get("is_setup")) - - @is_setup.setter - def is_setup(self, value: bool) -> None: - self.partial_kwargs["is_setup"] = value - - @property - def is_teardown(self) -> bool: - return bool(self.partial_kwargs.get("is_teardown")) - - @is_teardown.setter - def is_teardown(self, value: bool) -> None: - self.partial_kwargs["is_teardown"] = value - - @property - def depends_on_past(self) -> bool: - return bool(self.partial_kwargs.get("depends_on_past")) - - @depends_on_past.setter - def depends_on_past(self, value: bool) -> None: - self.partial_kwargs["depends_on_past"] = value - - @property - def ignore_first_depends_on_past(self) -> bool: - value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST) - return bool(value) - - @ignore_first_depends_on_past.setter - def ignore_first_depends_on_past(self, value: bool) -> None: - self.partial_kwargs["ignore_first_depends_on_past"] = value - - @property - def wait_for_past_depends_before_skipping(self) -> bool: - value = self.partial_kwargs.get( - "wait_for_past_depends_before_skipping", DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING - ) - return bool(value) - - @wait_for_past_depends_before_skipping.setter - def wait_for_past_depends_before_skipping(self, value: bool) -> None: - self.partial_kwargs["wait_for_past_depends_before_skipping"] = value - - @property - def wait_for_downstream(self) -> bool: - return bool(self.partial_kwargs.get("wait_for_downstream")) - - @wait_for_downstream.setter - def wait_for_downstream(self, value: bool) -> None: - self.partial_kwargs["wait_for_downstream"] = value - - @property - def retries(self) -> int: - return self.partial_kwargs.get("retries", DEFAULT_RETRIES) - - @retries.setter - def retries(self, value: int) -> None: - self.partial_kwargs["retries"] = value - - @property - def queue(self) -> str: - return self.partial_kwargs.get("queue", DEFAULT_QUEUE) - - @queue.setter - def queue(self, value: str) -> None: - self.partial_kwargs["queue"] = value - - @property - def pool(self) -> str: - return self.partial_kwargs.get("pool", Pool.DEFAULT_POOL_NAME) - - @pool.setter - def pool(self, value: str) -> None: - self.partial_kwargs["pool"] = value - - @property - def pool_slots(self) -> int: - return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS) - - @pool_slots.setter - def pool_slots(self, value: int) -> None: - self.partial_kwargs["pool_slots"] = value - - @property - def execution_timeout(self) -> datetime.timedelta | None: - return self.partial_kwargs.get("execution_timeout") - - @execution_timeout.setter - def execution_timeout(self, value: datetime.timedelta | None) -> None: - self.partial_kwargs["execution_timeout"] = value - - @property - def max_retry_delay(self) -> datetime.timedelta | None: - return self.partial_kwargs.get("max_retry_delay") - - @max_retry_delay.setter - def max_retry_delay(self, value: datetime.timedelta | None) -> None: - self.partial_kwargs["max_retry_delay"] = value - - @property - def retry_delay(self) -> datetime.timedelta: - return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY) - - @retry_delay.setter - def retry_delay(self, value: datetime.timedelta) -> None: - self.partial_kwargs["retry_delay"] = value - - @property - def retry_exponential_backoff(self) -> bool: - return bool(self.partial_kwargs.get("retry_exponential_backoff")) - - @retry_exponential_backoff.setter - def retry_exponential_backoff(self, value: bool) -> None: - self.partial_kwargs["retry_exponential_backoff"] = value - - @property - def priority_weight(self) -> int: # type: ignore[override] - return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT) - - @priority_weight.setter - def priority_weight(self, value: int) -> None: - self.partial_kwargs["priority_weight"] = value - - @property - def weight_rule(self) -> PriorityWeightStrategy: # type: ignore[override] - return validate_and_load_priority_weight_strategy( - self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) - ) - - @weight_rule.setter - def weight_rule(self, value: str | PriorityWeightStrategy) -> None: - self.partial_kwargs["weight_rule"] = validate_and_load_priority_weight_strategy(value) - - @property - def max_active_tis_per_dag(self) -> int | None: - return self.partial_kwargs.get("max_active_tis_per_dag") - - @max_active_tis_per_dag.setter - def max_active_tis_per_dag(self, value: int | None) -> None: - self.partial_kwargs["max_active_tis_per_dag"] = value - - @property - def max_active_tis_per_dagrun(self) -> int | None: - return self.partial_kwargs.get("max_active_tis_per_dagrun") - - @max_active_tis_per_dagrun.setter - def max_active_tis_per_dagrun(self, value: int | None) -> None: - self.partial_kwargs["max_active_tis_per_dagrun"] = value - - @property - def resources(self) -> Resources | None: - return self.partial_kwargs.get("resources") - - @property - def on_execute_callback(self) -> TaskStateChangeCallbackAttrType: - return self.partial_kwargs.get("on_execute_callback") - - @on_execute_callback.setter - def on_execute_callback(self, value: TaskStateChangeCallbackAttrType) -> None: - self.partial_kwargs["on_execute_callback"] = value - - @property - def on_failure_callback(self) -> TaskStateChangeCallbackAttrType: - return self.partial_kwargs.get("on_failure_callback") - - @on_failure_callback.setter - def on_failure_callback(self, value: TaskStateChangeCallbackAttrType) -> None: - self.partial_kwargs["on_failure_callback"] = value - - @property - def on_retry_callback(self) -> TaskStateChangeCallbackAttrType: - return self.partial_kwargs.get("on_retry_callback") - - @on_retry_callback.setter - def on_retry_callback(self, value: TaskStateChangeCallbackAttrType) -> None: - self.partial_kwargs["on_retry_callback"] = value - - @property - def on_success_callback(self) -> TaskStateChangeCallbackAttrType: - return self.partial_kwargs.get("on_success_callback") - - @on_success_callback.setter - def on_success_callback(self, value: TaskStateChangeCallbackAttrType) -> None: - self.partial_kwargs["on_success_callback"] = value - - @property - def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType: - return self.partial_kwargs.get("on_skipped_callback") - - @on_skipped_callback.setter - def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) -> None: - self.partial_kwargs["on_skipped_callback"] = value - - @property - def run_as_user(self) -> str | None: - return self.partial_kwargs.get("run_as_user") - - @property - def executor(self) -> str | None: - return self.partial_kwargs.get("executor", DEFAULT_EXECUTOR) - - @property - def executor_config(self) -> dict: - return self.partial_kwargs.get("executor_config", {}) - - @property # type: ignore[override] - def inlets(self) -> list[Any]: # type: ignore[override] - return self.partial_kwargs.get("inlets", []) - - @inlets.setter - def inlets(self, value: list[Any]) -> None: # type: ignore[override] - self.partial_kwargs["inlets"] = value - - @property # type: ignore[override] - def outlets(self) -> list[Any]: # type: ignore[override] - return self.partial_kwargs.get("outlets", []) - - @outlets.setter - def outlets(self, value: list[Any]) -> None: # type: ignore[override] - self.partial_kwargs["outlets"] = value - - @property - def doc(self) -> str | None: - return self.partial_kwargs.get("doc") - - @property - def doc_md(self) -> str | None: - return self.partial_kwargs.get("doc_md") - - @property - def doc_json(self) -> str | None: - return self.partial_kwargs.get("doc_json") - - @property - def doc_yaml(self) -> str | None: - return self.partial_kwargs.get("doc_yaml") - - @property - def doc_rst(self) -> str | None: - return self.partial_kwargs.get("doc_rst") - - @property - def allow_nested_operators(self) -> bool: - return bool(self.partial_kwargs.get("allow_nested_operators")) - - def get_dag(self) -> DAG | None: - """Implement Operator.""" - return self.dag - - @property - def output(self) -> XComArg: - """Return reference to XCom pushed by current operator.""" - from airflow.models.xcom_arg import XComArg - - return XComArg(operator=self) - - def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: - """Implement DAGNode.""" - return DagAttributeTypes.OP, self.task_id - def _expand_mapped_kwargs( self, context: Mapping[str, Any], session: Session, *, include_xcom: bool ) -> tuple[Mapping[str, Any], set[int]]: @@ -770,136 +156,3 @@ def expand_start_trigger_args(self, *, context: Context, session: Session) -> St next_kwargs=next_kwargs, timeout=timeout, ) - - def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator: - """ - Get the "normal" Operator after applying the current mapping. - - The *resolve* argument is only used if ``operator_class`` is a real - class, i.e. if this operator is not serialized. If ``operator_class`` is - not a class (i.e. this DAG has been deserialized), this returns a - SerializedBaseOperator that "looks like" the actual unmapping result. - - If *resolve* is a two-tuple (context, session), the information is used - to resolve the mapped arguments into init arguments. If it is a mapping, - no resolving happens, the mapping directly provides those init arguments - resolved from mapped kwargs. - - :meta private: - """ - if isinstance(self.operator_class, type): - if isinstance(resolve, collections.abc.Mapping): - kwargs = resolve - elif resolve is not None: - kwargs, _ = self._expand_mapped_kwargs(*resolve, include_xcom=True) - else: - raise RuntimeError("cannot unmap a non-serialized operator without context") - kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override) - is_setup = kwargs.pop("is_setup", False) - is_teardown = kwargs.pop("is_teardown", False) - on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False) - op = self.operator_class(**kwargs, _airflow_from_mapped=True) - # We need to overwrite task_id here because BaseOperator further - # mangles the task_id based on the task hierarchy (namely, group_id - # is prepended, and '__N' appended to deduplicate). This is hacky, - # but better than duplicating the whole mangling logic. - op.task_id = self.task_id - op.is_setup = is_setup - op.is_teardown = is_teardown - op.on_failure_fail_dagrun = on_failure_fail_dagrun - return op - - # After a mapped operator is serialized, there's no real way to actually - # unmap it since we've lost access to the underlying operator class. - # This tries its best to simply "forward" all the attributes on this - # mapped operator to a new SerializedBaseOperator instance. - from airflow.serialization.serialized_objects import SerializedBaseOperator - - op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True) - for partial_attr, value in self.partial_kwargs.items(): - setattr(op, partial_attr, value) - SerializedBaseOperator.populate_operator(op, self.operator_class) - if self.dag is not None: # For Mypy; we only serialize tasks in a DAG so the check always satisfies. - SerializedBaseOperator.set_task_dag_references(op, self.dag) - return op - - def _get_specified_expand_input(self) -> ExpandInput: - """Input received from the expand call on the operator.""" - return getattr(self, self._expand_input_attr) - - def prepare_for_execution(self) -> MappedOperator: - # Since a mapped operator cannot be used for execution, and an unmapped - # BaseOperator needs to be created later (see render_template_fields), - # we don't need to create a copy of the MappedOperator here. - return self - - def iter_mapped_dependencies(self) -> Iterator[Operator]: - """Upstream dependencies that provide XComs used by this task for task mapping.""" - from airflow.models.xcom_arg import XComArg - - for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()): - yield operator - - @methodtools.lru_cache(maxsize=None) - def get_parse_time_mapped_ti_count(self) -> int: - current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count() - try: - parent_count = super().get_parse_time_mapped_ti_count() - except NotMapped: - return current_count - return parent_count * current_count - - def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: - from airflow.serialization.serialized_objects import _ExpandInputRef - - exp_input = self._get_specified_expand_input() - if isinstance(exp_input, _ExpandInputRef): - exp_input = exp_input.deref(self.dag) - current_count = exp_input.get_total_map_length(run_id, session=session) - try: - parent_count = super().get_mapped_ti_count(run_id, session=session) - except NotMapped: - return current_count - return parent_count * current_count - - def render_template_fields( - self, - context: Context, - jinja_env: jinja2.Environment | None = None, - ) -> None: - """ - Template all attributes listed in *self.template_fields*. - - This updates *context* to reference the map-expanded task and relevant - information, without modifying the mapped operator. The expanded task - in *context* is then rendered in-place. - - :param context: Context dict with values to apply on content. - :param jinja_env: Jinja environment to use for rendering. - """ - if not jinja_env: - jinja_env = self.get_template_env() - - # We retrieve the session here, stored by _run_raw_task in set_current_task_session - # context manager - we cannot pass the session via @provide_session because the signature - # of render_template_fields is defined by BaseOperator and there are already many subclasses - # overriding it, so changing the signature is not an option. However render_template_fields is - # always executed within "_run_raw_task" so we make sure that _run_raw_task uses the - # set_current_task_session context manager to store the session in the current task. - session = get_current_task_instance_session() - - mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session, include_xcom=True) - unmapped_task = self.unmap(mapped_kwargs) - context_update_for_unmapped(context, unmapped_task) - - # Since the operators that extend `BaseOperator` are not subclasses of - # `MappedOperator`, we need to call `_do_render_template_fields` from - # the unmapped task in order to call the operator method when we override - # it to customize the parsing of nested fields. - unmapped_task._do_render_template_fields( - parent=unmapped_task, - template_fields=self.template_fields, - context=context, - jinja_env=jinja_env, - seen_oids=seen_oids, - ) diff --git a/airflow/models/param.py b/airflow/models/param.py index a25d4bba8ee76..cd3ccec26a48a 100644 --- a/airflow/models/param.py +++ b/airflow/models/param.py @@ -28,9 +28,9 @@ from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: - from airflow.models.operator import Operator from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.dag import DAG + from airflow.sdk.types import Operator logger = logging.getLogger(__name__) diff --git a/airflow/models/renderedtifields.py b/airflow/models/renderedtifields.py index c56517f0815ca..f2d7d83920fff 100644 --- a/airflow/models/renderedtifields.py +++ b/airflow/models/renderedtifields.py @@ -47,8 +47,8 @@ from sqlalchemy.orm import Session from sqlalchemy.sql import FromClause - from airflow.models import Operator from airflow.models.taskinstance import TaskInstance + from airflow.sdk.types import Operator def get_serialized_template_fields(task: Operator): diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 5a605f40cd518..d61331dd620d3 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -106,9 +106,9 @@ from airflow.models.taskreschedule import TaskReschedule from airflow.models.xcom import LazyXComSelectSequence, XCom from airflow.plugins_manager import integrate_macros_plugins -from airflow.sdk.definitions._internal.contextmanager import _CURRENT_CONTEXT from airflow.sdk.definitions._internal.templater import SandboxedEnvironment from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef +from airflow.sdk.definitions.taskgroup import MappedTaskGroup from airflow.sentry import Sentry from airflow.settings import task_instance_mutation_hook from airflow.stats import Stats @@ -135,10 +135,8 @@ from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.sqlalchemy import ExecutorConfigType, ExtendedJSON, UtcDateTime from airflow.utils.state import DagRunState, State, TaskInstanceState -from airflow.utils.task_group import MappedTaskGroup from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.timeout import timeout -from airflow.utils.types import AttributeRemoved from airflow.utils.xcom import XCOM_RETURN_KEY TR = TaskReschedule @@ -161,7 +159,7 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG as SchedulerDAG, DagModel from airflow.models.dagrun import DagRun - from airflow.models.operator import Operator + from airflow.sdk.definitions._internal.abstractoperator import Operator from airflow.sdk.definitions.dag import DAG from airflow.sdk.types import OutletEventAccessorsProtocol, RuntimeTaskInstanceProtocol from airflow.timetables.base import DataInterval @@ -232,7 +230,7 @@ def _run_raw_task( :param session: SQLAlchemy ORM Session """ if TYPE_CHECKING: - assert ti.task + assert isinstance(ti.task, BaseOperator) ti.test_mode = test_mode ti.refresh_from_task(ti.task, pool_override=pool) @@ -374,6 +372,8 @@ def set_current_context(context: Context) -> Generator[Context, None, None]: This method should be called once per Task execution, before calling operator.execute. """ + from airflow.sdk.definitions._internal.contextmanager import _CURRENT_CONTEXT + _CURRENT_CONTEXT.append(context) try: yield context @@ -677,12 +677,14 @@ def _execute_task(task_instance: TaskInstance, context: Context, task_orig: Oper :meta private: """ - from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator task_to_execute = task_instance.task if TYPE_CHECKING: - assert task_to_execute + # TODO: TaskSDK this function will need 100% re-writing + # This only works with a "rich" BaseOperator, not the SDK version + assert isinstance(task_to_execute, BaseOperator) if isinstance(task_to_execute, MappedOperator): raise AirflowException("MappedOperator cannot be executed.") @@ -925,6 +927,7 @@ def _get_template_context( from airflow import macros from airflow.models.abstractoperator import NotMapped + from airflow.models.baseoperator import BaseOperator integrate_macros_plugins() @@ -934,10 +937,6 @@ def _get_template_context( assert task assert task.dag - if task.dag.__class__ is AttributeRemoved: - # TODO: Task-SDK: Remove this after AIP-44 code is removed - task.dag = dag # type: ignore[assignment] # required after deserialization - dag_run = task_instance.get_dagrun(session) data_interval = dag.get_run_data_interval(dag_run) @@ -1002,11 +1001,6 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]: return triggering_events - try: - expanded_ti_count: int | None = task.get_mapped_ti_count(task_instance.run_id, session=session) - except NotMapped: - expanded_ti_count = None - # NOTE: If you add to this dict, make sure to also update the following: # * Context in task_sdk/src/airflow/sdk/definitions/context.py # * KNOWN_CONTEXT_KEYS in airflow/utils/context.py @@ -1019,7 +1013,6 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]: "outlet_events": OutletEventAccessors(), "ds": ds, "ds_nodash": ds_nodash, - "expanded_ti_count": expanded_ti_count, "inlets": task.inlets, "inlet_events": InletEventsAccessors(task.inlets, session=session), "logical_date": logical_date, @@ -1032,7 +1025,7 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]: "prev_start_date_success": get_prev_start_date_success(), "prev_end_date_success": get_prev_end_date_success(), "run_id": task_instance.run_id, - "task": task, + "task": task, # type: ignore[typeddict-item] "task_instance": task_instance, "task_instance_key_str": f"{task.dag_id}__{task.task_id}__{ds_nodash}", "test_mode": task_instance.test_mode, @@ -1047,6 +1040,24 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]: }, "conn": ConnectionAccessor(), } + + try: + expanded_ti_count: int | None = BaseOperator.get_mapped_ti_count( + task, task_instance.run_id, session=session + ) + context["expanded_ti_count"] = expanded_ti_count + if expanded_ti_count: + context["_upstream_map_indexes"] = { # type: ignore[typeddict-unknown-key] + upstream.task_id: task_instance.get_relevant_upstream_map_indexes( + upstream, + expanded_ti_count, + session=session, + ) + for upstream in task.upstream_list + } + except NotMapped: + pass + # Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it # is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890 return context @@ -1205,12 +1216,7 @@ def _record_task_map_for_downstreams( :meta private: """ - from airflow.models.mappedoperator import MappedOperator - - # TODO: Task-SDK: Remove this after AIP-44 code is removed - if task.dag.__class__ is AttributeRemoved: - # required after deserialization - task.dag = dag # type: ignore[assignment] + from airflow.sdk.definitions.mappedoperator import MappedOperator if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate. return @@ -1253,6 +1259,8 @@ def _get_previous_dagrun( if dag is None: return None + if TYPE_CHECKING: + assert isinstance(dag, SchedulerDAG) dr = task_instance.get_dagrun(session=session) dr.dag = dag @@ -1538,6 +1546,10 @@ def _defer_task( ) -> TaskInstance: from airflow.models.trigger import Trigger + # TODO: TaskSDK add start_trigger_args to SDK definitions + if TYPE_CHECKING: + assert ti.task is None or isinstance(ti.task, BaseOperator) + timeout: timedelta | None if exception is not None: trigger_row = Trigger.from_object(exception.trigger) @@ -1910,6 +1922,7 @@ def _command_as_list( if hasattr(ti, "task") and getattr(ti.task, "dag", None) is not None: if TYPE_CHECKING: assert ti.task + assert isinstance(ti.task.dag, SchedulerDAG) dag = ti.task.dag else: dag = ti.dag_model @@ -2351,7 +2364,7 @@ def are_dependencies_met( def get_failed_dep_statuses(self, dep_context: DepContext | None = None, session: Session = NEW_SESSION): """Get failed Dependencies.""" if TYPE_CHECKING: - assert self.task + assert isinstance(self.task, BaseOperator) dep_context = dep_context or DepContext() for dep in dep_context.deps | self.task.deps: @@ -2449,6 +2462,7 @@ def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun: if getattr(self, "task", None) is not None: if TYPE_CHECKING: assert self.task + assert isinstance(self.task.dag, SchedulerDAG) dr.dag = self.task.dag # Record it in the instance for next time. This means that `self.logical_date` will work correctly set_committed_value(self, "dag_run", dr) @@ -2461,7 +2475,7 @@ def ensure_dag(cls, task_instance: TaskInstance, session: Session = NEW_SESSION) """Ensure that task has a dag object associated, might have been removed by serialization.""" if TYPE_CHECKING: assert task_instance.task - if task_instance.task.dag is None or task_instance.task.dag.__class__ is AttributeRemoved: + if task_instance.task.dag is None: task_instance.task.dag = DagBag(read_dags_from_db=True).get_dag( dag_id=task_instance.dag_id, session=session ) @@ -3123,7 +3137,7 @@ def fetch_handle_failure_context( try: if getattr(ti, "task", None) and context: if TYPE_CHECKING: - assert ti.task + assert isinstance(ti.task, BaseOperator) task = ti.task.unmap((context, session)) except Exception: cls.logger().error("Unable to unmap task to determine if we need to send an alert email") @@ -3220,7 +3234,7 @@ def get_template_context( """ if TYPE_CHECKING: assert self.task - assert self.task.dag + assert isinstance(self.task.dag, SchedulerDAG) return _get_template_context( task_instance=self, dag=self.task.dag, @@ -3238,7 +3252,7 @@ def get_rendered_template_fields(self, session: Session = NEW_SESSION) -> None: from airflow.models.renderedtifields import RenderedTaskInstanceFields if TYPE_CHECKING: - assert self.task + assert isinstance(self.task, BaseOperator) rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self, session=session) if rendered_task_instance_fields: @@ -3279,7 +3293,7 @@ def render_templates( the unmapped, fully rendered BaseOperator. The original ``self.task`` before replacement is returned. """ - from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator if not context: context = self.get_template_context() @@ -3292,9 +3306,6 @@ def render_templates( assert self.task assert ti.task - if ti.task.dag.__class__ is AttributeRemoved: - ti.task.dag = self.task.dag # type: ignore[assignment] - # If self.task is mapped, this call replaces self.task to point to the # unmapped BaseOperator created by this function! This is because the # MappedOperator is useless for template rendering, and we need to be @@ -3590,6 +3601,8 @@ def tg2(inp): :return: Specific map index or map indexes to pull, or ``None`` if we want to "whole" return value (i.e. no mapped task groups involved). """ + from airflow.models.baseoperator import BaseOperator + if TYPE_CHECKING: assert self.task @@ -3612,7 +3625,8 @@ def tg2(inp): # At this point we know the two tasks share a mapped task group, and we # should use a "partial" value. Let's break down the mapped ti count # between the ancestor and further expansion happened inside it. - ancestor_ti_count = common_ancestor.get_mapped_ti_count(self.run_id, session=session) + + ancestor_ti_count = BaseOperator.get_mapped_ti_count(common_ancestor, self.run_id, session=session) ancestor_map_index = self.map_index * ancestor_ti_count // ti_count # If the task is NOT further expanded inside the common ancestor, we @@ -3732,7 +3746,7 @@ def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> Mapp def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool: """Whether given operator is *further* mapped inside a task group.""" - from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator if isinstance(operator, MappedOperator): return True diff --git a/airflow/models/taskmap.py b/airflow/models/taskmap.py index 2702b906df034..fdd37f1f5b55c 100644 --- a/airflow/models/taskmap.py +++ b/airflow/models/taskmap.py @@ -21,15 +21,20 @@ import collections.abc import enum -from collections.abc import Collection +from collections.abc import Collection, Iterable, Sequence from typing import TYPE_CHECKING, Any -from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer, String +from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer, String, func, or_, select from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies -from airflow.utils.sqlalchemy import ExtendedJSON +from airflow.utils.db import exists_query +from airflow.utils.sqlalchemy import ExtendedJSON, with_row_locks +from airflow.utils.state import State, TaskInstanceState if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.models.dag import DAG as SchedulerDAG from airflow.models.taskinstance import TaskInstance @@ -114,3 +119,133 @@ def variant(self) -> TaskMapVariant: if self.keys is None: return TaskMapVariant.LIST return TaskMapVariant.DICT + + @classmethod + def expand_mapped_task(cls, task, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]: + """ + Create the mapped task instances for mapped task. + + :raise NotMapped: If this task does not need expansion. + :return: The newly created mapped task instances (if any) in ascending + order by map index, and the maximum map index value. + """ + from airflow.models.baseoperator import BaseOperator as DBBaseOperator + from airflow.models.expandinput import NotFullyPopulated + from airflow.models.taskinstance import TaskInstance + from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator + from airflow.settings import task_instance_mutation_hook + + if not isinstance(task, (BaseOperator, MappedOperator)): + raise RuntimeError( + f"cannot expand unrecognized operator type {type(task).__module__}.{type(task).__name__}" + ) + + try: + total_length: int | None = DBBaseOperator.get_mapped_ti_count(task, run_id, session=session) + except NotFullyPopulated as e: + if not task.dag or not task.dag.partial: + task.log.error( + "Cannot expand %r for run %s; missing upstream values: %s", + task, + run_id, + sorted(e.missing), + ) + total_length = None + + state: TaskInstanceState | None = None + unmapped_ti: TaskInstance | None = session.scalars( + select(TaskInstance).where( + TaskInstance.dag_id == task.dag_id, + TaskInstance.task_id == task.task_id, + TaskInstance.run_id == run_id, + TaskInstance.map_index == -1, + or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), + ) + ).one_or_none() + + all_expanded_tis: list[TaskInstance] = [] + + if unmapped_ti: + if TYPE_CHECKING: + assert task.dag is None or isinstance(task.dag, SchedulerDAG) + + # The unmapped task instance still exists and is unfinished, i.e. we + # haven't tried to run it before. + if total_length is None: + # If the DAG is partial, it's likely that the upstream tasks + # are not done yet, so the task can't fail yet. + if not task.dag or not task.dag.partial: + unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED + elif total_length < 1: + # If the upstream maps this to a zero-length value, simply mark + # the unmapped task instance as SKIPPED (if needed). + task.log.info( + "Marking %s as SKIPPED since the map has %d values to expand", + unmapped_ti, + total_length, + ) + unmapped_ti.state = TaskInstanceState.SKIPPED + else: + zero_index_ti_exists = exists_query( + TaskInstance.dag_id == task.dag_id, + TaskInstance.task_id == task.task_id, + TaskInstance.run_id == run_id, + TaskInstance.map_index == 0, + session=session, + ) + if not zero_index_ti_exists: + # Otherwise convert this into the first mapped index, and create + # TaskInstance for other indexes. + unmapped_ti.map_index = 0 + task.log.debug("Updated in place to become %s", unmapped_ti) + all_expanded_tis.append(unmapped_ti) + # execute hook for task instance map index 0 + task_instance_mutation_hook(unmapped_ti) + session.flush() + else: + task.log.debug("Deleting the original task instance: %s", unmapped_ti) + session.delete(unmapped_ti) + state = unmapped_ti.state + + if total_length is None or total_length < 1: + # Nothing to fixup. + indexes_to_map: Iterable[int] = () + else: + # Only create "missing" ones. + current_max_mapping = session.scalar( + select(func.max(TaskInstance.map_index)).where( + TaskInstance.dag_id == task.dag_id, + TaskInstance.task_id == task.task_id, + TaskInstance.run_id == run_id, + ) + ) + indexes_to_map = range(current_max_mapping + 1, total_length) + + for index in indexes_to_map: + # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings. + ti = TaskInstance(task, run_id=run_id, map_index=index, state=state) + task.log.debug("Expanding TIs upserted %s", ti) + task_instance_mutation_hook(ti) + ti = session.merge(ti) + ti.refresh_from_task(task) # session.merge() loses task information. + all_expanded_tis.append(ti) + + # Coerce the None case to 0 -- these two are almost treated identically, + # except the unmapped ti (if exists) is marked to different states. + total_expanded_ti_count = total_length or 0 + + # Any (old) task instances with inapplicable indexes (>= the total + # number we need) are set to "REMOVED". + query = select(TaskInstance).where( + TaskInstance.dag_id == task.dag_id, + TaskInstance.task_id == task.task_id, + TaskInstance.run_id == run_id, + TaskInstance.map_index >= total_expanded_ti_count, + ) + query = with_row_locks(query, of=TaskInstance, session=session, skip_locked=True) + to_update = session.scalars(query) + for ti in to_update: + ti.state = TaskInstanceState.REMOVED + session.flush() + return all_expanded_tis, total_expanded_ti_count - 1 diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 4bf91a68bee53..078a9e6ff5223 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -17,714 +17,100 @@ from __future__ import annotations -import contextlib -import inspect -import itertools -from collections.abc import Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Callable, Union, overload +from functools import singledispatch +from typing import TYPE_CHECKING from sqlalchemy import func, or_, select - -from airflow.exceptions import AirflowException, XComNotFound -from airflow.models import MappedOperator, TaskInstance -from airflow.models.abstractoperator import AbstractOperator -from airflow.models.taskmixin import DependencyMixin -from airflow.sdk.definitions._internal.mixins import ResolveMixin -from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet +from sqlalchemy.orm import Session + +from airflow.sdk.definitions._internal.types import ArgNotSet +from airflow.sdk.definitions.mappedoperator import MappedOperator +from airflow.sdk.definitions.xcom_arg import ( + ConcatXComArg, + MapXComArg, + PlainXComArg, + XComArg, + ZipXComArg, +) from airflow.utils.db import exists_query -from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.state import State -from airflow.utils.trigger_rule import TriggerRule from airflow.utils.xcom import XCOM_RETURN_KEY -if TYPE_CHECKING: - from sqlalchemy.orm import Session - - from airflow.models.operator import Operator - from airflow.sdk.definitions.baseoperator import BaseOperator - from airflow.sdk.definitions.dag import DAG - from airflow.utils.edgemodifier import EdgeModifier - -# Callable objects contained by MapXComArg. We only accept callables from -# the user, but deserialize them into strings in a serialized XComArg for -# safety (those callables are arbitrary user code). -MapCallables = Sequence[Union[Callable[[Any], Any], str]] - - -class XComArg(ResolveMixin, DependencyMixin): - """ - Reference to an XCom value pushed from another operator. - - The implementation supports:: - - xcomarg >> op - xcomarg << op - op >> xcomarg # By BaseOperator code - op << xcomarg # By BaseOperator code - - **Example**: The moment you get a result from any operator (decorated or regular) you can :: - - any_op = AnyOperator() - xcomarg = XComArg(any_op) - # or equivalently - xcomarg = any_op.output - my_op = MyOperator() - my_op >> xcomarg - - This object can be used in legacy Operators via Jinja. - - **Example**: You can make this result to be part of any generated string:: - - any_op = AnyOperator() - xcomarg = any_op.output - op1 = MyOperator(my_text_message=f"the value is {xcomarg}") - op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}") - - :param operator: Operator instance to which the XComArg references. - :param key: Key used to pull the XCom value. Defaults to *XCOM_RETURN_KEY*, - i.e. the referenced operator's return value. - """ - - @overload - def __new__(cls: type[XComArg], operator: Operator, key: str = XCOM_RETURN_KEY) -> XComArg: - """Execute when the user writes ``XComArg(...)`` directly.""" - - @overload - def __new__(cls: type[XComArg]) -> XComArg: - """Execute by Python internals from subclasses.""" - - def __new__(cls, *args, **kwargs) -> XComArg: - if cls is XComArg: - return PlainXComArg(*args, **kwargs) - return super().__new__(cls) - - @staticmethod - def iter_xcom_references(arg: Any) -> Iterator[tuple[Operator, str]]: - """ - Return XCom references in an arbitrary value. - - Recursively traverse ``arg`` and look for XComArg instances in any - collection objects, and instances with ``template_fields`` set. - """ - if isinstance(arg, ResolveMixin): - yield from arg.iter_references() - elif isinstance(arg, (tuple, set, list)): - for elem in arg: - yield from XComArg.iter_xcom_references(elem) - elif isinstance(arg, dict): - for elem in arg.values(): - yield from XComArg.iter_xcom_references(elem) - elif isinstance(arg, AbstractOperator): - for attr in arg.template_fields: - yield from XComArg.iter_xcom_references(getattr(arg, attr)) - - @staticmethod - def apply_upstream_relationship(op: DependencyMixin, arg: Any): - """ - Set dependency for XComArgs. - - This looks for XComArg objects in ``arg`` "deeply" (looking inside - collections objects and classes decorated with ``template_fields``), and - sets the relationship to ``op`` on any found. - """ - for operator, _ in XComArg.iter_xcom_references(arg): - op.set_upstream(operator) - - @property - def roots(self) -> list[Operator]: - """Required by DependencyMixin.""" - return [op for op, _ in self.iter_references()] - - @property - def leaves(self) -> list[Operator]: - """Required by DependencyMixin.""" - return [op for op, _ in self.iter_references()] - - def set_upstream( - self, - task_or_task_list: DependencyMixin | Sequence[DependencyMixin], - edge_modifier: EdgeModifier | None = None, - ): - """Proxy to underlying operator set_upstream method. Required by DependencyMixin.""" - for operator, _ in self.iter_references(): - operator.set_upstream(task_or_task_list, edge_modifier) - - def set_downstream( - self, - task_or_task_list: DependencyMixin | Sequence[DependencyMixin], - edge_modifier: EdgeModifier | None = None, - ): - """Proxy to underlying operator set_downstream method. Required by DependencyMixin.""" - for operator, _ in self.iter_references(): - operator.set_downstream(task_or_task_list, edge_modifier) - - def _serialize(self) -> dict[str, Any]: - """ - Serialize an XComArg. - - The implementation should be the inverse function to ``deserialize``, - returning a data dict converted from this XComArg derivative. DAG - serialization does not call this directly, but ``serialize_xcom_arg`` - instead, which adds additional information to dispatch deserialization - to the correct class. - """ - raise NotImplementedError() - - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - """ - Deserialize an XComArg. - - The implementation should be the inverse function to ``serialize``, - implementing given a data dict converted from this XComArg derivative, - how the original XComArg should be created. DAG serialization relies on - additional information added in ``serialize_xcom_arg`` to dispatch data - dicts to the correct ``_deserialize`` information, so this function does - not need to validate whether the incoming data contains correct keys. - """ - raise NotImplementedError() - - def map(self, f: Callable[[Any], Any]) -> MapXComArg: - return MapXComArg(self, [f]) - - def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: - return ZipXComArg([self, *others], fillvalue=fillvalue) - - def concat(self, *others: XComArg) -> ConcatXComArg: - return ConcatXComArg([self, *others]) - - def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: - """ - Inspect length of pushed value for task-mapping. - - This is used to determine how many task instances the scheduler should - create for a downstream using this XComArg for task-mapping. - - *None* may be returned if the depended XCom has not been pushed. - """ - raise NotImplementedError() - - def resolve( - self, context: Mapping[str, Any], session: Session | None = None, *, include_xcom: bool = True - ) -> Any: - """ - Pull XCom value. - - This should only be called during ``op.execute()`` with an appropriate - context (e.g. generated from ``TaskInstance.get_template_context()``). - Although the ``ResolveMixin`` parent mixin also has a ``resolve`` - protocol, this adds the optional ``session`` argument that some of the - subclasses need. - - :meta private: - """ - raise NotImplementedError() - - def __enter__(self): - if not self.operator.is_setup and not self.operator.is_teardown: - raise AirflowException("Only setup/teardown tasks can be used as context managers.") - SetupTeardownContext.push_setup_teardown_task(self.operator) - return SetupTeardownContext - - def __exit__(self, exc_type, exc_val, exc_tb): - SetupTeardownContext.set_work_task_roots_and_leaves() - - -class PlainXComArg(XComArg): - """ - Reference to one single XCom without any additional semantics. - - This class should not be accessed directly, but only through XComArg. The - class inheritance chain and ``__new__`` is implemented in this slightly - convoluted way because we want to - - a. Allow the user to continue using XComArg directly for the simple - semantics (see documentation of the base class for details). - b. Make ``isinstance(thing, XComArg)`` be able to detect all kinds of XCom - references. - c. Not allow many properties of PlainXComArg (including ``__getitem__`` and - ``__str__``) to exist on other kinds of XComArg implementations since - they don't make sense. - - :meta private: - """ - - def __init__(self, operator: Operator, key: str = XCOM_RETURN_KEY): - self.operator = operator - self.key = key - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, PlainXComArg): - return NotImplemented - return self.operator == other.operator and self.key == other.key - - def __getitem__(self, item: str) -> XComArg: - """Implement xcomresult['some_result_key'].""" - if not isinstance(item, str): - raise ValueError(f"XComArg only supports str lookup, received {type(item).__name__}") - return PlainXComArg(operator=self.operator, key=item) - - def __iter__(self): - """ - Override iterable protocol to raise error explicitly. - - The default ``__iter__`` implementation in Python calls ``__getitem__`` - with 0, 1, 2, etc. until it hits an ``IndexError``. This does not work - well with our custom ``__getitem__`` implementation, and results in poor - DAG-writing experience since a misplaced ``*`` expansion would create an - infinite loop consuming the entire DAG parser. - - This override catches the error eagerly, so an incorrectly implemented - DAG fails fast and avoids wasting resources on nonsensical iterating. - """ - raise TypeError("'XComArg' object is not iterable") - - def __repr__(self) -> str: - if self.key == XCOM_RETURN_KEY: - return f"XComArg({self.operator!r})" - return f"XComArg({self.operator!r}, {self.key!r})" - - def __str__(self) -> str: - """ - Backward compatibility for old-style jinja used in Airflow Operators. - - **Example**: to use XComArg at BashOperator:: - - BashOperator(cmd=f"... { xcomarg } ...") - - :return: - """ - xcom_pull_kwargs = [ - f"task_ids='{self.operator.task_id}'", - f"dag_id='{self.operator.dag_id}'", - ] - if self.key is not None: - xcom_pull_kwargs.append(f"key='{self.key}'") - - xcom_pull_str = ", ".join(xcom_pull_kwargs) - # {{{{ are required for escape {{ in f-string - xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_str}) }}}}" - return xcom_pull - - def _serialize(self) -> dict[str, Any]: - return {"task_id": self.operator.task_id, "key": self.key} - - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - return cls(dag.get_task(data["task_id"]), data["key"]) - - @property - def is_setup(self) -> bool: - return self.operator.is_setup - - @is_setup.setter - def is_setup(self, val: bool): - self.operator.is_setup = val - - @property - def is_teardown(self) -> bool: - return self.operator.is_teardown - - @is_teardown.setter - def is_teardown(self, val: bool): - self.operator.is_teardown = val +__all__ = ["XComArg", "get_task_map_length"] - @property - def on_failure_fail_dagrun(self) -> bool: - return self.operator.on_failure_fail_dagrun - - @on_failure_fail_dagrun.setter - def on_failure_fail_dagrun(self, val: bool): - self.operator.on_failure_fail_dagrun = val - - def as_setup(self) -> DependencyMixin: - for operator, _ in self.iter_references(): - operator.is_setup = True - return self - - def as_teardown( - self, - *, - setups: BaseOperator | Iterable[BaseOperator] | None = None, - on_failure_fail_dagrun: bool | None = None, - ): - for operator, _ in self.iter_references(): - operator.is_teardown = True - operator.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS - if on_failure_fail_dagrun is not None: - operator.on_failure_fail_dagrun = on_failure_fail_dagrun - if setups is not None: - setups = [setups] if isinstance(setups, DependencyMixin) else setups - for s in setups: - s.is_setup = True - s >> operator - return self - - def iter_references(self) -> Iterator[tuple[Operator, str]]: - yield self.operator, self.key - - def map(self, f: Callable[[Any], Any]) -> MapXComArg: - if self.key != XCOM_RETURN_KEY: - raise ValueError("cannot map against non-return XCom") - return super().map(f) - - def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: - if self.key != XCOM_RETURN_KEY: - raise ValueError("cannot map against non-return XCom") - return super().zip(*others, fillvalue=fillvalue) - - def concat(self, *others: XComArg) -> ConcatXComArg: - if self.key != XCOM_RETURN_KEY: - raise ValueError("cannot concatenate non-return XCom") - return super().concat(*others) - - def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: - from airflow.models.taskinstance import TaskInstance - from airflow.models.taskmap import TaskMap - from airflow.models.xcom import XCom - - dag_id = self.operator.dag_id - task_id = self.operator.task_id - is_mapped = isinstance(self.operator, MappedOperator) - - if is_mapped: - unfinished_ti_exists = exists_query( - TaskInstance.dag_id == dag_id, - TaskInstance.run_id == run_id, - TaskInstance.task_id == task_id, - # Special NULL treatment is needed because 'state' can be NULL. - # The "IN" part would produce "NULL NOT IN ..." and eventually - # "NULl = NULL", which is a big no-no in SQL. - or_( - TaskInstance.state.is_(None), - TaskInstance.state.in_(s.value for s in State.unfinished if s is not None), - ), - session=session, - ) - if unfinished_ti_exists: - return None # Not all of the expanded tis are done yet. - query = select(func.count(XCom.map_index)).where( - XCom.dag_id == dag_id, - XCom.run_id == run_id, - XCom.task_id == task_id, - XCom.map_index >= 0, - XCom.key == XCOM_RETURN_KEY, - ) - else: - query = select(TaskMap.length).where( - TaskMap.dag_id == dag_id, - TaskMap.run_id == run_id, - TaskMap.task_id == task_id, - TaskMap.map_index < 0, - ) - return session.scalar(query) - - # TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK - @provide_session - def resolve( - self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True - ) -> Any: - ti = context["ti"] - if TYPE_CHECKING: - assert isinstance(ti, TaskInstance) - task_id = self.operator.task_id - map_indexes = ti.get_relevant_upstream_map_indexes( - self.operator, - context["expanded_ti_count"], +if TYPE_CHECKING: + from airflow.models.expandinput import OperatorExpandArgument + + +@singledispatch +def get_task_map_length(xcom_arg: OperatorExpandArgument, run_id: str, *, session: Session) -> int | None: + # The base implementation -- specific XComArg subclasses have specialised implementations + raise NotImplementedError() + + +@get_task_map_length.register +def _(xcom_arg: PlainXComArg, run_id: str, *, session: Session): + from airflow.models.taskinstance import TaskInstance + from airflow.models.taskmap import TaskMap + from airflow.models.xcom import XCom + + dag_id = xcom_arg.operator.dag_id + task_id = xcom_arg.operator.task_id + is_mapped = isinstance(xcom_arg.operator, MappedOperator) + + if is_mapped: + unfinished_ti_exists = exists_query( + TaskInstance.dag_id == dag_id, + TaskInstance.run_id == run_id, + TaskInstance.task_id == task_id, + # Special NULL treatment is needed because 'state' can be NULL. + # The "IN" part would produce "NULL NOT IN ..." and eventually + # "NULl = NULL", which is a big no-no in SQL. + or_( + TaskInstance.state.is_(None), + TaskInstance.state.in_(s.value for s in State.unfinished if s is not None), + ), session=session, ) - - result = ti.xcom_pull( - task_ids=task_id, - map_indexes=map_indexes, - key=self.key, - default=NOTSET, + if unfinished_ti_exists: + return None # Not all of the expanded tis are done yet. + query = select(func.count(XCom.map_index)).where( + XCom.dag_id == dag_id, + XCom.run_id == run_id, + XCom.task_id == task_id, + XCom.map_index >= 0, + XCom.key == XCOM_RETURN_KEY, ) - if not isinstance(result, ArgNotSet): - return result - if self.key == XCOM_RETURN_KEY: - return None - if getattr(self.operator, "multiple_outputs", False): - # If the operator is set to have multiple outputs and it was not executed, - # we should return "None" instead of showing an error. This is because when - # multiple outputs XComs are created, the XCom keys associated with them will have - # different names than the predefined "XCOM_RETURN_KEY" and won't be found. - # Therefore, it's better to return "None" like we did above where self.key==XCOM_RETURN_KEY. - return None - raise XComNotFound(ti.dag_id, task_id, self.key) - - -def _get_callable_name(f: Callable | str) -> str: - """Try to "describe" a callable by getting its name.""" - if callable(f): - return f.__name__ - # Parse the source to find whatever is behind "def". For safety, we don't - # want to evaluate the code in any meaningful way! - with contextlib.suppress(Exception): - kw, name, _ = f.lstrip().split(None, 2) - if kw == "def": - return name - return "" - - -class _MapResult(Sequence): - def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: - self.value = value - self.callables = callables - - def __getitem__(self, index: Any) -> Any: - value = self.value[index] - - # In the worker, we can access all actual callables. Call them. - callables = [f for f in self.callables if callable(f)] - if len(callables) == len(self.callables): - for f in callables: - value = f(value) - return value - - # In the scheduler, we don't have access to the actual callables, nor do - # we want to run it since it's arbitrary code. This builds a string to - # represent the call chain in the UI or logs instead. - for v in self.callables: - value = f"{_get_callable_name(v)}({value})" - return value - - def __len__(self) -> int: - return len(self.value) - - -class MapXComArg(XComArg): - """ - An XCom reference with ``map()`` call(s) applied. - - This is based on an XComArg, but also applies a series of "transforms" that - convert the pulled XCom value. - - :meta private: - """ - - def __init__(self, arg: XComArg, callables: MapCallables) -> None: - for c in callables: - if getattr(c, "_airflow_is_task_decorator", False): - raise ValueError("map() argument must be a plain function, not a @task operator") - self.arg = arg - self.callables = callables - - def __repr__(self) -> str: - map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables) - return f"{self.arg!r}{map_calls}" - - def _serialize(self) -> dict[str, Any]: - return { - "arg": serialize_xcom_arg(self.arg), - "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables], - } - - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - # We are deliberately NOT deserializing the callables. These are shown - # in the UI, and displaying a function object is useless. - return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) - - def iter_references(self) -> Iterator[tuple[Operator, str]]: - yield from self.arg.iter_references() - - def map(self, f: Callable[[Any], Any]) -> MapXComArg: - # Flatten arg.map(f1).map(f2) into one MapXComArg. - return MapXComArg(self.arg, [*self.callables, f]) - - def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: - return self.arg.get_task_map_length(run_id, session=session) - - @provide_session - def resolve( - self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True - ) -> Any: - value = self.arg.resolve(context, session=session, include_xcom=include_xcom) - if not isinstance(value, (Sequence, dict)): - raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") - return _MapResult(value, self.callables) - - -class _ZipResult(Sequence): - def __init__(self, values: Sequence[Sequence | dict], *, fillvalue: Any = NOTSET) -> None: - self.values = values - self.fillvalue = fillvalue - - @staticmethod - def _get_or_fill(container: Sequence | dict, index: Any, fillvalue: Any) -> Any: - try: - return container[index] - except (IndexError, KeyError): - return fillvalue - - def __getitem__(self, index: Any) -> Any: - if index >= len(self): - raise IndexError(index) - return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values) - - def __len__(self) -> int: - lengths = (len(v) for v in self.values) - if isinstance(self.fillvalue, ArgNotSet): - return min(lengths) - return max(lengths) - - -class ZipXComArg(XComArg): - """ - An XCom reference with ``zip()`` applied. - - This is constructed from multiple XComArg instances, and presents an - iterable that "zips" them together like the built-in ``zip()`` (and - ``itertools.zip_longest()`` if ``fillvalue`` is provided). - """ - - def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None: - if not args: - raise ValueError("At least one input is required") - self.args = args - self.fillvalue = fillvalue - - def __repr__(self) -> str: - args_iter = iter(self.args) - first = repr(next(args_iter)) - rest = ", ".join(repr(arg) for arg in args_iter) - if isinstance(self.fillvalue, ArgNotSet): - return f"{first}.zip({rest})" - return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})" - - def _serialize(self) -> dict[str, Any]: - args = [serialize_xcom_arg(arg) for arg in self.args] - if isinstance(self.fillvalue, ArgNotSet): - return {"args": args} - return {"args": args, "fillvalue": self.fillvalue} - - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - return cls( - [deserialize_xcom_arg(arg, dag) for arg in data["args"]], - fillvalue=data.get("fillvalue", NOTSET), + else: + query = select(TaskMap.length).where( + TaskMap.dag_id == dag_id, + TaskMap.run_id == run_id, + TaskMap.task_id == task_id, + TaskMap.map_index < 0, ) - - def iter_references(self) -> Iterator[tuple[Operator, str]]: - for arg in self.args: - yield from arg.iter_references() - - def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: - all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args) - ready_lengths = [length for length in all_lengths if length is not None] - if len(ready_lengths) != len(self.args): - return None # If any of the referenced XComs is not ready, we are not ready either. - if isinstance(self.fillvalue, ArgNotSet): - return min(ready_lengths) - return max(ready_lengths) - - @provide_session - def resolve( - self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True - ) -> Any: - values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args] - for value in values: - if not isinstance(value, (Sequence, dict)): - raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}") - return _ZipResult(values, fillvalue=self.fillvalue) - - -class _ConcatResult(Sequence): - def __init__(self, values: Sequence[Sequence | dict]) -> None: - self.values = values - - def __getitem__(self, index: Any) -> Any: - if index >= 0: - i = index - else: - i = len(self) + index - for value in self.values: - if i < 0: - break - elif i >= (curlen := len(value)): - i -= curlen - elif isinstance(value, Sequence): - return value[i] - else: - return next(itertools.islice(iter(value), i, None)) - raise IndexError("list index out of range") - - def __len__(self) -> int: - return sum(len(v) for v in self.values) - - -class ConcatXComArg(XComArg): - """ - Concatenating multiple XCom references into one. - - This is done by calling ``concat()`` on an XComArg to combine it with - others. The effect is similar to Python's :func:`itertools.chain`, but the - return value also supports index access. - """ - - def __init__(self, args: Sequence[XComArg]) -> None: - if not args: - raise ValueError("At least one input is required") - self.args = args - - def __repr__(self) -> str: - args_iter = iter(self.args) - first = repr(next(args_iter)) - rest = ", ".join(repr(arg) for arg in args_iter) - return f"{first}.concat({rest})" - - def _serialize(self) -> dict[str, Any]: - return {"args": [serialize_xcom_arg(arg) for arg in self.args]} - - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - return cls([deserialize_xcom_arg(arg, dag) for arg in data["args"]]) - - def iter_references(self) -> Iterator[tuple[Operator, str]]: - for arg in self.args: - yield from arg.iter_references() - - def concat(self, *others: XComArg) -> ConcatXComArg: - # Flatten foo.concat(x).concat(y) into one call. - return ConcatXComArg([*self.args, *others]) - - def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: - all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args) - ready_lengths = [length for length in all_lengths if length is not None] - if len(ready_lengths) != len(self.args): - return None # If any of the referenced XComs is not ready, we are not ready either. - return sum(ready_lengths) - - @provide_session - def resolve( - self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True - ) -> Any: - values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args] - for value in values: - if not isinstance(value, (Sequence, dict)): - raise ValueError(f"XCom concat expects sequence or dict, not {type(value).__name__}") - return _ConcatResult(values) + return session.scalar(query) -_XCOM_ARG_TYPES: Mapping[str, type[XComArg]] = { - "": PlainXComArg, - "concat": ConcatXComArg, - "map": MapXComArg, - "zip": ZipXComArg, -} +@get_task_map_length.register +def _(xcom_arg: MapXComArg, run_id: str, *, session: Session): + return get_task_map_length(xcom_arg.arg, run_id, session=session) -def serialize_xcom_arg(value: XComArg) -> dict[str, Any]: - """DAG serialization interface.""" - key = next(k for k, v in _XCOM_ARG_TYPES.items() if isinstance(value, v)) - if key: - return {"type": key, **value._serialize()} - return value._serialize() +@get_task_map_length.register +def _(xcom_arg: ZipXComArg, run_id: str, *, session: Session): + all_lengths = (get_task_map_length(arg, run_id, session=session) for arg in xcom_arg.args) + ready_lengths = [length for length in all_lengths if length is not None] + if len(ready_lengths) != len(xcom_arg.args): + return None # If any of the referenced XComs is not ready, we are not ready either. + if isinstance(xcom_arg.fillvalue, ArgNotSet): + return min(ready_lengths) + return max(ready_lengths) -def deserialize_xcom_arg(data: dict[str, Any], dag: DAG) -> XComArg: - """DAG serialization interface.""" - klass = _XCOM_ARG_TYPES[data.get("type", "")] - return klass._deserialize(data, dag) +@get_task_map_length.register +def _(xcom_arg: ConcatXComArg, run_id: str, *, session: Session): + all_lengths = (get_task_map_length(arg, run_id, session=session) for arg in xcom_arg.args) + ready_lengths = [length for length in all_lengths if length is not None] + if len(ready_lengths) != len(xcom_arg.args): + return None # If any of the referenced XComs is not ready, we are not ready either. + return sum(ready_lengths) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 88e0f200bb24d..b7e08a45aed74 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -47,11 +47,9 @@ create_expand_input, get_map_type_key, ) -from airflow.models.mappedoperator import MappedOperator from airflow.models.param import Param, ParamsDict from airflow.models.taskinstance import SimpleTaskInstance from airflow.models.taskinstancekey import TaskInstanceKey -from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg from airflow.providers_manager import ProvidersManager from airflow.sdk.definitions.asset import ( Asset, @@ -65,6 +63,9 @@ BaseAsset, ) from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator +from airflow.sdk.definitions.mappedoperator import MappedOperator +from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup +from airflow.sdk.definitions.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors from airflow.serialization.dag_dependency import DagDependency from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding @@ -87,9 +88,8 @@ from airflow.utils.docs import get_docs_url from airflow.utils.module_loading import import_string, qualname from airflow.utils.operator_resources import Resources -from airflow.utils.task_group import MappedTaskGroup, TaskGroup from airflow.utils.timezone import from_timestamp, parse_timezone -from airflow.utils.types import NOTSET, ArgNotSet, AttributeRemoved +from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: from inspect import Parameter @@ -97,8 +97,8 @@ from airflow.models import DagRun from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.expandinput import ExpandInput - from airflow.models.operator import Operator from airflow.sdk.definitions._internal.node import DAGNode + from airflow.sdk.types import Operator from airflow.serialization.json_schema import Validator from airflow.timetables.base import DagRunInfo, DataInterval, Timetable @@ -1339,7 +1339,7 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None: if v is False: raise RuntimeError("_is_sensor=False should never have been serialized!") - object.__setattr__(op, "deps", op.deps | {ReadyToRescheduleDep()}) + object.__setattr__(op, "deps", op.deps | {ReadyToRescheduleDep()}) # type: ignore[union-attr] continue elif ( k in cls._decorated_fields @@ -1408,13 +1408,18 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator: op: Operator if encoded_op.get("_is_mapped", False): # Most of these will be loaded later, these are just some stand-ins. - op_data = {k: v for k, v in encoded_op.items() if k in BaseOperator.get_serialized_fields()} + op_data = { + k: v for k, v in encoded_op.items() if k in TaskSDKBaseOperator.get_serialized_fields() + } + + from airflow.models.mappedoperator import MappedOperator as MappedOperatorWithDB + try: operator_name = encoded_op["_operator_name"] except KeyError: operator_name = encoded_op["task_type"] - op = MappedOperator( + op = MappedOperatorWithDB( operator_class=op_data, expand_input=EXPAND_INPUT_EMPTY, partial_kwargs={}, @@ -1442,7 +1447,6 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator: ) else: op = SerializedBaseOperator(task_id=encoded_op["task_id"]) - op.dag = AttributeRemoved("dag") # type: ignore[assignment] cls.populate_operator(op, encoded_op) return op @@ -1455,12 +1459,7 @@ def detect_dependencies(cls, op: Operator) -> set[DagDependency]: @classmethod def _is_excluded(cls, var: Any, attrname: str, op: DAGNode): - if ( - var is not None - and op.has_dag() - and op.dag.__class__ is not AttributeRemoved - and attrname.endswith("_date") - ): + if var is not None and op.has_dag() and attrname.endswith("_date"): # If this date is the same as the matching field in the dag, then # don't store it again at the task level. dag_date = getattr(op.dag, attrname, None) diff --git a/airflow/ti_deps/deps/mapped_task_upstream_dep.py b/airflow/ti_deps/deps/mapped_task_upstream_dep.py index 247dc84f3b478..97531ef4257e6 100644 --- a/airflow/ti_deps/deps/mapped_task_upstream_dep.py +++ b/airflow/ti_deps/deps/mapped_task_upstream_dep.py @@ -51,7 +51,7 @@ def _get_dep_statuses( session: Session, dep_context: DepContext, ) -> Iterator[TIDepStatus]: - from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator if isinstance(ti.task, MappedOperator): mapped_dependencies = ti.task.iter_mapped_dependencies() diff --git a/airflow/ti_deps/deps/prev_dagrun_dep.py b/airflow/ti_deps/deps/prev_dagrun_dep.py index c756e6ec1c645..9ce5c1134240a 100644 --- a/airflow/ti_deps/deps/prev_dagrun_dep.py +++ b/airflow/ti_deps/deps/prev_dagrun_dep.py @@ -32,7 +32,7 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.operator import Operator + from airflow.sdk.types import Operator _SUCCESSFUL_STATES = (TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS) diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index ccd11ab303e66..f49e62d4b4605 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -139,10 +139,12 @@ def _get_expanded_ti_count() -> int: This extra closure allows us to query the database only when needed, and at most once. """ + from airflow.models.baseoperator import BaseOperator + if TYPE_CHECKING: assert ti.task - return ti.task.get_mapped_ti_count(ti.run_id, session=session) + return BaseOperator.get_mapped_ti_count(ti.task, ti.run_id, session=session) @functools.lru_cache def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None: diff --git a/airflow/utils/context.py b/airflow/utils/context.py index a36202f0793ec..74479a328a1a2 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -62,7 +62,7 @@ from sqlalchemy.orm import Session from sqlalchemy.sql.expression import Select, TextClause - from airflow.models.baseoperator import BaseOperator + from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.types import OutletEventAccessorsProtocol # NOTE: Please keep this in sync with the following: diff --git a/airflow/utils/dag_edges.py b/airflow/utils/dag_edges.py index bd1ad268aefed..aafecbf308232 100644 --- a/airflow/utils/dag_edges.py +++ b/airflow/utils/dag_edges.py @@ -18,11 +18,11 @@ from typing import TYPE_CHECKING -from airflow.models.abstractoperator import AbstractOperator +from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator if TYPE_CHECKING: - from airflow.models import Operator - from airflow.models.dag import DAG + from airflow.sdk import DAG + from airflow.sdk.types import Operator def dag_edges(dag: DAG): diff --git a/airflow/utils/dot_renderer.py b/airflow/utils/dot_renderer.py index 877d16450d700..c4f1e45dc1f5e 100644 --- a/airflow/utils/dot_renderer.py +++ b/airflow/utils/dot_renderer.py @@ -36,8 +36,8 @@ graphviz = None from airflow.exceptions import AirflowException -from airflow.models.baseoperator import BaseOperator -from airflow.models.mappedoperator import MappedOperator +from airflow.sdk import BaseOperator +from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.utils.dag_edges import dag_edges from airflow.utils.state import State from airflow.utils.task_group import TaskGroup diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index a8abfe0ea8638..73ee79126a9ee 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -282,7 +282,8 @@ def _render_filename(self, ti: TaskInstance, try_number: int, session=NEW_SESSIO if str_tpl: if ti.task is not None and ti.task.dag is not None: dag = ti.task.dag - data_interval = dag.get_run_data_interval(dag_run) + # TODO: TaskSDK: why do we need this on the DAG! Where is this render fn called from. Revisit + data_interval = dag.get_run_data_interval(dag_run) # type: ignore[attr-defined] else: from airflow.timetables.base import DataInterval diff --git a/airflow/utils/setup_teardown.py b/airflow/utils/setup_teardown.py index 37e1a03457971..3108657d30ac2 100644 --- a/airflow/utils/setup_teardown.py +++ b/airflow/utils/setup_teardown.py @@ -22,9 +22,9 @@ from airflow.exceptions import AirflowException if TYPE_CHECKING: - from airflow.models.abstractoperator import AbstractOperator from airflow.models.taskmixin import DependencyMixin from airflow.models.xcom_arg import PlainXComArg + from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator class BaseSetupTeardownContext: diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index dafd87ab7df02..b74921f75d533 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -29,29 +29,28 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.operator import Operator from airflow.typing_compat import TypeAlias TaskGroup: TypeAlias = airflow.sdk.definitions.taskgroup.TaskGroup -class MappedTaskGroup(airflow.sdk.definitions.taskgroup.MappedTaskGroup): - """ - A mapped task group. - - This doesn't really do anything special, just holds some additional metadata - for expansion later. +class MappedTaskGroup(airflow.sdk.definitions.taskgroup.MappedTaskGroup): # noqa: D101 + # TODO: Rename this to SerializedMappedTaskGroup perhaps? - Don't instantiate this class directly; call *expand* or *expand_kwargs* on - a ``@task_group`` function instead. - """ + def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: + """ + Return mapped task groups in the hierarchy. - def iter_mapped_dependencies(self) -> Iterator[Operator]: - """Upstream dependencies that provide XComs used by this mapped task group.""" - from airflow.models.xcom_arg import XComArg + Groups are returned from the closest to the outmost. If *self* is a + mapped task group, it is returned first. - for op, _ in XComArg.iter_xcom_references(self._expand_input): - yield op + :meta private: + """ + group: TaskGroup | None = self + while group is not None: + if isinstance(group, MappedTaskGroup): + yield group + group = group.parent_group def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: """ @@ -79,9 +78,9 @@ def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: def task_group_to_dict(task_item_or_group): """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" - from airflow.models.abstractoperator import AbstractOperator - from airflow.models.baseoperator import BaseOperator - from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator + from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator if isinstance(task := task_item_or_group, AbstractOperator): setup_teardown_type = {} diff --git a/airflow/utils/types.py b/airflow/utils/types.py index 12870210efb82..46f295c4ee21a 100644 --- a/airflow/utils/types.py +++ b/airflow/utils/types.py @@ -31,32 +31,6 @@ NOTSET = airflow.sdk.definitions._internal.types.NOTSET -class AttributeRemoved: - """ - Sentinel type to signal when attribute removed on serialization. - - :meta private: - """ - - def __init__(self, attribute_name: str): - self.attribute_name = attribute_name - - def __getattr__(self, item): - if item == "attribute_name": - return super().__getattribute__(item) - raise RuntimeError( - f"Attribute {self.attribute_name} was removed on " - f"serialization and must be set again - found when accessing {item}." - ) - - -""" -Sentinel value for attributes removed on serialization. - -:meta private: -""" - - class DagRunType(str, enum.Enum): """Class with DagRun types.""" diff --git a/airflow/www/views.py b/airflow/www/views.py index bad206e183089..dd5bc856f74a3 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -440,7 +440,8 @@ def set_overall_state(record): "id": item.task_id, "instances": instances, "label": item.label, - "extra_links": item.extra_links, + # TODO: Task-SDK: MappedOperator doesn't support extra links right now + "extra_links": getattr(item, "extra_links", []), "is_mapped": item_is_mapped, "has_outlet_assets": any(isinstance(i, (Asset, AssetAlias)) for i in (item.outlets or [])), "operator": item.operator_name, diff --git a/providers/src/airflow/providers/elasticsearch/log/es_task_handler.py b/providers/src/airflow/providers/elasticsearch/log/es_task_handler.py index 9881f1ac5dca7..15904e7ebf3b4 100644 --- a/providers/src/airflow/providers/elasticsearch/log/es_task_handler.py +++ b/providers/src/airflow/providers/elasticsearch/log/es_task_handler.py @@ -241,7 +241,8 @@ def _render_log_id(self, ti: TaskInstance | TaskInstanceKey, try_number: int) -> else: if TYPE_CHECKING: assert dag is not None - data_interval = dag.get_run_data_interval(dag_run) + # TODO: Task-SDK: Where should this function be? + data_interval = dag.get_run_data_interval(dag_run) # type: ignore[attr-defined] if self.json_format: data_interval_start = self._clean_date(data_interval[0]) diff --git a/providers/src/airflow/providers/openlineage/utils/selective_enable.py b/providers/src/airflow/providers/openlineage/utils/selective_enable.py index a3c16a1e18da3..3b6331a7cd6d7 100644 --- a/providers/src/airflow/providers/openlineage/utils/selective_enable.py +++ b/providers/src/airflow/providers/openlineage/utils/selective_enable.py @@ -18,11 +18,20 @@ from __future__ import annotations import logging -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar -from airflow.models import DAG, Operator, Param +from airflow.models import Param from airflow.models.xcom_arg import XComArg +if TYPE_CHECKING: + from airflow.sdk import DAG + from airflow.sdk.definitions._internal.abstractoperator import Operator +else: + try: + from airflow.sdk import DAG + except ImportError: + from airflow.models import DAG + ENABLE_OL_PARAM_NAME = "_selective_enable_ol" ENABLE_OL_PARAM = Param(True, const=True) DISABLE_OL_PARAM = Param(False, const=False) diff --git a/providers/src/airflow/providers/openlineage/utils/utils.py b/providers/src/airflow/providers/openlineage/utils/utils.py index 4408a833fba68..734437f44adc0 100644 --- a/providers/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/src/airflow/providers/openlineage/utils/utils.py @@ -32,7 +32,7 @@ from airflow import __version__ as AIRFLOW_VERSION # TODO: move this maybe to Airflow's logic? -from airflow.models import DAG, BaseOperator, DagRun, MappedOperator, TaskReschedule +from airflow.models import DagRun, TaskReschedule from airflow.providers.openlineage import ( __version__ as OPENLINEAGE_PROVIDER_VERSION, conf, @@ -70,13 +70,19 @@ from airflow.models import TaskInstance from airflow.providers.common.compat.assets import Asset + from airflow.sdk import DAG, BaseOperator, MappedOperator from airflow.utils.state import DagRunState, TaskInstanceState else: + try: + from airflow.sdk import DAG, BaseOperator, MappedOperator + except ImportError: + from airflow.models import DAG, BaseOperator, MappedOperator + try: from airflow.providers.common.compat.assets import Asset except ImportError: if AIRFLOW_V_3_0_PLUS: - from airflow.sdk.definitions.asset import Asset + from airflow.sdk import Asset else: # dataset is renamed to asset since Airflow 3.0 from airflow.datasets import Dataset as Asset @@ -565,8 +571,8 @@ def _emits_ol_events(task: BaseOperator | MappedOperator) -> bool: is_skipped_as_empty_operator = all( ( task.inherits_from_empty_operator, - not task.on_execute_callback, - not task.on_success_callback, + not getattr(task, "on_execute_callback", None), + not getattr(task, "on_success_callback", None), not task.outlets, ) ) diff --git a/providers/src/airflow/providers/opensearch/log/os_task_handler.py b/providers/src/airflow/providers/opensearch/log/os_task_handler.py index 8e784076c5f8c..e1a0a083e7291 100644 --- a/providers/src/airflow/providers/opensearch/log/os_task_handler.py +++ b/providers/src/airflow/providers/opensearch/log/os_task_handler.py @@ -294,7 +294,8 @@ def _render_log_id(self, ti: TaskInstance | TaskInstanceKey, try_number: int) -> else: if TYPE_CHECKING: assert dag is not None - data_interval = dag.get_run_data_interval(dag_run) + # TODO: Task-SDK: Where should this function be? + data_interval = dag.get_run_data_interval(dag_run) # type: ignore[attr-defined] if self.json_format: data_interval_start = self._clean_date(data_interval[0]) diff --git a/providers/src/airflow/providers/standard/operators/python.py b/providers/src/airflow/providers/standard/operators/python.py index 86f5f0156d243..7afaaecfaee3d 100644 --- a/providers/src/airflow/providers/standard/operators/python.py +++ b/providers/src/airflow/providers/standard/operators/python.py @@ -1165,7 +1165,7 @@ def my_task(): def _get_current_context() -> Mapping[str, Any]: # Airflow 2.x # TODO: To be removed when Airflow 2 support is dropped - from airflow.models.taskinstance import _CURRENT_CONTEXT + from airflow.models.taskinstance import _CURRENT_CONTEXT # type: ignore[attr-defined] if not _CURRENT_CONTEXT: raise RuntimeError( diff --git a/providers/tests/openlineage/utils/test_utils.py b/providers/tests/openlineage/utils/test_utils.py index 28be0b6306751..3f653d9025ef9 100644 --- a/providers/tests/openlineage/utils/test_utils.py +++ b/providers/tests/openlineage/utils/test_utils.py @@ -25,7 +25,6 @@ from airflow.decorators import task from airflow.models.baseoperator import BaseOperator from airflow.models.dagrun import DagRun -from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance, TaskInstanceState from airflow.operators.empty import EmptyOperator from airflow.providers.openlineage.plugins.facets import AirflowDagRunFacet, AirflowJobFacet @@ -186,7 +185,6 @@ def test_get_fully_qualified_class_name_serialized_operator(): def test_get_fully_qualified_class_name_mapped_operator(): mapped = MockOperator.partial(task_id="task_2").expand(arg2=["a", "b", "c"]) - assert isinstance(mapped, MappedOperator) mapped_op_path = get_fully_qualified_class_name(mapped) assert mapped_op_path == "tests_common.test_utils.mock_operators.MockOperator" @@ -216,7 +214,6 @@ def test_get_operator_class(): def test_get_operator_class_mapped_operator(): mapped = MockOperator.partial(task_id="task").expand(arg2=["a", "b", "c"]) - assert isinstance(mapped, MappedOperator) op_class = get_operator_class(mapped) assert op_class == MockOperator diff --git a/providers/tests/standard/operators/test_python.py b/providers/tests/standard/operators/test_python.py index c43c00dd0e814..039522e98340c 100644 --- a/providers/tests/standard/operators/test_python.py +++ b/providers/tests/standard/operators/test_python.py @@ -925,6 +925,7 @@ def test_virtualenv_serializable_context_fields(self, create_task_instance): "prev_execution_date", "prev_execution_date_success", "conf", + "expanded_ti_count", } else: declared_keys.remove("triggering_asset_events") diff --git a/scripts/ci/pre_commit/base_operator_partial_arguments.py b/scripts/ci/pre_commit/base_operator_partial_arguments.py index b50705331700e..070ffe767cccc 100755 --- a/scripts/ci/pre_commit/base_operator_partial_arguments.py +++ b/scripts/ci/pre_commit/base_operator_partial_arguments.py @@ -28,7 +28,9 @@ BASEOPERATOR_PY = ROOT_DIR.joinpath("airflow", "models", "baseoperator.py") SDK_BASEOPERATOR_PY = ROOT_DIR.joinpath("task_sdk", "src", "airflow", "sdk", "definitions", "baseoperator.py") -MAPPEDOPERATOR_PY = ROOT_DIR.joinpath("airflow", "models", "mappedoperator.py") +SDK_MAPPEDOPERATOR_PY = ROOT_DIR.joinpath( + "task_sdk", "src", "airflow", "sdk", "definitions", "mappedoperator.py" +) IGNORED = { # These are only used in the worker and thus mappable. @@ -72,7 +74,7 @@ BO_MOD = ast.parse(BASEOPERATOR_PY.read_text("utf-8"), str(BASEOPERATOR_PY)) SDK_BO_MOD = ast.parse(SDK_BASEOPERATOR_PY.read_text("utf-8"), str(SDK_BASEOPERATOR_PY)) -MO_MOD = ast.parse(MAPPEDOPERATOR_PY.read_text("utf-8"), str(MAPPEDOPERATOR_PY)) +SDK_MO_MOD = ast.parse(SDK_MAPPEDOPERATOR_PY.read_text("utf-8"), str(SDK_MAPPEDOPERATOR_PY)) # TODO: Task-SDK: Look at the BaseOperator init functions in both airflow.models.baseoperator and combine # them, until we fully remove BaseOperator class from core. @@ -100,19 +102,19 @@ ) # We now define the signature in a type checking block, the runtime impl uses **kwargs -BO_TYPE_CHECKING_BLOCKS = ( +SDK_BO_TYPE_CHECKING_BLOCKS = ( node - for node in ast.iter_child_nodes(BO_MOD) + for node in ast.iter_child_nodes(SDK_BO_MOD) if isinstance(node, ast.If) and node.test.id == "TYPE_CHECKING" # type: ignore[attr-defined] ) BO_PARTIAL = next( node - for node in itertools.chain.from_iterable(map(ast.iter_child_nodes, BO_TYPE_CHECKING_BLOCKS)) + for node in itertools.chain.from_iterable(map(ast.iter_child_nodes, SDK_BO_TYPE_CHECKING_BLOCKS)) if isinstance(node, ast.FunctionDef) and node.name == "partial" ) MO_CLS = next( node - for node in ast.iter_child_nodes(MO_MOD) + for node in ast.iter_child_nodes(SDK_MO_MOD) if isinstance(node, ast.ClassDef) and node.name == "MappedOperator" ) diff --git a/scripts/ci/pre_commit/template_context_key_sync.py b/scripts/ci/pre_commit/template_context_key_sync.py index 2f6c6021b2ed1..501801be55d97 100755 --- a/scripts/ci/pre_commit/template_context_key_sync.py +++ b/scripts/ci/pre_commit/template_context_key_sync.py @@ -102,6 +102,9 @@ def _compare_keys(retn_keys: set[str], decl_keys: set[str], hint_keys: set[str], retn_keys.add("templates_dict") docs_keys.add("templates_dict") + # Compat shim for task-sdk, not actually designed for user use + retn_keys.add("expanded_ti_count") + # Only present in callbacks. Not listed in templates-ref (that doc is for task execution). retn_keys.update(("exception", "reason", "try_number")) docs_keys.update(("exception", "reason", "try_number")) diff --git a/task_sdk/src/airflow/sdk/__init__.py b/task_sdk/src/airflow/sdk/__init__.py index 1bd0358a63c7e..b8d6b6609dba7 100644 --- a/task_sdk/src/airflow/sdk/__init__.py +++ b/task_sdk/src/airflow/sdk/__init__.py @@ -20,15 +20,17 @@ __all__ = [ "BaseOperator", + "Connection", "DAG", "EdgeModifier", "Label", + "MappedOperator", "TaskGroup", + "XComArg", + "__version__", "dag", - "Connection", "get_current_context", "get_parsing_context", - "__version__", ] __version__ = "1.0.0.dev1" @@ -39,17 +41,21 @@ from airflow.sdk.definitions.context import get_current_context, get_parsing_context from airflow.sdk.definitions.dag import DAG, dag from airflow.sdk.definitions.edges import EdgeModifier, Label + from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.taskgroup import TaskGroup + from airflow.sdk.definitions.xcom_arg import XComArg __lazy_imports: dict[str, str] = { - "DAG": ".definitions.dag", - "dag": ".definitions.dag", "BaseOperator": ".definitions.baseoperator", - "TaskGroup": ".definitions.taskgroup", + "Connection": ".definitions.connection", + "DAG": ".definitions.dag", "EdgeModifier": ".definitions.edges", "Label": ".definitions.edges", - "Connection": ".definitions.connection", + "MappedOperator": ".definitions.mappedoperator", + "TaskGroup": ".definitions.taskgroup", "Variable": ".definitions.variable", + "XComArg": ".definitions.xcom_arg", + "dag": ".definitions.dag", "get_current_context": ".definitions.context", "get_parsing_context": ".definitions.context", } diff --git a/task_sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py b/task_sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py index b3435f3091e83..481fb2ee773d7 100644 --- a/task_sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py @@ -21,18 +21,20 @@ import logging from abc import abstractmethod from collections.abc import ( + Callable, Collection, Iterable, + Iterator, + Mapping, ) -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, -) +from typing import TYPE_CHECKING, Any, ClassVar + +import methodtools from airflow.sdk.definitions._internal.mixins import DependencyMixin from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.definitions._internal.templater import Templater +from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule @@ -40,13 +42,18 @@ import jinja2 from airflow.models.baseoperatorlink import BaseOperatorLink - from airflow.models.operator import Operator from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.dag import DAG + from airflow.sdk.definitions.mappedoperator import MappedOperator + from airflow.sdk.definitions.taskgroup import MappedTaskGroup + from airflow.sdk.types import Operator + +TaskStateChangeCallback = Callable[[Mapping[str, Any]], None] DEFAULT_OWNER: str = "airflow" DEFAULT_POOL_SLOTS: int = 1 +DEFAULT_POOL_NAME = "default_pool" DEFAULT_PRIORITY_WEIGHT: int = 1 # Databases do not support arbitrary precision integers, so we need to limit the range of priority weights. # postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html) @@ -93,7 +100,7 @@ class AbstractOperator(Templater, DAGNode): priority_weight: int # Defines the operator level extra links. - operator_extra_links: Collection[BaseOperatorLink] + operator_extra_links: Collection[BaseOperatorLink] = () owner: str task_id: str @@ -163,6 +170,10 @@ def node_id(self) -> str: @abstractmethod def task_display_name(self) -> str: ... + @property + def is_mapped(self): + return self._is_mapped + @property def label(self) -> str | None: if self.task_display_name and self.task_display_name != self.task_id: @@ -174,6 +185,29 @@ def label(self) -> str | None: return self.task_id[len(tg.node_id) + 1 :] return self.task_id + @property + def on_failure_fail_dagrun(self): + """ + Whether the operator should fail the dagrun on failure. + + :meta private: + """ + return self._on_failure_fail_dagrun + + @on_failure_fail_dagrun.setter + def on_failure_fail_dagrun(self, value): + """ + Setter for on_failure_fail_dagrun property. + + :meta private: + """ + if value is True and self.is_teardown is not True: + raise ValueError( + f"Cannot set task on_failure_fail_dagrun for " + f"'{self.task_id}' because it is not a teardown task." + ) + self._on_failure_fail_dagrun = value + def as_setup(self): self.is_setup = True return self @@ -195,6 +229,15 @@ def as_teardown( s >> self return self + def __enter__(self): + if not self.is_setup and not self.is_teardown: + raise RuntimeError("Only setup/teardown tasks can be used as context managers.") + SetupTeardownContext.push_setup_teardown_task(self) + return SetupTeardownContext + + def __exit__(self, exc_type, exc_val, exc_tb): + SetupTeardownContext.set_work_task_roots_and_leaves() + def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]: """ Get a flat set of relative IDs, upstream or downstream. @@ -339,3 +382,111 @@ def _do_render_template_fields( raise else: setattr(parent, attr_name, rendered_content) + + def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: + """ + Return mapped nodes that are direct dependencies of the current task. + + For now, this walks the entire DAG to find mapped nodes that has this + current task as an upstream. We cannot use ``downstream_list`` since it + only contains operators, not task groups. In the future, we should + provide a way to record an DAG node's all downstream nodes instead. + + Note that this does not guarantee the returned tasks actually use the + current task for task mapping, but only checks those task are mapped + operators, and are downstreams of the current task. + + To get a list of tasks that uses the current task for task mapping, use + :meth:`iter_mapped_dependants` instead. + """ + from airflow.sdk.definitions.mappedoperator import MappedOperator + from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup + + def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]: + """ + Recursively walk children in a task group. + + This yields all direct children (including both tasks and task + groups), and all children of any task groups. + """ + for key, child in group.children.items(): + yield key, child + if isinstance(child, TaskGroup): + yield from _walk_group(child) + + dag = self.get_dag() + if not dag: + raise RuntimeError("Cannot check for mapped dependants when not attached to a DAG") + for key, child in _walk_group(dag.task_group): + if key == self.node_id: + continue + if not isinstance(child, (MappedOperator, MappedTaskGroup)): + continue + if self.node_id in child.upstream_task_ids: + yield child + + def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]: + """ + Return mapped nodes that depend on the current task the expansion. + + For now, this walks the entire DAG to find mapped nodes that has this + current task as an upstream. We cannot use ``downstream_list`` since it + only contains operators, not task groups. In the future, we should + provide a way to record an DAG node's all downstream nodes instead. + """ + return ( + downstream + for downstream in self._iter_all_mapped_downstreams() + if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies()) + ) + + def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: + """ + Return mapped task groups this task belongs to. + + Groups are returned from the innermost to the outmost. + + :meta private: + """ + if (group := self.task_group) is None: + return + yield from group.iter_mapped_task_groups() + + def get_closest_mapped_task_group(self) -> MappedTaskGroup | None: + """ + Get the mapped task group "closest" to this task in the DAG. + + :meta private: + """ + return next(self.iter_mapped_task_groups(), None) + + def get_needs_expansion(self) -> bool: + """ + Return true if the task is MappedOperator or is in a mapped task group. + + :meta private: + """ + if self._needs_expansion is None: + if self.get_closest_mapped_task_group() is not None: + self._needs_expansion = True + else: + self._needs_expansion = False + return self._needs_expansion + + @methodtools.lru_cache(maxsize=None) + def get_parse_time_mapped_ti_count(self) -> int: + """ + Return the number of mapped task instances that can be created on DAG run creation. + + This only considers literal mapped arguments, and would return *None* + when any non-literal values are used for mapping. + + :raise NotFullyPopulated: If non-literal mapped arguments are encountered. + :raise NotMapped: If the operator is neither mapped, nor has any parent + mapped task groups. + :return: Total number of mapped TIs this task should have. + """ + group = self.get_closest_mapped_task_group() + if group is None: + raise NotMapped() + return group.get_parse_time_mapped_ti_count() diff --git a/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py b/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py index bdd2a5c08c755..fcd68ba20b2c6 100644 --- a/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py +++ b/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py @@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from airflow.models.operator import Operator + from airflow.sdk.definitions._internal.abstractoperator import Operator from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.edges import EdgeModifier diff --git a/task_sdk/src/airflow/sdk/definitions/_internal/node.py b/task_sdk/src/airflow/sdk/definitions/_internal/node.py index b8c0260911837..fb69c7b926016 100644 --- a/task_sdk/src/airflow/sdk/definitions/_internal/node.py +++ b/task_sdk/src/airflow/sdk/definitions/_internal/node.py @@ -30,8 +30,8 @@ from airflow.sdk.definitions._internal.mixins import DependencyMixin if TYPE_CHECKING: - from airflow.models.operator import Operator from airflow.sdk.definitions._internal.types import Logger + from airflow.sdk.definitions.abstractoperator import Operator from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.edges import EdgeModifier from airflow.sdk.definitions.taskgroup import TaskGroup @@ -99,8 +99,8 @@ def dag_id(self) -> str: return self.dag.dag_id return "_in_memory_dag_" + @methodtools.lru_cache() # type: ignore[misc] @property - @methodtools.lru_cache() def log(self) -> Logger: typ = type(self) name = f"{typ.__module__}.{typ.__qualname__}" @@ -123,8 +123,8 @@ def _set_relatives( edge_modifier: EdgeModifier | None = None, ) -> None: """Set relatives for the task or task list.""" - from airflow.models.mappedoperator import MappedOperator from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator if not isinstance(task_or_task_list, Sequence): task_or_task_list = [task_or_task_list] diff --git a/task_sdk/src/airflow/sdk/definitions/_internal/templater.py b/task_sdk/src/airflow/sdk/definitions/_internal/templater.py index 58a2b450eb3d8..d7028d4d6bca7 100644 --- a/task_sdk/src/airflow/sdk/definitions/_internal/templater.py +++ b/task_sdk/src/airflow/sdk/definitions/_internal/templater.py @@ -177,9 +177,9 @@ def render_template( return self._render(template, context) if isinstance(value, ObjectStoragePath): return self._render_object_storage_path(value, context, jinja_env) - if isinstance(value, ResolveMixin): - # TODO: Task-SDK: Tidy up the typing on template context - return value.resolve(context, include_xcom=True) # type: ignore[arg-type] + + if resolve := getattr(value, "resolve", None): + return resolve(context, include_xcom=True) # Fast path for common built-in collections. if value.__class__ is tuple: diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index 51d4abbeda45d..91ebc4ab6cb8b 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -111,14 +111,8 @@ def normalize_noop(parts: SplitResult) -> SplitResult: def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | None: if scheme == "file": return normalize_noop - from packaging.version import Version - - from airflow import __version__ as AIRFLOW_VERSION from airflow.providers_manager import ProvidersManager - AIRFLOW_V_2 = Version(AIRFLOW_VERSION).base_version < Version("3.0.0").base_version - if AIRFLOW_V_2: - return ProvidersManager().dataset_uri_handlers.get(scheme) # type: ignore[attr-defined] return ProvidersManager().asset_uri_handlers.get(scheme) diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index 2258524525193..e7ecec69411ba 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -24,7 +24,7 @@ import inspect import sys import warnings -from collections.abc import Collection, Iterable, Sequence +from collections.abc import Callable, Collection, Iterable, Sequence from dataclasses import dataclass, field from datetime import datetime, timedelta from functools import total_ordering, wraps @@ -37,6 +37,7 @@ from airflow.sdk.definitions._internal.abstractoperator import ( DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, DEFAULT_OWNER, + DEFAULT_POOL_NAME, DEFAULT_POOL_SLOTS, DEFAULT_PRIORITY_WEIGHT, DEFAULT_QUEUE, @@ -47,10 +48,12 @@ DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, DEFAULT_WEIGHT_RULE, AbstractOperator, + TaskStateChangeCallback, ) from airflow.sdk.definitions._internal.decorators import fixup_decorator_warning_stack from airflow.sdk.definitions._internal.node import validate_key -from airflow.sdk.definitions._internal.types import NOTSET, validate_instance_args +from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, validate_instance_args +from airflow.sdk.definitions.mappedoperator import OperatorPartial, validate_mapping_kwargs from airflow.task.priority_strategy import ( PriorityWeightStrategy, airflow_priority_weight_strategies, @@ -59,12 +62,13 @@ from airflow.utils import timezone from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.trigger_rule import TriggerRule -from airflow.utils.types import AttributeRemoved from airflow.utils.weight_rule import db_safe_priority T = TypeVar("T", bound=FunctionType) if TYPE_CHECKING: + from types import ClassMethodDescriptorType + import jinja2 from airflow.models.xcom_arg import XComArg @@ -72,6 +76,7 @@ from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.serialization.enums import DagAttributeTypes + from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.typing_compat import Self from airflow.utils.operator_resources import Resources @@ -111,6 +116,189 @@ def get_merged_defaults( return args, params +class _PartialDescriptor: + """A descriptor that guards against ``.partial`` being called on Task objects.""" + + class_method: ClassMethodDescriptorType | None = None + + def __get__( + self, obj: BaseOperator, cls: type[BaseOperator] | None = None + ) -> Callable[..., OperatorPartial]: + # Call this "partial" so it looks nicer in stack traces. + def partial(**kwargs): + raise TypeError("partial can only be called on Operator classes, not Tasks themselves") + + if obj is not None: + return partial + return self.class_method.__get__(cls, cls) + + +_PARTIAL_DEFAULTS: dict[str, Any] = { + "map_index_template": None, + "owner": DEFAULT_OWNER, + "trigger_rule": DEFAULT_TRIGGER_RULE, + "depends_on_past": False, + "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, + "wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, + "wait_for_downstream": False, + "retries": DEFAULT_RETRIES, + # "executor": DEFAULT_EXECUTOR, + "queue": DEFAULT_QUEUE, + "pool_slots": DEFAULT_POOL_SLOTS, + "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT, + "retry_delay": DEFAULT_RETRY_DELAY, + "retry_exponential_backoff": False, + "priority_weight": DEFAULT_PRIORITY_WEIGHT, + "weight_rule": DEFAULT_WEIGHT_RULE, + "inlets": [], + "outlets": [], + "allow_nested_operators": True, +} + + +# This is what handles the actual mapping. + +if TYPE_CHECKING: + + def partial( + operator_class: type[BaseOperator], + *, + task_id: str, + dag: DAG | None = None, + task_group: TaskGroup | None = None, + start_date: datetime | ArgNotSet = NOTSET, + end_date: datetime | ArgNotSet = NOTSET, + owner: str | ArgNotSet = NOTSET, + email: None | str | Iterable[str] | ArgNotSet = NOTSET, + params: collections.abc.MutableMapping | None = None, + resources: dict[str, Any] | None | ArgNotSet = NOTSET, + trigger_rule: str | ArgNotSet = NOTSET, + depends_on_past: bool | ArgNotSet = NOTSET, + ignore_first_depends_on_past: bool | ArgNotSet = NOTSET, + wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET, + wait_for_downstream: bool | ArgNotSet = NOTSET, + retries: int | None | ArgNotSet = NOTSET, + queue: str | ArgNotSet = NOTSET, + pool: str | ArgNotSet = NOTSET, + pool_slots: int | ArgNotSet = NOTSET, + execution_timeout: timedelta | None | ArgNotSet = NOTSET, + max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET, + retry_delay: timedelta | float | ArgNotSet = NOTSET, + retry_exponential_backoff: bool | ArgNotSet = NOTSET, + priority_weight: int | ArgNotSet = NOTSET, + weight_rule: str | PriorityWeightStrategy | ArgNotSet = NOTSET, + sla: timedelta | None | ArgNotSet = NOTSET, + map_index_template: str | None | ArgNotSet = NOTSET, + max_active_tis_per_dag: int | None | ArgNotSet = NOTSET, + max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET, + on_execute_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + on_failure_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + on_success_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + on_retry_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + on_skipped_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + run_as_user: str | None | ArgNotSet = NOTSET, + executor: str | None | ArgNotSet = NOTSET, + executor_config: dict | None | ArgNotSet = NOTSET, + inlets: Any | None | ArgNotSet = NOTSET, + outlets: Any | None | ArgNotSet = NOTSET, + doc: str | None | ArgNotSet = NOTSET, + doc_md: str | None | ArgNotSet = NOTSET, + doc_json: str | None | ArgNotSet = NOTSET, + doc_yaml: str | None | ArgNotSet = NOTSET, + doc_rst: str | None | ArgNotSet = NOTSET, + task_display_name: str | None | ArgNotSet = NOTSET, + logger_name: str | None | ArgNotSet = NOTSET, + allow_nested_operators: bool = True, + **kwargs, + ) -> OperatorPartial: ... +else: + + def partial( + operator_class: type[BaseOperator], + *, + task_id: str, + dag: DAG | None = None, + task_group: TaskGroup | None = None, + params: collections.abc.MutableMapping | None = None, + **kwargs, + ): + from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext + + validate_mapping_kwargs(operator_class, "partial", kwargs) + + dag = dag or DagContext.get_current() + if dag: + task_group = task_group or TaskGroupContext.get_current(dag) + if task_group: + task_id = task_group.child_id(task_id) + + # Merge DAG and task group level defaults into user-supplied values. + dag_default_args, partial_params = get_merged_defaults( + dag=dag, + task_group=task_group, + task_params=params, + task_default_args=kwargs.pop("default_args", None), + ) + + # Create partial_kwargs from args and kwargs + partial_kwargs: dict[str, Any] = { + "task_id": task_id, + "dag": dag, + "task_group": task_group, + **kwargs, + } + + # Inject DAG-level default args into args provided to this function. + partial_kwargs.update( + (k, v) for k, v in dag_default_args.items() if partial_kwargs.get(k, NOTSET) is NOTSET + ) + + # Fill fields not provided by the user with default values. + for k, v in _PARTIAL_DEFAULTS.items(): + partial_kwargs.setdefault(k, v) + + # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). + if "task_concurrency" in kwargs: # Reject deprecated option. + raise TypeError("unexpected argument: task_concurrency") + if start_date := partial_kwargs.get("start_date", None): + partial_kwargs["start_date"] = timezone.convert_to_utc(start_date) + if end_date := partial_kwargs.get("end_date", None): + partial_kwargs["end_date"] = timezone.convert_to_utc(end_date) + if partial_kwargs["pool_slots"] < 1: + dag_str = "" + if dag: + dag_str = f" in dag {dag.dag_id}" + raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") + if retries := partial_kwargs.get("retries"): + partial_kwargs["retries"] = BaseOperator._convert_retries(retries) + partial_kwargs["retry_delay"] = BaseOperator._convert_retry_delay(partial_kwargs["retry_delay"]) + partial_kwargs["max_retry_delay"] = BaseOperator._convert_max_retry_delay( + partial_kwargs.get("max_retry_delay", None) + ) + partial_kwargs.setdefault("executor_config", {}) + + return OperatorPartial( + operator_class=operator_class, + kwargs=partial_kwargs, + params=partial_params, + ) + + class BaseOperatorMeta(abc.ABCMeta): """Metaclass of BaseOperator.""" @@ -224,11 +412,9 @@ def __new__(cls, name, bases, namespace, **kwargs): with contextlib.suppress(KeyError): # Update the partial descriptor with the class method, so it calls the actual function # (but let subclasses override it if they need to) - # TODO: Task-SDK - # partial_desc = vars(new_cls)["partial"] - # if isinstance(partial_desc, _PartialDescriptor): - # partial_desc.class_method = classmethod(partial) - ... + partial_desc = vars(new_cls)["partial"] + if isinstance(partial_desc, _PartialDescriptor): + partial_desc.class_method = classmethod(partial) # We patch `__init__` only if the class defines it. if inspect.getmro(new_cls)[1].__init__ is not new_cls.__init__: @@ -533,7 +719,7 @@ def say_hello_world(**context): default_factory=airflow_priority_weight_strategies[DEFAULT_WEIGHT_RULE] ) queue: str = DEFAULT_QUEUE - pool: str = "default" + pool: str = DEFAULT_POOL_NAME pool_slots: int = DEFAULT_POOL_SLOTS execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT # on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None @@ -579,8 +765,7 @@ def say_hello_world(**context): ui_color: str = "#fff" ui_fgcolor: str = "#000" - # TODO: Task-SDK Mapping - # partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore + partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore _dag: DAG | None = field(init=False, default=None) @@ -777,8 +962,7 @@ def __init__( # self.retries = parse_retries(retries) self.retries = retries self.queue = queue - # TODO: Task-SDK: pull this default name from Pool constant? - self.pool = "default_pool" if pool is None else pool + self.pool = DEFAULT_POOL_NAME if pool is None else pool self.pool_slots = pool_slots if self.pool_slots < 1: dag_str = f" in dag {dag.dag_id}" if dag else "" @@ -1006,25 +1190,17 @@ def dag(self) -> DAG: raise RuntimeError(f"Operator {self} has not been assigned to a DAG yet") @dag.setter - def dag(self, dag: DAG | None | AttributeRemoved) -> None: + def dag(self, dag: DAG | None) -> None: """Operators can be assigned to one DAG, one time. Repeat assignments to that same DAG are ok.""" - # TODO: Task-SDK: Remove the AttributeRemoved and this type ignore once we remove AIP-44 code - self._dag = dag # type: ignore[assignment] + self._dag = dag - def _convert__dag(self, dag: DAG | None | AttributeRemoved) -> DAG | None | AttributeRemoved: + def _convert__dag(self, dag: DAG | None) -> DAG | None: # Called automatically by __setattr__ method from airflow.sdk.definitions.dag import DAG if dag is None: return dag - # if set to removed, then just set and exit - if type(self._dag) is AttributeRemoved: - return dag - # if setting to removed, then just set and exit - if type(dag) is AttributeRemoved: - return AttributeRemoved("_dag") # type: ignore[assignment] - if not isinstance(dag, DAG): raise TypeError(f"Expected DAG; received {dag.__class__.__name__}") elif self._dag is not None and self._dag is not dag: @@ -1033,8 +1209,7 @@ def _convert__dag(self, dag: DAG | None | AttributeRemoved) -> DAG | None | Attr if self.__from_mapped: pass # Don't add to DAG -- the mapped task takes the place. elif dag.task_dict.get(self.task_id) is not self: - # TODO: Task-SDK: Remove this type ignore - dag.add_task(self) # type: ignore[arg-type] + dag.add_task(self) return dag @staticmethod diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index f4b71ec99584b..cd5217c8111d0 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -74,7 +74,7 @@ from pendulum.tz.timezone import FixedTimezone, Timezone from airflow.decorators import TaskDecoratorCollection - from airflow.models.operator import Operator + from airflow.sdk.definitions.abstractoperator import Operator from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.typing_compat import Self diff --git a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py new file mode 100644 index 0000000000000..0fc0a7fa1896a --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -0,0 +1,900 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import contextlib +import copy +import warnings +from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Union + +import attrs +import methodtools + +from airflow.models.abstractoperator import NotMapped +from airflow.models.expandinput import ( + DictOfListsExpandInput, + ListOfDictsExpandInput, + is_mappable, +) +from airflow.sdk.definitions._internal.abstractoperator import ( + DEFAULT_EXECUTOR, + DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, + DEFAULT_OWNER, + DEFAULT_POOL_NAME, + DEFAULT_POOL_SLOTS, + DEFAULT_PRIORITY_WEIGHT, + DEFAULT_QUEUE, + DEFAULT_RETRIES, + DEFAULT_RETRY_DELAY, + DEFAULT_TRIGGER_RULE, + DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, + DEFAULT_WEIGHT_RULE, + AbstractOperator, +) +from airflow.sdk.definitions._internal.types import NOTSET +from airflow.serialization.enums import DagAttributeTypes +from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy +from airflow.triggers.base import StartTriggerArgs +from airflow.typing_compat import Literal +from airflow.utils.helpers import is_container, prevent_duplicates +from airflow.utils.task_instance_session import get_current_task_instance_session +from airflow.utils.xcom import XCOM_RETURN_KEY + +if TYPE_CHECKING: + import datetime + + import jinja2 # Slow import. + import pendulum + from sqlalchemy.orm.session import Session + + from airflow.models.abstractoperator import ( + TaskStateChangeCallback, + ) + from airflow.models.baseoperatorlink import BaseOperatorLink + from airflow.models.expandinput import ( + ExpandInput, + OperatorExpandArgument, + OperatorExpandKwargsArgument, + ) + from airflow.models.param import ParamsDict + from airflow.models.xcom_arg import XComArg + from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.dag import DAG + from airflow.sdk.types import Operator + from airflow.ti_deps.deps.base_ti_dep import BaseTIDep + from airflow.utils.context import Context + from airflow.utils.operator_resources import Resources + from airflow.utils.task_group import TaskGroup + from airflow.utils.trigger_rule import TriggerRule + + TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, list[TaskStateChangeCallback]] + +ValidationSource = Union[Literal["expand"], Literal["partial"]] + + +def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: + # use a dict so order of args is same as code order + unknown_args = value.copy() + for klass in op.mro(): + init = klass.__init__ # type: ignore[misc] + try: + param_names = init._BaseOperatorMeta__param_names + except AttributeError: + continue + for name in param_names: + value = unknown_args.pop(name, NOTSET) + if func != "expand": + continue + if value is NOTSET: + continue + if is_mappable(value): + continue + type_name = type(value).__name__ + error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}" + raise ValueError(error) + if not unknown_args: + return # If we have no args left to check: stop looking at the MRO chain. + + if len(unknown_args) == 1: + error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}" + else: + names = ", ".join(repr(n) for n in unknown_args) + error = f"unexpected keyword arguments {names}" + raise TypeError(f"{op.__name__}.{func}() got {error}") + + +def ensure_xcomarg_return_value(arg: Any) -> None: + from airflow.sdk.definitions.xcom_arg import XComArg + + if isinstance(arg, XComArg): + for operator, key in arg.iter_references(): + if key != XCOM_RETURN_KEY: + raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}") + elif not is_container(arg): + return + elif isinstance(arg, Mapping): + for v in arg.values(): + ensure_xcomarg_return_value(v) + elif isinstance(arg, Iterable): + for v in arg: + ensure_xcomarg_return_value(v) + + +@attrs.define(kw_only=True, repr=False) +class OperatorPartial: + """ + An "intermediate state" returned by ``BaseOperator.partial()``. + + This only exists at DAG-parsing time; the only intended usage is for the + user to call ``.expand()`` on it at some point (usually in a method chain) to + create a ``MappedOperator`` to add into the DAG. + """ + + operator_class: type[BaseOperator] + kwargs: dict[str, Any] + params: ParamsDict | dict + + _expand_called: bool = False # Set when expand() is called to ease user debugging. + + def __attrs_post_init__(self): + validate_mapping_kwargs(self.operator_class, "partial", self.kwargs) + + def __repr__(self) -> str: + args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items()) + return f"{self.operator_class.__name__}.partial({args})" + + def __del__(self): + if not self._expand_called: + try: + task_id = repr(self.kwargs["task_id"]) + except KeyError: + task_id = f"at {hex(id(self))}" + warnings.warn(f"Task {task_id} was never mapped!", category=UserWarning, stacklevel=1) + + def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator: + if not mapped_kwargs: + raise TypeError("no arguments to expand against") + validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs) + prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") + # Since the input is already checked at parse time, we can set strict + # to False to skip the checks on execution. + return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False) + + def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator: + from airflow.models.xcom_arg import XComArg + + if isinstance(kwargs, Sequence): + for item in kwargs: + if not isinstance(item, (XComArg, Mapping)): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + elif not isinstance(kwargs, XComArg): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) + + def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: + from airflow.operators.empty import EmptyOperator + from airflow.sensors.base import BaseSensorOperator + + self._expand_called = True + ensure_xcomarg_return_value(expand_input.value) + + partial_kwargs = self.kwargs.copy() + task_id = partial_kwargs.pop("task_id") + dag = partial_kwargs.pop("dag") + task_group = partial_kwargs.pop("task_group") + start_date = partial_kwargs.pop("start_date", None) + end_date = partial_kwargs.pop("end_date", None) + + try: + operator_name = self.operator_class.custom_operator_name # type: ignore + except AttributeError: + operator_name = self.operator_class.__name__ + + op = MappedOperator( + operator_class=self.operator_class, + expand_input=expand_input, + partial_kwargs=partial_kwargs, + task_id=task_id, + params=self.params, + operator_extra_links=self.operator_class.operator_extra_links, + template_ext=self.operator_class.template_ext, + template_fields=self.operator_class.template_fields, + template_fields_renderers=self.operator_class.template_fields_renderers, + ui_color=self.operator_class.ui_color, + ui_fgcolor=self.operator_class.ui_fgcolor, + is_empty=issubclass(self.operator_class, EmptyOperator), + is_sensor=issubclass(self.operator_class, BaseSensorOperator), + task_module=self.operator_class.__module__, + task_type=self.operator_class.__name__, + operator_name=operator_name, + dag=dag, + task_group=task_group, + start_date=start_date, + end_date=end_date, + disallow_kwargs_override=strict, + # For classic operators, this points to expand_input because kwargs + # to BaseOperator.expand() contribute to operator arguments. + expand_input_attr="expand_input", + # TODO: Move these to task SDK's BaseOperator and remove getattr + start_trigger_args=getattr(self.operator_class, "start_trigger_args", None), + start_from_trigger=bool(getattr(self.operator_class, "start_from_trigger", False)), + ) + return op + + +@attrs.define( + kw_only=True, + # Disable custom __getstate__ and __setstate__ generation since it interacts + # badly with Airflow's DAG serialization and pickling. When a mapped task is + # deserialized, subclasses are coerced into MappedOperator, but when it goes + # through DAG pickling, all attributes defined in the subclasses are dropped + # by attrs's custom state management. Since attrs does not do anything too + # special here (the logic is only important for slots=True), we use Python's + # built-in implementation, which works (as proven by good old BaseOperator). + getstate_setstate=False, +) +class MappedOperator(AbstractOperator): + """Object representing a mapped operator in a DAG.""" + + # This attribute serves double purpose. For a "normal" operator instance + # loaded from DAG, this holds the underlying non-mapped operator class that + # can be used to create an unmapped operator for execution. For an operator + # recreated from a serialized DAG, however, this holds the serialized data + # that can be used to unmap this into a SerializedBaseOperator. + operator_class: type[BaseOperator] | dict[str, Any] + + _is_mapped: bool = attrs.field(init=False, default=True) + + expand_input: ExpandInput + partial_kwargs: dict[str, Any] + + # Needed for serialization. + task_id: str + params: ParamsDict | dict + deps: frozenset[BaseTIDep] = attrs.field(init=False) + operator_extra_links: Collection[BaseOperatorLink] + template_ext: Sequence[str] + template_fields: Collection[str] + template_fields_renderers: dict[str, str] + ui_color: str + ui_fgcolor: str + _is_empty: bool = attrs.field(alias="is_empty") + _is_sensor: bool = attrs.field(alias="is_sensor", default=False) + _task_module: str + _task_type: str + _operator_name: str + start_trigger_args: StartTriggerArgs | None + start_from_trigger: bool + _needs_expansion: bool = True + + dag: DAG | None + task_group: TaskGroup | None + start_date: pendulum.DateTime | None + end_date: pendulum.DateTime | None + upstream_task_ids: set[str] = attrs.field(factory=set, init=False) + downstream_task_ids: set[str] = attrs.field(factory=set, init=False) + + _disallow_kwargs_override: bool + """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``. + + If *False*, values from ``expand_input`` under duplicate keys override those + under corresponding keys in ``partial_kwargs``. + """ + + _expand_input_attr: str + """Where to get kwargs to calculate expansion length against. + + This should be a name to call ``getattr()`` on. + """ + + supports_lineage: bool = False + + HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset( + ("parse_time_mapped_ti_count", "operator_class", "start_trigger_args", "start_from_trigger") + ) + + @deps.default + def _deps(self): + from airflow.models.baseoperator import BaseOperator + + return BaseOperator.deps + + def __hash__(self): + return id(self) + + def __repr__(self): + return f"" + + def __attrs_post_init__(self): + from airflow.models.xcom_arg import XComArg + + if self.get_closest_mapped_task_group() is not None: + raise NotImplementedError("operator expansion in an expanded task group is not yet supported") + + if self.task_group: + self.task_group.add(self) + if self.dag: + self.dag.add_task(self) + XComArg.apply_upstream_relationship(self, self.expand_input.value) + for k, v in self.partial_kwargs.items(): + if k in self.template_fields: + XComArg.apply_upstream_relationship(self, v) + + @methodtools.lru_cache(maxsize=None) + @classmethod + def get_serialized_fields(cls): + # Not using 'cls' here since we only want to serialize base fields. + return (frozenset(attrs.fields_dict(MappedOperator)) | {"task_type"}) - { + "_task_type", + "dag", + "deps", + "expand_input", # This is needed to be able to accept XComArg. + "task_group", + "upstream_task_ids", + "supports_lineage", + "_is_setup", + "_is_teardown", + "_on_failure_fail_dagrun", + } + + @property + def task_type(self) -> str: + """Implementing Operator.""" + return self._task_type + + @property + def operator_name(self) -> str: + return self._operator_name + + @property + def inherits_from_empty_operator(self) -> bool: + """Implementing Operator.""" + return self._is_empty + + @property + def roots(self) -> Sequence[AbstractOperator]: + """Implementing DAGNode.""" + return [self] + + @property + def leaves(self) -> Sequence[AbstractOperator]: + """Implementing DAGNode.""" + return [self] + + @property + def task_display_name(self) -> str: + return self.partial_kwargs.get("task_display_name") or self.task_id + + @property + def owner(self) -> str: # type: ignore[override] + return self.partial_kwargs.get("owner", DEFAULT_OWNER) + + @property + def email(self) -> None | str | Iterable[str]: + return self.partial_kwargs.get("email") + + @property + def map_index_template(self) -> None | str: + return self.partial_kwargs.get("map_index_template") + + @map_index_template.setter + def map_index_template(self, value: str | None) -> None: + self.partial_kwargs["map_index_template"] = value + + @property + def trigger_rule(self) -> TriggerRule: + return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE) + + @trigger_rule.setter + def trigger_rule(self, value): + self.partial_kwargs["trigger_rule"] = value + + @property + def is_setup(self) -> bool: + return bool(self.partial_kwargs.get("is_setup")) + + @is_setup.setter + def is_setup(self, value: bool) -> None: + self.partial_kwargs["is_setup"] = value + + @property + def is_teardown(self) -> bool: + return bool(self.partial_kwargs.get("is_teardown")) + + @is_teardown.setter + def is_teardown(self, value: bool) -> None: + self.partial_kwargs["is_teardown"] = value + + @property + def depends_on_past(self) -> bool: + return bool(self.partial_kwargs.get("depends_on_past")) + + @depends_on_past.setter + def depends_on_past(self, value: bool) -> None: + self.partial_kwargs["depends_on_past"] = value + + @property + def ignore_first_depends_on_past(self) -> bool: + value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST) + return bool(value) + + @ignore_first_depends_on_past.setter + def ignore_first_depends_on_past(self, value: bool) -> None: + self.partial_kwargs["ignore_first_depends_on_past"] = value + + @property + def wait_for_past_depends_before_skipping(self) -> bool: + value = self.partial_kwargs.get( + "wait_for_past_depends_before_skipping", DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING + ) + return bool(value) + + @wait_for_past_depends_before_skipping.setter + def wait_for_past_depends_before_skipping(self, value: bool) -> None: + self.partial_kwargs["wait_for_past_depends_before_skipping"] = value + + @property + def wait_for_downstream(self) -> bool: + return bool(self.partial_kwargs.get("wait_for_downstream")) + + @wait_for_downstream.setter + def wait_for_downstream(self, value: bool) -> None: + self.partial_kwargs["wait_for_downstream"] = value + + @property + def retries(self) -> int: + return self.partial_kwargs.get("retries", DEFAULT_RETRIES) + + @retries.setter + def retries(self, value: int) -> None: + self.partial_kwargs["retries"] = value + + @property + def queue(self) -> str: + return self.partial_kwargs.get("queue", DEFAULT_QUEUE) + + @queue.setter + def queue(self, value: str) -> None: + self.partial_kwargs["queue"] = value + + @property + def pool(self) -> str: + return self.partial_kwargs.get("pool", DEFAULT_POOL_NAME) + + @pool.setter + def pool(self, value: str) -> None: + self.partial_kwargs["pool"] = value + + @property + def pool_slots(self) -> int: + return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS) + + @pool_slots.setter + def pool_slots(self, value: int) -> None: + self.partial_kwargs["pool_slots"] = value + + @property + def execution_timeout(self) -> datetime.timedelta | None: + return self.partial_kwargs.get("execution_timeout") + + @execution_timeout.setter + def execution_timeout(self, value: datetime.timedelta | None) -> None: + self.partial_kwargs["execution_timeout"] = value + + @property + def max_retry_delay(self) -> datetime.timedelta | None: + return self.partial_kwargs.get("max_retry_delay") + + @max_retry_delay.setter + def max_retry_delay(self, value: datetime.timedelta | None) -> None: + self.partial_kwargs["max_retry_delay"] = value + + @property + def retry_delay(self) -> datetime.timedelta: + return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY) + + @retry_delay.setter + def retry_delay(self, value: datetime.timedelta) -> None: + self.partial_kwargs["retry_delay"] = value + + @property + def retry_exponential_backoff(self) -> bool: + return bool(self.partial_kwargs.get("retry_exponential_backoff")) + + @retry_exponential_backoff.setter + def retry_exponential_backoff(self, value: bool) -> None: + self.partial_kwargs["retry_exponential_backoff"] = value + + @property + def priority_weight(self) -> int: # type: ignore[override] + return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT) + + @priority_weight.setter + def priority_weight(self, value: int) -> None: + self.partial_kwargs["priority_weight"] = value + + @property + def weight_rule(self) -> PriorityWeightStrategy: # type: ignore[override] + return validate_and_load_priority_weight_strategy( + self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) + ) + + @weight_rule.setter + def weight_rule(self, value: str | PriorityWeightStrategy) -> None: + self.partial_kwargs["weight_rule"] = validate_and_load_priority_weight_strategy(value) + + @property + def max_active_tis_per_dag(self) -> int | None: + return self.partial_kwargs.get("max_active_tis_per_dag") + + @max_active_tis_per_dag.setter + def max_active_tis_per_dag(self, value: int | None) -> None: + self.partial_kwargs["max_active_tis_per_dag"] = value + + @property + def max_active_tis_per_dagrun(self) -> int | None: + return self.partial_kwargs.get("max_active_tis_per_dagrun") + + @max_active_tis_per_dagrun.setter + def max_active_tis_per_dagrun(self, value: int | None) -> None: + self.partial_kwargs["max_active_tis_per_dagrun"] = value + + @property + def resources(self) -> Resources | None: + return self.partial_kwargs.get("resources") + + @property + def on_execute_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_execute_callback") + + @on_execute_callback.setter + def on_execute_callback(self, value: TaskStateChangeCallbackAttrType) -> None: + self.partial_kwargs["on_execute_callback"] = value + + @property + def on_failure_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_failure_callback") + + @on_failure_callback.setter + def on_failure_callback(self, value: TaskStateChangeCallbackAttrType) -> None: + self.partial_kwargs["on_failure_callback"] = value + + @property + def on_retry_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_retry_callback") + + @on_retry_callback.setter + def on_retry_callback(self, value: TaskStateChangeCallbackAttrType) -> None: + self.partial_kwargs["on_retry_callback"] = value + + @property + def on_success_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_success_callback") + + @on_success_callback.setter + def on_success_callback(self, value: TaskStateChangeCallbackAttrType) -> None: + self.partial_kwargs["on_success_callback"] = value + + @property + def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType: + return self.partial_kwargs.get("on_skipped_callback") + + @on_skipped_callback.setter + def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) -> None: + self.partial_kwargs["on_skipped_callback"] = value + + @property + def run_as_user(self) -> str | None: + return self.partial_kwargs.get("run_as_user") + + @property + def executor(self) -> str | None: + return self.partial_kwargs.get("executor", DEFAULT_EXECUTOR) + + @property + def executor_config(self) -> dict: + return self.partial_kwargs.get("executor_config", {}) + + @property # type: ignore[override] + def inlets(self) -> list[Any]: # type: ignore[override] + return self.partial_kwargs.get("inlets", []) + + @inlets.setter + def inlets(self, value: list[Any]) -> None: # type: ignore[override] + self.partial_kwargs["inlets"] = value + + @property # type: ignore[override] + def outlets(self) -> list[Any]: # type: ignore[override] + return self.partial_kwargs.get("outlets", []) + + @outlets.setter + def outlets(self, value: list[Any]) -> None: # type: ignore[override] + self.partial_kwargs["outlets"] = value + + @property + def doc(self) -> str | None: + return self.partial_kwargs.get("doc") + + @property + def doc_md(self) -> str | None: + return self.partial_kwargs.get("doc_md") + + @property + def doc_json(self) -> str | None: + return self.partial_kwargs.get("doc_json") + + @property + def doc_yaml(self) -> str | None: + return self.partial_kwargs.get("doc_yaml") + + @property + def doc_rst(self) -> str | None: + return self.partial_kwargs.get("doc_rst") + + @property + def allow_nested_operators(self) -> bool: + return bool(self.partial_kwargs.get("allow_nested_operators")) + + def get_dag(self) -> DAG | None: + """Implement Operator.""" + return self.dag + + @property + def output(self) -> XComArg: + """Return reference to XCom pushed by current operator.""" + from airflow.models.xcom_arg import XComArg + + return XComArg(operator=self) + + def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: + """Implement DAGNode.""" + return DagAttributeTypes.OP, self.task_id + + def _expand_mapped_kwargs( + self, context: Mapping[str, Any], session: Session, *, include_xcom: bool + ) -> tuple[Mapping[str, Any], set[int]]: + """ + Get the kwargs to create the unmapped operator. + + This exists because taskflow operators expand against op_kwargs, not the + entire operator kwargs dict. + """ + return self._get_specified_expand_input().resolve(context, session, include_xcom=include_xcom) + + def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: + """ + Get init kwargs to unmap the underlying operator class. + + :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``. + """ + if strict: + prevent_duplicates( + self.partial_kwargs, + mapped_kwargs, + fail_reason="unmappable or already specified", + ) + + # If params appears in the mapped kwargs, we need to merge it into the + # partial params, overriding existing keys. + params = copy.copy(self.params) + with contextlib.suppress(KeyError): + params.update(mapped_kwargs["params"]) + + # Ordering is significant; mapped kwargs should override partial ones, + # and the specially handled params should be respected. + return { + "task_id": self.task_id, + "dag": self.dag, + "task_group": self.task_group, + "start_date": self.start_date, + "end_date": self.end_date, + **self.partial_kwargs, + **mapped_kwargs, + "params": params, + } + + def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool: + """ + Get the start_from_trigger value of the current abstract operator. + + MappedOperator uses this to unmap start_from_trigger to decide whether to start the task + execution directly from triggerer. + + :meta private: + """ + # start_from_trigger only makes sense when start_trigger_args exists. + if not self.start_trigger_args: + return False + + mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False) + if self._disallow_kwargs_override: + prevent_duplicates( + self.partial_kwargs, + mapped_kwargs, + fail_reason="unmappable or already specified", + ) + + # Ordering is significant; mapped kwargs should override partial ones. + return mapped_kwargs.get( + "start_from_trigger", self.partial_kwargs.get("start_from_trigger", self.start_from_trigger) + ) + + def expand_start_trigger_args(self, *, context: Context, session: Session) -> StartTriggerArgs | None: + """ + Get the kwargs to create the unmapped start_trigger_args. + + This method is for allowing mapped operator to start execution from triggerer. + """ + if not self.start_trigger_args: + return None + + mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False) + if self._disallow_kwargs_override: + prevent_duplicates( + self.partial_kwargs, + mapped_kwargs, + fail_reason="unmappable or already specified", + ) + + # Ordering is significant; mapped kwargs should override partial ones. + trigger_kwargs = mapped_kwargs.get( + "trigger_kwargs", + self.partial_kwargs.get("trigger_kwargs", self.start_trigger_args.trigger_kwargs), + ) + next_kwargs = mapped_kwargs.get( + "next_kwargs", + self.partial_kwargs.get("next_kwargs", self.start_trigger_args.next_kwargs), + ) + timeout = mapped_kwargs.get( + "trigger_timeout", self.partial_kwargs.get("trigger_timeout", self.start_trigger_args.timeout) + ) + return StartTriggerArgs( + trigger_cls=self.start_trigger_args.trigger_cls, + trigger_kwargs=trigger_kwargs, + next_method=self.start_trigger_args.next_method, + next_kwargs=next_kwargs, + timeout=timeout, + ) + + def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator: + """ + Get the "normal" Operator after applying the current mapping. + + The *resolve* argument is only used if ``operator_class`` is a real + class, i.e. if this operator is not serialized. If ``operator_class`` is + not a class (i.e. this DAG has been deserialized), this returns a + SerializedBaseOperator that "looks like" the actual unmapping result. + + If *resolve* is a two-tuple (context, session), the information is used + to resolve the mapped arguments into init arguments. If it is a mapping, + no resolving happens, the mapping directly provides those init arguments + resolved from mapped kwargs. + + :meta private: + """ + if isinstance(self.operator_class, type): + if isinstance(resolve, Mapping): + kwargs = resolve + elif resolve is not None: + kwargs, _ = self._expand_mapped_kwargs(*resolve, include_xcom=True) + else: + raise RuntimeError("cannot unmap a non-serialized operator without context") + kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override) + is_setup = kwargs.pop("is_setup", False) + is_teardown = kwargs.pop("is_teardown", False) + on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False) + op = self.operator_class(**kwargs, _airflow_from_mapped=True) + # We need to overwrite task_id here because BaseOperator further + # mangles the task_id based on the task hierarchy (namely, group_id + # is prepended, and '__N' appended to deduplicate). This is hacky, + # but better than duplicating the whole mangling logic. + op.task_id = self.task_id + op.is_setup = is_setup + op.is_teardown = is_teardown + op.on_failure_fail_dagrun = on_failure_fail_dagrun + return op + + # TODO: TaskSDK: This probably doesn't need to live in definition time as the next section of code is + # for unmapping a deserialized DAG -- i.e. in the scheduler. + + # After a mapped operator is serialized, there's no real way to actually + # unmap it since we've lost access to the underlying operator class. + # This tries its best to simply "forward" all the attributes on this + # mapped operator to a new SerializedBaseOperator instance. + from airflow.serialization.serialized_objects import SerializedBaseOperator + + op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True) + for partial_attr, value in self.partial_kwargs.items(): + setattr(op, partial_attr, value) + SerializedBaseOperator.populate_operator(op, self.operator_class) + if self.dag is not None: # For Mypy; we only serialize tasks in a DAG so the check always satisfies. + SerializedBaseOperator.set_task_dag_references(op, self.dag) # type: ignore[arg-type] + return op + + def _get_specified_expand_input(self) -> ExpandInput: + """Input received from the expand call on the operator.""" + return getattr(self, self._expand_input_attr) + + def prepare_for_execution(self) -> MappedOperator: + # Since a mapped operator cannot be used for execution, and an unmapped + # BaseOperator needs to be created later (see render_template_fields), + # we don't need to create a copy of the MappedOperator here. + return self + + def iter_mapped_dependencies(self) -> Iterator[Operator]: + """Upstream dependencies that provide XComs used by this task for task mapping.""" + from airflow.models.xcom_arg import XComArg + + for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()): + yield operator + + @methodtools.lru_cache(maxsize=None) + def get_parse_time_mapped_ti_count(self) -> int: + current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count() + try: + # The use of `methodtools` interferes with the zero-arg super + parent_count = super(MappedOperator, self).get_parse_time_mapped_ti_count() # noqa: UP008 + except NotMapped: + return current_count + return parent_count * current_count + + def render_template_fields( + self, + context: Context, + jinja_env: jinja2.Environment | None = None, + ) -> None: + """ + Template all attributes listed in *self.template_fields*. + + This updates *context* to reference the map-expanded task and relevant + information, without modifying the mapped operator. The expanded task + in *context* is then rendered in-place. + + :param context: Context dict with values to apply on content. + :param jinja_env: Jinja environment to use for rendering. + """ + from airflow.utils.context import context_update_for_unmapped + + if not jinja_env: + jinja_env = self.get_template_env() + + # We retrieve the session here, stored by _run_raw_task in set_current_task_session + # context manager - we cannot pass the session via @provide_session because the signature + # of render_template_fields is defined by BaseOperator and there are already many subclasses + # overriding it, so changing the signature is not an option. However render_template_fields is + # always executed within "_run_raw_task" so we make sure that _run_raw_task uses the + # set_current_task_session context manager to store the session in the current task. + session = get_current_task_instance_session() + + mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session, include_xcom=True) + unmapped_task = self.unmap(mapped_kwargs) + # TODO: Task-SDK: remove arg-type ignore once Kaxil's PR lands + context_update_for_unmapped(context, unmapped_task) # type: ignore[arg-type] + + # Since the operators that extend `BaseOperator` are not subclasses of + # `MappedOperator`, we need to call `_do_render_template_fields` from + # the unmapped task in order to call the operator method when we override + # it to customize the parsing of nested fields. + unmapped_task._do_render_template_fields( + parent=unmapped_task, + template_fields=self.template_fields, + context=context, + jinja_env=jinja_env, + seen_oids=seen_oids, + ) diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py b/task_sdk/src/airflow/sdk/definitions/taskgroup.py index cb5ee3eeece08..f609c5aa21c48 100644 --- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py @@ -46,6 +46,7 @@ from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.edges import EdgeModifier + from airflow.sdk.types import Operator from airflow.serialization.enums import DagAttributeTypes @@ -105,7 +106,9 @@ class TaskGroup(DAGNode): """ _group_id: str | None = attrs.field( - validator=attrs.validators.optional(attrs.validators.instance_of(str)) + validator=attrs.validators.optional(attrs.validators.instance_of(str)), + # This is the default behaviour for attrs, but by specifying this it makes IDEs happier + alias="group_id", ) group_display_name: str = attrs.field(default="", validator=attrs.validators.instance_of(str)) prefix_group_id: bool = attrs.field(default=True) @@ -561,7 +564,7 @@ def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: def iter_tasks(self) -> Iterator[AbstractOperator]: """Return an iterator of the child tasks.""" - from airflow.models.abstractoperator import AbstractOperator + from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator groups_to_visit = [self] @@ -575,7 +578,7 @@ def iter_tasks(self) -> Iterator[AbstractOperator]: groups_to_visit.append(child) else: raise ValueError( - f"Encountered a DAGNode that is not a TaskGroup or an AbstractOperator: {type(child)}" + f"Encountered a DAGNode that is not a TaskGroup or an AbstractOperator: {type(child).__module__}.{type(child)}" ) @@ -604,13 +607,6 @@ def __iter__(self): ) yield from self._iter_child(child) - def iter_mapped_dependencies(self) -> Iterator[DAGNode]: - """Upstream dependencies that provide XComs used by this mapped task group.""" - from airflow.models.xcom_arg import XComArg - - for op, _ in XComArg.iter_xcom_references(self._expand_input): - yield op - @methodtools.lru_cache(maxsize=None) def get_parse_time_mapped_ti_count(self) -> int: """ @@ -637,11 +633,18 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.set_upstream(op) super().__exit__(exc_type, exc_val, exc_tb) + def iter_mapped_dependencies(self) -> Iterator[Operator]: + """Upstream dependencies that provide XComs used by this mapped task group.""" + from airflow.models.xcom_arg import XComArg + + for op, _ in XComArg.iter_xcom_references(self._expand_input): + yield op + def task_group_to_dict(task_item_or_group): """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" - from airflow.models.abstractoperator import AbstractOperator - from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.definitions.abstractoperator import AbstractOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sensors.base import BaseSensorOperator if isinstance(task := task_item_or_group, AbstractOperator): diff --git a/task_sdk/src/airflow/sdk/definitions/xcom_arg.py b/task_sdk/src/airflow/sdk/definitions/xcom_arg.py new file mode 100644 index 0000000000000..436cd9d005012 --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -0,0 +1,639 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import contextlib +import inspect +import itertools +from collections.abc import Iterable, Iterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Callable, Union, overload + +from airflow.exceptions import AirflowException, XComNotFound +from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator +from airflow.sdk.definitions._internal.mixins import DependencyMixin, ResolveMixin +from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet +from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.setup_teardown import SetupTeardownContext +from airflow.utils.trigger_rule import TriggerRule +from airflow.utils.xcom import XCOM_RETURN_KEY + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.dag import DAG + from airflow.sdk.types import Operator + from airflow.utils.edgemodifier import EdgeModifier + +# Callable objects contained by MapXComArg. We only accept callables from +# the user, but deserialize them into strings in a serialized XComArg for +# safety (those callables are arbitrary user code). +MapCallables = Sequence[Union[Callable[[Any], Any], str]] + + +class XComArg(ResolveMixin, DependencyMixin): + """ + Reference to an XCom value pushed from another operator. + + The implementation supports:: + + xcomarg >> op + xcomarg << op + op >> xcomarg # By BaseOperator code + op << xcomarg # By BaseOperator code + + **Example**: The moment you get a result from any operator (decorated or regular) you can :: + + any_op = AnyOperator() + xcomarg = XComArg(any_op) + # or equivalently + xcomarg = any_op.output + my_op = MyOperator() + my_op >> xcomarg + + This object can be used in legacy Operators via Jinja. + + **Example**: You can make this result to be part of any generated string:: + + any_op = AnyOperator() + xcomarg = any_op.output + op1 = MyOperator(my_text_message=f"the value is {xcomarg}") + op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}") + + :param operator: Operator instance to which the XComArg references. + :param key: Key used to pull the XCom value. Defaults to *XCOM_RETURN_KEY*, + i.e. the referenced operator's return value. + """ + + @overload + def __new__(cls: type[XComArg], operator: Operator, key: str = XCOM_RETURN_KEY) -> XComArg: + """Execute when the user writes ``XComArg(...)`` directly.""" + + @overload + def __new__(cls: type[XComArg]) -> XComArg: + """Execute by Python internals from subclasses.""" + + def __new__(cls, *args, **kwargs) -> XComArg: + if cls is XComArg: + return PlainXComArg(*args, **kwargs) + return super().__new__(cls) + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + raise NotImplementedError() + + @staticmethod + def iter_xcom_references(arg: Any) -> Iterator[tuple[Operator, str]]: + """ + Return XCom references in an arbitrary value. + + Recursively traverse ``arg`` and look for XComArg instances in any + collection objects, and instances with ``template_fields`` set. + """ + if isinstance(arg, ResolveMixin): + yield from arg.iter_references() + elif isinstance(arg, (tuple, set, list)): + for elem in arg: + yield from XComArg.iter_xcom_references(elem) + elif isinstance(arg, dict): + for elem in arg.values(): + yield from XComArg.iter_xcom_references(elem) + elif isinstance(arg, AbstractOperator): + for attr in arg.template_fields: + yield from XComArg.iter_xcom_references(getattr(arg, attr)) + + @staticmethod + def apply_upstream_relationship(op: DependencyMixin, arg: Any): + """ + Set dependency for XComArgs. + + This looks for XComArg objects in ``arg`` "deeply" (looking inside + collections objects and classes decorated with ``template_fields``), and + sets the relationship to ``op`` on any found. + """ + for operator, _ in XComArg.iter_xcom_references(arg): + op.set_upstream(operator) + + @property + def roots(self) -> list[Operator]: + """Required by DependencyMixin.""" + return [op for op, _ in self.iter_references()] + + @property + def leaves(self) -> list[Operator]: + """Required by DependencyMixin.""" + return [op for op, _ in self.iter_references()] + + def set_upstream( + self, + task_or_task_list: DependencyMixin | Sequence[DependencyMixin], + edge_modifier: EdgeModifier | None = None, + ): + """Proxy to underlying operator set_upstream method. Required by DependencyMixin.""" + for operator, _ in self.iter_references(): + operator.set_upstream(task_or_task_list, edge_modifier) + + def set_downstream( + self, + task_or_task_list: DependencyMixin | Sequence[DependencyMixin], + edge_modifier: EdgeModifier | None = None, + ): + """Proxy to underlying operator set_downstream method. Required by DependencyMixin.""" + for operator, _ in self.iter_references(): + operator.set_downstream(task_or_task_list, edge_modifier) + + def _serialize(self) -> dict[str, Any]: + """ + Serialize an XComArg. + + The implementation should be the inverse function to ``deserialize``, + returning a data dict converted from this XComArg derivative. DAG + serialization does not call this directly, but ``serialize_xcom_arg`` + instead, which adds additional information to dispatch deserialization + to the correct class. + """ + raise NotImplementedError() + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: + """ + Deserialize an XComArg. + + The implementation should be the inverse function to ``serialize``, + implementing given a data dict converted from this XComArg derivative, + how the original XComArg should be created. DAG serialization relies on + additional information added in ``serialize_xcom_arg`` to dispatch data + dicts to the correct ``_deserialize`` information, so this function does + not need to validate whether the incoming data contains correct keys. + """ + raise NotImplementedError() + + def map(self, f: Callable[[Any], Any]) -> MapXComArg: + return MapXComArg(self, [f]) + + def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: + return ZipXComArg([self, *others], fillvalue=fillvalue) + + def concat(self, *others: XComArg) -> ConcatXComArg: + return ConcatXComArg([self, *others]) + + def resolve( + self, context: Mapping[str, Any], session: Session | None = None, *, include_xcom: bool = True + ) -> Any: + raise NotImplementedError() + + def __enter__(self): + if not self.operator.is_setup and not self.operator.is_teardown: + raise AirflowException("Only setup/teardown tasks can be used as context managers.") + SetupTeardownContext.push_setup_teardown_task(self.operator) + return SetupTeardownContext + + def __exit__(self, exc_type, exc_val, exc_tb): + SetupTeardownContext.set_work_task_roots_and_leaves() + + +class PlainXComArg(XComArg): + """ + Reference to one single XCom without any additional semantics. + + This class should not be accessed directly, but only through XComArg. The + class inheritance chain and ``__new__`` is implemented in this slightly + convoluted way because we want to + + a. Allow the user to continue using XComArg directly for the simple + semantics (see documentation of the base class for details). + b. Make ``isinstance(thing, XComArg)`` be able to detect all kinds of XCom + references. + c. Not allow many properties of PlainXComArg (including ``__getitem__`` and + ``__str__``) to exist on other kinds of XComArg implementations since + they don't make sense. + + :meta private: + """ + + def __init__(self, operator: Operator, key: str = XCOM_RETURN_KEY): + self.operator = operator + self.key = key + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, PlainXComArg): + return NotImplemented + return self.operator == other.operator and self.key == other.key + + def __getitem__(self, item: str) -> XComArg: + """Implement xcomresult['some_result_key'].""" + if not isinstance(item, str): + raise ValueError(f"XComArg only supports str lookup, received {type(item).__name__}") + return PlainXComArg(operator=self.operator, key=item) + + def __iter__(self): + """ + Override iterable protocol to raise error explicitly. + + The default ``__iter__`` implementation in Python calls ``__getitem__`` + with 0, 1, 2, etc. until it hits an ``IndexError``. This does not work + well with our custom ``__getitem__`` implementation, and results in poor + DAG-writing experience since a misplaced ``*`` expansion would create an + infinite loop consuming the entire DAG parser. + + This override catches the error eagerly, so an incorrectly implemented + DAG fails fast and avoids wasting resources on nonsensical iterating. + """ + raise TypeError("'XComArg' object is not iterable") + + def __repr__(self) -> str: + if self.key == XCOM_RETURN_KEY: + return f"XComArg({self.operator!r})" + return f"XComArg({self.operator!r}, {self.key!r})" + + def __str__(self) -> str: + """ + Backward compatibility for old-style jinja used in Airflow Operators. + + **Example**: to use XComArg at BashOperator:: + + BashOperator(cmd=f"... { xcomarg } ...") + + :return: + """ + xcom_pull_kwargs = [ + f"task_ids='{self.operator.task_id}'", + f"dag_id='{self.operator.dag_id}'", + ] + if self.key is not None: + xcom_pull_kwargs.append(f"key='{self.key}'") + + xcom_pull_str = ", ".join(xcom_pull_kwargs) + # {{{{ are required for escape {{ in f-string + xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_str}) }}}}" + return xcom_pull + + def _serialize(self) -> dict[str, Any]: + return {"task_id": self.operator.task_id, "key": self.key} + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: + return cls(dag.get_task(data["task_id"]), data["key"]) + + @property + def is_setup(self) -> bool: + return self.operator.is_setup + + @is_setup.setter + def is_setup(self, val: bool): + self.operator.is_setup = val + + @property + def is_teardown(self) -> bool: + return self.operator.is_teardown + + @is_teardown.setter + def is_teardown(self, val: bool): + self.operator.is_teardown = val + + @property + def on_failure_fail_dagrun(self) -> bool: + return self.operator.on_failure_fail_dagrun + + @on_failure_fail_dagrun.setter + def on_failure_fail_dagrun(self, val: bool): + self.operator.on_failure_fail_dagrun = val + + def as_setup(self) -> DependencyMixin: + for operator, _ in self.iter_references(): + operator.is_setup = True + return self + + def as_teardown( + self, + *, + setups: BaseOperator | Iterable[BaseOperator] | None = None, + on_failure_fail_dagrun: bool | None = None, + ): + for operator, _ in self.iter_references(): + operator.is_teardown = True + operator.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS + if on_failure_fail_dagrun is not None: + operator.on_failure_fail_dagrun = on_failure_fail_dagrun + if setups is not None: + setups = [setups] if isinstance(setups, DependencyMixin) else setups + for s in setups: + s.is_setup = True + s >> operator + return self + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + yield self.operator, self.key + + def map(self, f: Callable[[Any], Any]) -> MapXComArg: + if self.key != XCOM_RETURN_KEY: + raise ValueError("cannot map against non-return XCom") + return super().map(f) + + def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: + if self.key != XCOM_RETURN_KEY: + raise ValueError("cannot map against non-return XCom") + return super().zip(*others, fillvalue=fillvalue) + + def concat(self, *others: XComArg) -> ConcatXComArg: + if self.key != XCOM_RETURN_KEY: + raise ValueError("cannot concatenate non-return XCom") + return super().concat(*others) + + # TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK + def resolve( + self, context: Mapping[str, Any], session: Session | None = None, *, include_xcom: bool = True + ) -> Any: + ti = context["ti"] + task_id = self.operator.task_id + map_indexes = context.get("_upstream_map_indexes", {}).get(task_id) + + result = ti.xcom_pull( + task_ids=task_id, + map_indexes=map_indexes, + key=self.key, + default=NOTSET, + ) + if not isinstance(result, ArgNotSet): + return result + if self.key == XCOM_RETURN_KEY: + return None + if getattr(self.operator, "multiple_outputs", False): + # If the operator is set to have multiple outputs and it was not executed, + # we should return "None" instead of showing an error. This is because when + # multiple outputs XComs are created, the XCom keys associated with them will have + # different names than the predefined "XCOM_RETURN_KEY" and won't be found. + # Therefore, it's better to return "None" like we did above where self.key==XCOM_RETURN_KEY. + return None + raise XComNotFound(ti.dag_id, task_id, self.key) + + +def _get_callable_name(f: Callable | str) -> str: + """Try to "describe" a callable by getting its name.""" + if callable(f): + return f.__name__ + # Parse the source to find whatever is behind "def". For safety, we don't + # want to evaluate the code in any meaningful way! + with contextlib.suppress(Exception): + kw, name, _ = f.lstrip().split(None, 2) + if kw == "def": + return name + return "" + + +class _MapResult(Sequence): + def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: + self.value = value + self.callables = callables + + def __getitem__(self, index: Any) -> Any: + value = self.value[index] + + # In the worker, we can access all actual callables. Call them. + callables = [f for f in self.callables if callable(f)] + if len(callables) == len(self.callables): + for f in callables: + value = f(value) + return value + + # In the scheduler, we don't have access to the actual callables, nor do + # we want to run it since it's arbitrary code. This builds a string to + # represent the call chain in the UI or logs instead. + for v in self.callables: + value = f"{_get_callable_name(v)}({value})" + return value + + def __len__(self) -> int: + return len(self.value) + + +class MapXComArg(XComArg): + """ + An XCom reference with ``map()`` call(s) applied. + + This is based on an XComArg, but also applies a series of "transforms" that + convert the pulled XCom value. + + :meta private: + """ + + def __init__(self, arg: XComArg, callables: MapCallables) -> None: + for c in callables: + if getattr(c, "_airflow_is_task_decorator", False): + raise ValueError("map() argument must be a plain function, not a @task operator") + self.arg = arg + self.callables = callables + + def __repr__(self) -> str: + map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables) + return f"{self.arg!r}{map_calls}" + + def _serialize(self) -> dict[str, Any]: + return { + "arg": serialize_xcom_arg(self.arg), + "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables], + } + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: + # We are deliberately NOT deserializing the callables. These are shown + # in the UI, and displaying a function object is useless. + return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + yield from self.arg.iter_references() + + def map(self, f: Callable[[Any], Any]) -> MapXComArg: + # Flatten arg.map(f1).map(f2) into one MapXComArg. + return MapXComArg(self.arg, [*self.callables, f]) + + # TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK + @provide_session + def resolve( + self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True + ) -> Any: + value = self.arg.resolve(context, session=session, include_xcom=include_xcom) + if not isinstance(value, (Sequence, dict)): + raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") + return _MapResult(value, self.callables) + + +class _ZipResult(Sequence): + def __init__(self, values: Sequence[Sequence | dict], *, fillvalue: Any = NOTSET) -> None: + self.values = values + self.fillvalue = fillvalue + + @staticmethod + def _get_or_fill(container: Sequence | dict, index: Any, fillvalue: Any) -> Any: + try: + return container[index] + except (IndexError, KeyError): + return fillvalue + + def __getitem__(self, index: Any) -> Any: + if index >= len(self): + raise IndexError(index) + return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values) + + def __len__(self) -> int: + lengths = (len(v) for v in self.values) + if isinstance(self.fillvalue, ArgNotSet): + return min(lengths) + return max(lengths) + + +class ZipXComArg(XComArg): + """ + An XCom reference with ``zip()`` applied. + + This is constructed from multiple XComArg instances, and presents an + iterable that "zips" them together like the built-in ``zip()`` (and + ``itertools.zip_longest()`` if ``fillvalue`` is provided). + """ + + def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None: + if not args: + raise ValueError("At least one input is required") + self.args = args + self.fillvalue = fillvalue + + def __repr__(self) -> str: + args_iter = iter(self.args) + first = repr(next(args_iter)) + rest = ", ".join(repr(arg) for arg in args_iter) + if isinstance(self.fillvalue, ArgNotSet): + return f"{first}.zip({rest})" + return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})" + + def _serialize(self) -> dict[str, Any]: + args = [serialize_xcom_arg(arg) for arg in self.args] + if isinstance(self.fillvalue, ArgNotSet): + return {"args": args} + return {"args": args, "fillvalue": self.fillvalue} + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: + return cls( + [deserialize_xcom_arg(arg, dag) for arg in data["args"]], + fillvalue=data.get("fillvalue", NOTSET), + ) + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + for arg in self.args: + yield from arg.iter_references() + + @provide_session + def resolve( + self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True + ) -> Any: + values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args] + for value in values: + if not isinstance(value, (Sequence, dict)): + raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}") + return _ZipResult(values, fillvalue=self.fillvalue) + + +class _ConcatResult(Sequence): + def __init__(self, values: Sequence[Sequence | dict]) -> None: + self.values = values + + def __getitem__(self, index: Any) -> Any: + if index >= 0: + i = index + else: + i = len(self) + index + for value in self.values: + if i < 0: + break + elif i >= (curlen := len(value)): + i -= curlen + elif isinstance(value, Sequence): + return value[i] + else: + return next(itertools.islice(iter(value), i, None)) + raise IndexError("list index out of range") + + def __len__(self) -> int: + return sum(len(v) for v in self.values) + + +class ConcatXComArg(XComArg): + """ + Concatenating multiple XCom references into one. + + This is done by calling ``concat()`` on an XComArg to combine it with + others. The effect is similar to Python's :func:`itertools.chain`, but the + return value also supports index access. + """ + + def __init__(self, args: Sequence[XComArg]) -> None: + if not args: + raise ValueError("At least one input is required") + self.args = args + + def __repr__(self) -> str: + args_iter = iter(self.args) + first = repr(next(args_iter)) + rest = ", ".join(repr(arg) for arg in args_iter) + return f"{first}.concat({rest})" + + def _serialize(self) -> dict[str, Any]: + return {"args": [serialize_xcom_arg(arg) for arg in self.args]} + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: + return cls([deserialize_xcom_arg(arg, dag) for arg in data["args"]]) + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + for arg in self.args: + yield from arg.iter_references() + + def concat(self, *others: XComArg) -> ConcatXComArg: + # Flatten foo.concat(x).concat(y) into one call. + return ConcatXComArg([*self.args, *others]) + + @provide_session + def resolve( + self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True + ) -> Any: + values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args] + for value in values: + if not isinstance(value, (Sequence, dict)): + raise ValueError(f"XCom concat expects sequence or dict, not {type(value).__name__}") + return _ConcatResult(values) + + +_XCOM_ARG_TYPES: Mapping[str, type[XComArg]] = { + "": PlainXComArg, + "concat": ConcatXComArg, + "map": MapXComArg, + "zip": ZipXComArg, +} + + +def serialize_xcom_arg(value: XComArg) -> dict[str, Any]: + """DAG serialization interface.""" + key = next(k for k, v in _XCOM_ARG_TYPES.items() if isinstance(value, v)) + if key: + return {"type": key, **value._serialize()} + return value._serialize() + + +def deserialize_xcom_arg(data: dict[str, Any], dag: DAG) -> XComArg: + """DAG serialization interface.""" + klass = _XCOM_ARG_TYPES[data.get("type", "")] + return klass._deserialize(data, dag) diff --git a/task_sdk/src/airflow/sdk/types.py b/task_sdk/src/airflow/sdk/types.py index 35ee9f8e38cab..f9ec150ce3ae2 100644 --- a/task_sdk/src/airflow/sdk/types.py +++ b/task_sdk/src/airflow/sdk/types.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol, Union if TYPE_CHECKING: from collections.abc import Iterator @@ -25,6 +25,9 @@ from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, BaseAssetUniqueKey from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator + + Operator = Union[BaseOperator, MappedOperator] class DagRunProtocol(Protocol): diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py index 50429d91b018a..3fc7fc18015c7 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/tests/conftest.py @@ -20,6 +20,7 @@ import os from pathlib import Path from typing import TYPE_CHECKING, Any, NoReturn, Protocol +from unittest import mock import pytest @@ -215,3 +216,11 @@ def _make_context_dict( return context.model_dump(exclude_unset=True, mode="json") return _make_context_dict + + +@pytest.fixture +def mock_supervisor_comms(): + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as supervisor_comms: + yield supervisor_comms diff --git a/task_sdk/tests/defintions/test_baseoperator.py b/task_sdk/tests/defintions/test_baseoperator.py index 35f33818dc198..af6bf592f5373 100644 --- a/task_sdk/tests/defintions/test_baseoperator.py +++ b/task_sdk/tests/defintions/test_baseoperator.py @@ -621,3 +621,26 @@ def _do_render(): assert expected_log in caplog.text if not_expected_log: assert not_expected_log not in caplog.text + + +def test_find_mapped_dependants_in_another_group(): + from airflow.decorators import task as task_decorator + from airflow.sdk import TaskGroup + + @task_decorator + def gen(x): + return list(range(x)) + + @task_decorator + def add(x, y): + return x + y + + with DAG(dag_id="test"): + with TaskGroup(group_id="g1"): + gen_result = gen(3) + with TaskGroup(group_id="g2"): + add_result = add.partial(y=1).expand(x=gen_result) + + # breakpoint() + dependants = list(gen_result.operator.iter_mapped_dependants()) + assert dependants == [add_result.operator] diff --git a/task_sdk/tests/defintions/test_mappedoperator.py b/task_sdk/tests/defintions/test_mappedoperator.py new file mode 100644 index 0000000000000..aba7523b5ad39 --- /dev/null +++ b/task_sdk/tests/defintions/test_mappedoperator.py @@ -0,0 +1,301 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime, timedelta + +import pendulum +import pytest + +from airflow.models.param import ParamsDict +from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.definitions.dag import DAG +from airflow.sdk.definitions.mappedoperator import MappedOperator +from airflow.sdk.definitions.xcom_arg import XComArg +from airflow.utils.trigger_rule import TriggerRule + +from tests_common.test_utils.mapping import expand_mapped_task # noqa: F401 +from tests_common.test_utils.mock_operators import ( + MockOperator, +) + +DEFAULT_DATE = datetime(2016, 1, 1) + + +def test_task_mapping_with_dag(): + with DAG("test-dag") as dag: + task1 = BaseOperator(task_id="op1") + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) + finish = MockOperator(task_id="finish") + + task1 >> mapped >> finish + + assert task1.downstream_list == [mapped] + assert mapped in dag.tasks + assert mapped.task_group == dag.task_group + # At parse time there should only be three tasks! + assert len(dag.tasks) == 3 + + assert finish.upstream_list == [mapped] + assert mapped.downstream_list == [finish] + + +# TODO: +# test_task_mapping_with_dag_and_list_of_pandas_dataframe + + +def test_task_mapping_without_dag_context(): + with DAG("test-dag") as dag: + task1 = BaseOperator(task_id="op1") + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) + + task1 >> mapped + + assert isinstance(mapped, MappedOperator) + assert mapped in dag.tasks + assert task1.downstream_list == [mapped] + assert mapped in dag.tasks + # At parse time there should only be two tasks! + assert len(dag.tasks) == 2 + + +def test_task_mapping_default_args(): + default_args = {"start_date": DEFAULT_DATE.now(), "owner": "test"} + with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE, default_args=default_args): + task1 = BaseOperator(task_id="op1") + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) + + task1 >> mapped + + assert mapped.partial_kwargs["owner"] == "test" + assert mapped.start_date == pendulum.instance(default_args["start_date"]) + + +def test_task_mapping_override_default_args(): + default_args = {"retries": 2, "start_date": DEFAULT_DATE.now()} + with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE, default_args=default_args): + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task", retries=1).expand(arg2=literal) + + # retries should be 1 because it is provided as a partial arg + assert mapped.partial_kwargs["retries"] == 1 + # start_date should be equal to default_args["start_date"] because it is not provided as partial arg + assert mapped.start_date == pendulum.instance(default_args["start_date"]) + # owner should be equal to Airflow default owner (airflow) because it is not provided at all + assert mapped.owner == "airflow" + + +def test_map_unknown_arg_raises(): + with pytest.raises(TypeError, match=r"argument 'file'"): + BaseOperator.partial(task_id="a").expand(file=[1, 2, {"a": "b"}]) + + +def test_map_xcom_arg(): + """Test that dependencies are correct when mapping with an XComArg""" + with DAG("test-dag"): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="task_2").expand(arg2=task1.output) + finish = MockOperator(task_id="finish") + + mapped >> finish + + assert task1.downstream_list == [mapped] + + +# def test_map_xcom_arg_multiple_upstream_xcoms(dag_maker, session): + + +def test_partial_on_instance() -> None: + """`.partial` on an instance should fail -- it's only designed to be called on classes""" + with pytest.raises(TypeError): + MockOperator(task_id="a").partial() + + +def test_partial_on_class() -> None: + # Test that we accept args for superclasses too + op = MockOperator.partial(task_id="a", arg1="a", trigger_rule=TriggerRule.ONE_FAILED) + assert op.kwargs["arg1"] == "a" + assert op.kwargs["trigger_rule"] == TriggerRule.ONE_FAILED + + +def test_partial_on_class_invalid_ctor_args() -> None: + """Test that when we pass invalid args to partial(). + + I.e. if an arg is not known on the class or any of its parent classes we error at parse time + """ + with pytest.raises(TypeError, match=r"arguments 'foo', 'bar'"): + MockOperator.partial(task_id="a", foo="bar", bar=2) + + +def test_partial_on_invalid_pool_slots_raises() -> None: + """Test that when we pass an invalid value to pool_slots in partial(), + + i.e. if the value is not an integer, an error is raised at import time.""" + + with pytest.raises(TypeError, match="'<' not supported between instances of 'str' and 'int'"): + MockOperator.partial(task_id="pool_slots_test", pool="test", pool_slots="a").expand(arg1=[1, 2, 3]) + + +# def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expected): + + +# def test_expand_mapped_task_failed_state_in_db(dag_maker, session): + + +# def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): + + +def test_mapped_task_applies_default_args_classic(): + with DAG("test", default_args={"execution_timeout": timedelta(minutes=30)}) as dag: + MockOperator(task_id="simple", arg1=None, arg2=0) + MockOperator.partial(task_id="mapped").expand(arg1=[1], arg2=[2, 3]) + + assert dag.get_task("simple").execution_timeout == timedelta(minutes=30) + assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30) + + +def test_mapped_task_applies_default_args_taskflow(): + with DAG("test", default_args={"execution_timeout": timedelta(minutes=30)}) as dag: + + @dag.task + def simple(arg): + pass + + @dag.task + def mapped(arg): + pass + + simple(arg=0) + mapped.expand(arg=[1, 2]) + + assert dag.get_task("simple").execution_timeout == timedelta(minutes=30) + assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30) + + +@pytest.mark.parametrize( + "dag_params, task_params, expected_partial_params", + [ + pytest.param(None, None, ParamsDict(), id="none"), + pytest.param({"a": -1}, None, ParamsDict({"a": -1}), id="dag"), + pytest.param(None, {"b": -2}, ParamsDict({"b": -2}), id="task"), + pytest.param({"a": -1}, {"b": -2}, ParamsDict({"a": -1, "b": -2}), id="merge"), + ], +) +def test_mapped_expand_against_params(dag_params, task_params, expected_partial_params): + with DAG("test", params=dag_params) as dag: + MockOperator.partial(task_id="t", params=task_params).expand(params=[{"c": "x"}, {"d": 1}]) + + t = dag.get_task("t") + assert isinstance(t, MappedOperator) + assert t.params == expected_partial_params + assert t.expand_input.value == {"params": [{"c": "x"}, {"d": 1}]} + + +# def test_mapped_render_template_fields_validating_operator(dag_maker, session, tmp_path): + + +# def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, tmp_path): + + +# def test_mapped_render_nested_template_fields(dag_maker, session): + + +# def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis, expected): + + +# def test_expand_mapped_task_instance_with_named_index( + + +# def test_expand_mapped_task_task_instance_mutation_hook(dag_maker, session, create_mapped_task) -> None: + + +# def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, map_index, expected): + + +def test_xcomarg_property_of_mapped_operator(): + with DAG("test_xcomarg_property_of_mapped_operator"): + op_a = MockOperator.partial(task_id="a").expand(arg1=["x", "y", "z"]) + + assert op_a.output == XComArg(op_a) + + +def test_set_xcomarg_dependencies_with_mapped_operator(): + with DAG("test_set_xcomargs_dependencies_with_mapped_operator"): + op1 = MockOperator.partial(task_id="op1").expand(arg1=[1, 2, 3]) + op2 = MockOperator.partial(task_id="op2").expand(arg2=["a", "b", "c"]) + op3 = MockOperator(task_id="op3", arg1=op1.output) + op4 = MockOperator(task_id="op4", arg1=[op1.output, op2.output]) + op5 = MockOperator(task_id="op5", arg1={"op1": op1.output, "op2": op2.output}) + + assert op1 in op3.upstream_list + assert op1 in op4.upstream_list + assert op2 in op4.upstream_list + assert op1 in op5.upstream_list + assert op2 in op5.upstream_list + + +# def test_all_xcomargs_from_mapped_tasks_are_consumable(dag_maker, session): + + +def test_task_mapping_with_task_group_context(): + from airflow.sdk.definitions.taskgroup import TaskGroup + + with DAG("test-dag") as dag: + task1 = BaseOperator(task_id="op1") + finish = MockOperator(task_id="finish") + + with TaskGroup("test-group") as group: + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) + + task1 >> group >> finish + + assert task1.downstream_list == [mapped] + assert mapped.upstream_list == [task1] + + assert mapped in dag.tasks + assert mapped.task_group == group + + assert finish.upstream_list == [mapped] + assert mapped.downstream_list == [finish] + + +def test_task_mapping_with_explicit_task_group(): + from airflow.sdk.definitions.taskgroup import TaskGroup + + with DAG("test-dag") as dag: + task1 = BaseOperator(task_id="op1") + finish = MockOperator(task_id="finish") + + group = TaskGroup("test-group") + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task_2", task_group=group).expand(arg2=literal) + + task1 >> group >> finish + + assert task1.downstream_list == [mapped] + assert mapped.upstream_list == [task1] + + assert mapped in dag.tasks + assert mapped.task_group == group + + assert finish.upstream_list == [mapped] + assert mapped.downstream_list == [finish] diff --git a/tests/models/test_taskmixin.py b/task_sdk/tests/defintions/test_minxins.py similarity index 75% rename from tests/models/test_taskmixin.py rename to task_sdk/tests/defintions/test_minxins.py index 7cfd58732889f..83b4d6eabefdf 100644 --- a/tests/models/test_taskmixin.py +++ b/task_sdk/tests/defintions/test_minxins.py @@ -22,10 +22,8 @@ import pytest from airflow.decorators import setup, task, teardown -from airflow.models.baseoperator import BaseOperator -from airflow.operators.empty import EmptyOperator - -pytestmark = pytest.mark.db_test +from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.definitions.dag import DAG def cleared_tasks(dag, task_id): @@ -73,14 +71,14 @@ def my_task(): @pytest.mark.parametrize( "setup_type, work_type, teardown_type", itertools.product(["classic", "taskflow"], repeat=3) ) -def test_as_teardown(dag_maker, setup_type, work_type, teardown_type): +def test_as_teardown(setup_type, work_type, teardown_type): """ Check that as_teardown works properly as implemented in PlainXComArg It should mark the teardown as teardown, and if a task is provided, it should mark that as setup and set it as a direct upstream. """ - with dag_maker() as dag: + with DAG("test") as dag: s1 = make_task(name="s1", type_=setup_type) w1 = make_task(name="w1", type_=work_type) t1 = make_task(name="t1", type_=teardown_type) @@ -106,7 +104,7 @@ def test_as_teardown(dag_maker, setup_type, work_type, teardown_type): @pytest.mark.parametrize( "setup_type, work_type, teardown_type", itertools.product(["classic", "taskflow"], repeat=3) ) -def test_as_teardown_oneline(dag_maker, setup_type, work_type, teardown_type): +def test_as_teardown_oneline(setup_type, work_type, teardown_type): """ Check that as_teardown implementations work properly. Tests all combinations of taskflow and classic. @@ -114,7 +112,7 @@ def test_as_teardown_oneline(dag_maker, setup_type, work_type, teardown_type): and set it as a direct upstream. """ - with dag_maker() as dag: + with DAG("test") as dag: s1 = make_task(name="s1", type_=setup_type) w1 = make_task(name="w1", type_=work_type) t1 = make_task(name="t1", type_=teardown_type) @@ -156,10 +154,10 @@ def test_as_teardown_oneline(dag_maker, setup_type, work_type, teardown_type): @pytest.mark.parametrize("type_", ["classic", "taskflow"]) -def test_cannot_be_both_setup_and_teardown(dag_maker, type_): +def test_cannot_be_both_setup_and_teardown(type_): # can't change a setup task to a teardown task or vice versa for first, second in [("setup", "teardown"), ("teardown", "setup")]: - with dag_maker(): + with DAG("test"): s1 = make_task(name="s1", type_=type_) getattr(s1, f"as_{first}")() with pytest.raises( @@ -168,8 +166,8 @@ def test_cannot_be_both_setup_and_teardown(dag_maker, type_): getattr(s1, f"as_{second}")() -def test_cannot_set_on_failure_fail_dagrun_unless_teardown_classic(dag_maker): - with dag_maker(): +def test_cannot_set_on_failure_fail_dagrun_unless_teardown_classic(): + with DAG("test"): t = make_task(name="t", type_="classic") assert t.is_teardown is False with pytest.raises( @@ -179,7 +177,7 @@ def test_cannot_set_on_failure_fail_dagrun_unless_teardown_classic(dag_maker): t.on_failure_fail_dagrun = True -def test_cannot_set_on_failure_fail_dagrun_unless_teardown_taskflow(dag_maker): +def test_cannot_set_on_failure_fail_dagrun_unless_teardown_taskflow(): @task(on_failure_fail_dagrun=True) def my_bad_task(): pass @@ -188,7 +186,7 @@ def my_bad_task(): def my_ok_task(): pass - with dag_maker(): + with DAG("test"): with pytest.raises( ValueError, match="Cannot set task on_failure_fail_dagrun for " @@ -218,12 +216,12 @@ def my_ok_task(): class TestDependencyMixin: - def test_set_upstream(self, dag_maker): - with dag_maker("test_set_upstream"): - op_a = EmptyOperator(task_id="a") - op_b = EmptyOperator(task_id="b") - op_c = EmptyOperator(task_id="c") - op_d = EmptyOperator(task_id="d") + def test_set_upstream(self): + with DAG("test_set_upstream"): + op_a = BaseOperator(task_id="a") + op_b = BaseOperator(task_id="b") + op_c = BaseOperator(task_id="c") + op_d = BaseOperator(task_id="d") op_d << op_c << op_b << op_a @@ -231,12 +229,12 @@ def test_set_upstream(self, dag_maker): assert [op_b] == op_c.upstream_list assert [op_c] == op_d.upstream_list - def test_set_downstream(self, dag_maker): - with dag_maker("test_set_downstream"): - op_a = EmptyOperator(task_id="a") - op_b = EmptyOperator(task_id="b") - op_c = EmptyOperator(task_id="c") - op_d = EmptyOperator(task_id="d") + def test_set_downstream(self): + with DAG("test_set_downstream"): + op_a = BaseOperator(task_id="a") + op_b = BaseOperator(task_id="b") + op_c = BaseOperator(task_id="c") + op_d = BaseOperator(task_id="d") op_a >> op_b >> op_c >> op_d @@ -244,12 +242,12 @@ def test_set_downstream(self, dag_maker): assert [op_b] == op_c.upstream_list assert [op_c] == op_d.upstream_list - def test_set_upstream_list(self, dag_maker): - with dag_maker("test_set_upstream_list"): - op_a = EmptyOperator(task_id="a") - op_b = EmptyOperator(task_id="b") - op_c = EmptyOperator(task_id="c") - op_d = EmptyOperator(task_id="d") + def test_set_upstream_list(self): + with DAG("test_set_upstream_list"): + op_a = BaseOperator(task_id="a") + op_b = BaseOperator(task_id="b") + op_c = BaseOperator(task_id="c") + op_d = BaseOperator(task_id="d") [op_d, op_c << op_b] << op_a @@ -257,12 +255,12 @@ def test_set_upstream_list(self, dag_maker): assert [op_a] == op_d.upstream_list assert [op_b] == op_c.upstream_list - def test_set_downstream_list(self, dag_maker): - with dag_maker("test_set_downstream_list"): - op_a = EmptyOperator(task_id="a") - op_b = EmptyOperator(task_id="b") - op_c = EmptyOperator(task_id="c") - op_d = EmptyOperator(task_id="d") + def test_set_downstream_list(self): + with DAG("test_set_downstream_list"): + op_a = BaseOperator(task_id="a") + op_b = BaseOperator(task_id="b") + op_c = BaseOperator(task_id="c") + op_d = BaseOperator(task_id="d") op_a >> [op_b >> op_c, op_d] @@ -270,12 +268,12 @@ def test_set_downstream_list(self, dag_maker): assert [op_a] == op_d.upstream_list assert {op_a, op_b} == set(op_c.upstream_list) - def test_set_upstream_inner_list(self, dag_maker): - with dag_maker("test_set_upstream_inner_list"): - op_a = EmptyOperator(task_id="a") - op_b = EmptyOperator(task_id="b") - op_c = EmptyOperator(task_id="c") - op_d = EmptyOperator(task_id="d") + def test_set_upstream_inner_list(self): + with DAG("test_set_upstream_inner_list"): + op_a = BaseOperator(task_id="a") + op_b = BaseOperator(task_id="b") + op_c = BaseOperator(task_id="c") + op_d = BaseOperator(task_id="d") with pytest.raises(AttributeError) as e_info: [op_d << [op_c, op_b]] << op_a @@ -285,12 +283,12 @@ def test_set_upstream_inner_list(self, dag_maker): assert op_c.upstream_list == [] assert {op_b, op_c} == set(op_d.upstream_list) - def test_set_downstream_inner_list(self, dag_maker): - with dag_maker("test_set_downstream_inner_list"): - op_a = EmptyOperator(task_id="a") - op_b = EmptyOperator(task_id="b") - op_c = EmptyOperator(task_id="c") - op_d = EmptyOperator(task_id="d") + def test_set_downstream_inner_list(self): + with DAG("test_set_downstream_inner_list"): + op_a = BaseOperator(task_id="a") + op_b = BaseOperator(task_id="b") + op_c = BaseOperator(task_id="c") + op_d = BaseOperator(task_id="d") op_a >> [[op_b, op_c] >> op_d] @@ -298,13 +296,13 @@ def test_set_downstream_inner_list(self, dag_maker): assert op_c.upstream_list == [] assert {op_b, op_c, op_a} == set(op_d.upstream_list) - def test_set_upstream_list_subarray(self, dag_maker): - with dag_maker("test_set_upstream_list"): - op_a = EmptyOperator(task_id="a") - op_b_1 = EmptyOperator(task_id="b_1") - op_b_2 = EmptyOperator(task_id="b_2") - op_c = EmptyOperator(task_id="c") - op_d = EmptyOperator(task_id="d") + def test_set_upstream_list_subarray(self): + with DAG("test_set_upstream_list"): + op_a = BaseOperator(task_id="a") + op_b_1 = BaseOperator(task_id="b_1") + op_b_2 = BaseOperator(task_id="b_2") + op_c = BaseOperator(task_id="c") + op_d = BaseOperator(task_id="d") with pytest.raises(AttributeError) as e_info: [op_d, op_c << [op_b_1, op_b_2]] << op_a @@ -316,13 +314,13 @@ def test_set_upstream_list_subarray(self, dag_maker): assert op_d.upstream_list == [] assert {op_b_1, op_b_2} == set(op_c.upstream_list) - def test_set_downstream_list_subarray(self, dag_maker): - with dag_maker("test_set_downstream_list"): - op_a = EmptyOperator(task_id="a") - op_b_1 = EmptyOperator(task_id="b_1") - op_b_2 = EmptyOperator(task_id="b2") - op_c = EmptyOperator(task_id="c") - op_d = EmptyOperator(task_id="d") + def test_set_downstream_list_subarray(self): + with DAG("test_set_downstream_list"): + op_a = BaseOperator(task_id="a") + op_b_1 = BaseOperator(task_id="b_1") + op_b_2 = BaseOperator(task_id="b2") + op_c = BaseOperator(task_id="c") + op_d = BaseOperator(task_id="d") op_a >> [[op_b_1, op_b_2] >> op_c, op_d] diff --git a/task_sdk/tests/execution_time/conftest.py b/task_sdk/tests/execution_time/conftest.py index d2a961a5307da..832f2b60ca351 100644 --- a/task_sdk/tests/execution_time/conftest.py +++ b/task_sdk/tests/execution_time/conftest.py @@ -19,7 +19,6 @@ import sys from typing import TYPE_CHECKING -from unittest import mock if TYPE_CHECKING: from datetime import datetime @@ -42,14 +41,6 @@ def disable_capturing(): sys.stdin, sys.stdout, sys.stderr = old_in, old_out, old_err -@pytest.fixture -def mock_supervisor_comms(): - with mock.patch( - "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True - ) as supervisor_comms: - yield supervisor_comms - - @pytest.fixture def mocked_parse(spy_agency): """ diff --git a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py index 7a4d802f6cc15..79dcde6cbd2a5 100644 --- a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py @@ -129,7 +129,7 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None): self.app.dag_bag.sync_to_db("dags-folder", None) session.flush() - mapped.expand_mapped_task(dr.run_id, session=session) + TaskMap.expand_mapped_task(mapped, dr.run_id, session=session) @pytest.fixture def one_task_with_mapped_tis(self, dag_maker, session): diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py b/tests/api_fastapi/core_api/routes/public/test_task_instances.py index 24572c37906dc..f09f6293e9082 100644 --- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py @@ -529,7 +529,7 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None): dagbag.sync_to_db("dags-folder", None) session.flush() - mapped.expand_mapped_task(dr.run_id, session=session) + TaskMap.expand_mapped_task(mapped, dr.run_id, session=session) @pytest.fixture def one_task_with_mapped_tis(self, dag_maker, session): diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index 6ffe935348e6d..bbaa4236f3100 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -26,16 +26,14 @@ from airflow.decorators import setup, task as task_decorator, teardown from airflow.decorators.base import DecoratedMappedOperator from airflow.exceptions import AirflowException, XComNotFound -from airflow.models.baseoperator import BaseOperator -from airflow.models.dag import DAG from airflow.models.expandinput import DictOfListsExpandInput -from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap -from airflow.models.xcom_arg import PlainXComArg, XComArg +from airflow.models.xcom_arg import PlainXComArg +from airflow.sdk import DAG, BaseOperator, TaskGroup, XComArg +from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.utils import timezone from airflow.utils.state import State -from airflow.utils.task_group import TaskGroup from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 0ff8896746a4c..13dad2bc8a4c3 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -411,28 +411,6 @@ def task0(): copy.deepcopy(dag) -@pytest.mark.db_test -def test_find_mapped_dependants_in_another_group(dag_maker): - from airflow.utils.task_group import TaskGroup - - @task_decorator - def gen(x): - return list(range(x)) - - @task_decorator - def add(x, y): - return x + y - - with dag_maker(): - with TaskGroup(group_id="g1"): - gen_result = gen(3) - with TaskGroup(group_id="g2"): - add_result = add.partial(y=1).expand(x=gen_result) - - dependants = list(gen_result.operator.iter_mapped_dependants()) - assert dependants == [add_result.operator] - - def get_states(dr): """ For a given dag run, get a dict of states. diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index fb3519b708928..77b749156af3c 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -1514,26 +1514,24 @@ def test_clear_set_dagrun_state(self, dag_run_state): assert dagrun.state == dag_run_state @pytest.mark.parametrize("dag_run_state", [DagRunState.QUEUED, DagRunState.RUNNING]) - def test_clear_set_dagrun_state_for_mapped_task(self, dag_run_state): + @pytest.mark.need_serialized_dag + def test_clear_set_dagrun_state_for_mapped_task(self, dag_maker, dag_run_state): dag_id = "test_clear_set_dagrun_state" self._clean_up(dag_id) task_id = "t1" - dag = DAG(dag_id, schedule=None, start_date=DEFAULT_DATE, max_active_runs=1) + with dag_maker(dag_id, schedule=None, start_date=DEFAULT_DATE, max_active_runs=1) as dag: - @dag.task - def make_arg_lists(): - return [[1], [2], [{"a": "b"}]] + @task_decorator + def make_arg_lists(): + return [[1], [2], [{"a": "b"}]] - def consumer(value): - print(value) + def consumer(value): + print(value) - mapped = PythonOperator.partial(task_id=task_id, dag=dag, python_callable=consumer).expand( - op_args=make_arg_lists() - ) + PythonOperator.partial(task_id=task_id, python_callable=consumer).expand(op_args=make_arg_lists()) - session = settings.Session() - triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} + session = dag_maker.session dagrun_1 = dag.create_dagrun( run_id="backfill", run_type=DagRunType.BACKFILL_JOB, @@ -1542,8 +1540,10 @@ def consumer(value): logical_date=DEFAULT_DATE, session=session, data_interval=(DEFAULT_DATE, DEFAULT_DATE), - **triggered_by_kwargs, + triggered_by=DagRunTriggeredByType.TEST, ) + # Get the (de)serialized MappedOperator + mapped = dag.get_task(task_id) expand_mapped_task(mapped, dagrun_1.run_id, "make_arg_lists", length=2, session=session) upstream_ti = dagrun_1.get_task_instance("make_arg_lists", session=session) @@ -2732,20 +2732,21 @@ def get_ti_from_db(task): } +@pytest.mark.need_serialized_dag def test_set_task_instance_state_mapped(dag_maker, session): """Test that when setting an individual mapped TI that the other TIs are not affected""" task_id = "t1" with dag_maker(session=session) as dag: - @dag.task + @task_decorator def make_arg_lists(): return [[1], [2], [{"a": "b"}]] def consumer(value): print(value) - mapped = PythonOperator.partial(task_id=task_id, dag=dag, python_callable=consumer).expand( + mapped = PythonOperator.partial(task_id=task_id, python_callable=consumer).expand( op_args=make_arg_lists() ) @@ -2755,6 +2756,8 @@ def consumer(value): run_type=DagRunType.SCHEDULED, state=DagRunState.FAILED, ) + + mapped = dag.get_task(task_id) expand_mapped_task(mapped, dr1.run_id, "make_arg_lists", length=2, session=session) # set_state(future=True) only applies to scheduled runs diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index ff107d8204d72..92130bd3471c8 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -18,12 +18,10 @@ from __future__ import annotations from collections import defaultdict -from datetime import timedelta from typing import TYPE_CHECKING from unittest import mock from unittest.mock import patch -import pendulum import pytest from sqlalchemy import select @@ -31,16 +29,13 @@ from airflow.exceptions import AirflowSkipException from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG -from airflow.models.mappedoperator import MappedOperator -from airflow.models.param import ParamsDict from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap -from airflow.models.xcom_arg import XComArg from airflow.providers.standard.operators.python import PythonOperator +from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.utils.state import TaskInstanceState from airflow.utils.task_group import TaskGroup from airflow.utils.task_instance_session import set_current_task_instance_session -from airflow.utils.trigger_rule import TriggerRule from airflow.utils.xcom import XCOM_RETURN_KEY from tests.models import DEFAULT_DATE @@ -57,25 +52,6 @@ from airflow.sdk.definitions.context import Context -def test_task_mapping_with_dag(): - with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE) as dag: - task1 = BaseOperator(task_id="op1") - literal = ["a", "b", "c"] - mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) - finish = MockOperator(task_id="finish") - - task1 >> mapped >> finish - - assert task1.downstream_list == [mapped] - assert mapped in dag.tasks - assert mapped.task_group == dag.task_group - # At parse time there should only be three tasks! - assert len(dag.tasks) == 3 - - assert finish.upstream_list == [mapped] - assert mapped.downstream_list == [finish] - - @patch("airflow.models.abstractoperator.AbstractOperator.render_template") def test_task_mapping_with_dag_and_list_of_pandas_dataframe(mock_render_template, caplog): class UnrenderableClass: @@ -105,66 +81,6 @@ def execute(self, context: Context): mock_render_template.assert_called() -def test_task_mapping_without_dag_context(): - with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE) as dag: - task1 = BaseOperator(task_id="op1") - literal = ["a", "b", "c"] - mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) - - task1 >> mapped - - assert isinstance(mapped, MappedOperator) - assert mapped in dag.tasks - assert task1.downstream_list == [mapped] - assert mapped in dag.tasks - # At parse time there should only be two tasks! - assert len(dag.tasks) == 2 - - -def test_task_mapping_default_args(): - default_args = {"start_date": DEFAULT_DATE.now(), "owner": "test"} - with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE, default_args=default_args): - task1 = BaseOperator(task_id="op1") - literal = ["a", "b", "c"] - mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) - - task1 >> mapped - - assert mapped.partial_kwargs["owner"] == "test" - assert mapped.start_date == pendulum.instance(default_args["start_date"]) - - -def test_task_mapping_override_default_args(): - default_args = {"retries": 2, "start_date": DEFAULT_DATE.now()} - with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE, default_args=default_args): - literal = ["a", "b", "c"] - mapped = MockOperator.partial(task_id="task", retries=1).expand(arg2=literal) - - # retries should be 1 because it is provided as a partial arg - assert mapped.partial_kwargs["retries"] == 1 - # start_date should be equal to default_args["start_date"] because it is not provided as partial arg - assert mapped.start_date == pendulum.instance(default_args["start_date"]) - # owner should be equal to Airflow default owner (airflow) because it is not provided at all - assert mapped.owner == "airflow" - - -def test_map_unknown_arg_raises(): - with pytest.raises(TypeError, match=r"argument 'file'"): - BaseOperator.partial(task_id="a").expand(file=[1, 2, {"a": "b"}]) - - -def test_map_xcom_arg(): - """Test that dependencies are correct when mapping with an XComArg""" - with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE): - task1 = BaseOperator(task_id="op1") - mapped = MockOperator.partial(task_id="task_2").expand(arg2=task1.output) - finish = MockOperator(task_id="finish") - - mapped >> finish - - assert task1.downstream_list == [mapped] - - def test_map_xcom_arg_multiple_upstream_xcoms(dag_maker, session): """Test that the correct number of downstream tasks are generated when mapping with an XComArg""" @@ -189,12 +105,12 @@ def execute(self, context): ti_1 = dr.get_task_instance("task_1", session) ti_1.run() - ti_2s, _ = task2.expand_mapped_task(dr.run_id, session=session) + ti_2s, _ = TaskMap.expand_mapped_task(task2, dr.run_id, session=session) for ti in ti_2s: ti.refresh_from_task(dag.get_task("task_2")) ti.run() - ti_3s, _ = task3.expand_mapped_task(dr.run_id, session=session) + ti_3s, _ = TaskMap.expand_mapped_task(task3, dr.run_id, session=session) for ti in ti_3s: ti.refresh_from_task(dag.get_task("task_3")) ti.run() @@ -202,37 +118,6 @@ def execute(self, context): assert len(ti_3s) == len(ti_2s) == len(upstream_return) -def test_partial_on_instance() -> None: - """`.partial` on an instance should fail -- it's only designed to be called on classes""" - with pytest.raises(TypeError): - MockOperator(task_id="a").partial() - - -def test_partial_on_class() -> None: - # Test that we accept args for superclasses too - op = MockOperator.partial(task_id="a", arg1="a", trigger_rule=TriggerRule.ONE_FAILED) - assert op.kwargs["arg1"] == "a" - assert op.kwargs["trigger_rule"] == TriggerRule.ONE_FAILED - - -def test_partial_on_class_invalid_ctor_args() -> None: - """Test that when we pass invalid args to partial(). - - I.e. if an arg is not known on the class or any of its parent classes we error at parse time - """ - with pytest.raises(TypeError, match=r"arguments 'foo', 'bar'"): - MockOperator.partial(task_id="a", foo="bar", bar=2) - - -def test_partial_on_invalid_pool_slots_raises() -> None: - """Test that when we pass an invalid value to pool_slots in partial(), - - i.e. if the value is not an integer, an error is raised at import time.""" - - with pytest.raises(TypeError, match="'<' not supported between instances of 'str' and 'int'"): - MockOperator.partial(task_id="pool_slots_test", pool="test", pool_slots="a").expand(arg1=[1, 2, 3]) - - @pytest.mark.parametrize( ["num_existing_tis", "expected"], ( @@ -288,7 +173,7 @@ def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expec session.add(ti) session.flush() - mapped.expand_mapped_task(dr.run_id, session=session) + TaskMap.expand_mapped_task(mapped, dr.run_id, session=session) indices = ( session.query(TaskInstance.map_index, TaskInstance.state) @@ -339,7 +224,7 @@ def test_expand_mapped_task_failed_state_in_db(dag_maker, session): # Make sure we have the faulty state in the database assert indices == [(-1, None), (0, "success"), (1, "success")] - mapped.expand_mapped_task(dr.run_id, session=session) + TaskMap.expand_mapped_task(mapped, dr.run_id, session=session) indices = ( session.query(TaskInstance.map_index, TaskInstance.state) @@ -370,52 +255,6 @@ def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): assert indices == [(-1, TaskInstanceState.SKIPPED)] -def test_mapped_task_applies_default_args_classic(dag_maker): - with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)}) as dag: - MockOperator(task_id="simple", arg1=None, arg2=0) - MockOperator.partial(task_id="mapped").expand(arg1=[1], arg2=[2, 3]) - - assert dag.get_task("simple").execution_timeout == timedelta(minutes=30) - assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30) - - -def test_mapped_task_applies_default_args_taskflow(dag_maker): - with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)}) as dag: - - @dag.task - def simple(arg): - pass - - @dag.task - def mapped(arg): - pass - - simple(arg=0) - mapped.expand(arg=[1, 2]) - - assert dag.get_task("simple").execution_timeout == timedelta(minutes=30) - assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30) - - -@pytest.mark.parametrize( - "dag_params, task_params, expected_partial_params", - [ - pytest.param(None, None, ParamsDict(), id="none"), - pytest.param({"a": -1}, None, ParamsDict({"a": -1}), id="dag"), - pytest.param(None, {"b": -2}, ParamsDict({"b": -2}), id="task"), - pytest.param({"a": -1}, {"b": -2}, ParamsDict({"a": -1, "b": -2}), id="merge"), - ], -) -def test_mapped_expand_against_params(dag_maker, dag_params, task_params, expected_partial_params): - with dag_maker(params=dag_params) as dag: - MockOperator.partial(task_id="t", params=task_params).expand(params=[{"c": "x"}, {"d": 1}]) - - t = dag.get_task("t") - assert isinstance(t, MappedOperator) - assert t.params == expected_partial_params - assert t.expand_input.value == {"params": [{"c": "x"}, {"d": 1}]} - - def test_mapped_render_template_fields_validating_operator(dag_maker, session, tmp_path): file_template_dir = tmp_path / "path" / "to" file_template_dir.mkdir(parents=True, exist_ok=True) @@ -609,7 +448,7 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis session.add(ti) session.flush() - mapped.expand_mapped_task(dr.run_id, session=session) + TaskMap.expand_mapped_task(mapped, dr.run_id, session=session) indices = ( session.query(TaskInstance.map_index, TaskInstance.state) @@ -785,29 +624,6 @@ def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, ses assert ti.task.arg2 == "a" -def test_xcomarg_property_of_mapped_operator(dag_maker): - with dag_maker("test_xcomarg_property_of_mapped_operator"): - op_a = MockOperator.partial(task_id="a").expand(arg1=["x", "y", "z"]) - dag_maker.create_dagrun() - - assert op_a.output == XComArg(op_a) - - -def test_set_xcomarg_dependencies_with_mapped_operator(dag_maker): - with dag_maker("test_set_xcomargs_dependencies_with_mapped_operator"): - op1 = MockOperator.partial(task_id="op1").expand(arg1=[1, 2, 3]) - op2 = MockOperator.partial(task_id="op2").expand(arg2=["a", "b", "c"]) - op3 = MockOperator(task_id="op3", arg1=op1.output) - op4 = MockOperator(task_id="op4", arg1=[op1.output, op2.output]) - op5 = MockOperator(task_id="op5", arg1={"op1": op1.output, "op2": op2.output}) - - assert op1 in op3.upstream_list - assert op1 in op4.upstream_list - assert op2 in op4.upstream_list - assert op1 in op5.upstream_list - assert op2 in op5.upstream_list - - def test_all_xcomargs_from_mapped_tasks_are_consumable(dag_maker, session): class PushXcomOperator(MockOperator): def __init__(self, arg1, **kwargs): @@ -830,48 +646,6 @@ def execute(self, context): ti.run() -def test_task_mapping_with_task_group_context(): - with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE) as dag: - task1 = BaseOperator(task_id="op1") - finish = MockOperator(task_id="finish") - - with TaskGroup("test-group") as group: - literal = ["a", "b", "c"] - mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) - - task1 >> group >> finish - - assert task1.downstream_list == [mapped] - assert mapped.upstream_list == [task1] - - assert mapped in dag.tasks - assert mapped.task_group == group - - assert finish.upstream_list == [mapped] - assert mapped.downstream_list == [finish] - - -def test_task_mapping_with_explicit_task_group(): - with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE) as dag: - task1 = BaseOperator(task_id="op1") - finish = MockOperator(task_id="finish") - - group = TaskGroup("test-group") - literal = ["a", "b", "c"] - mapped = MockOperator.partial(task_id="task_2", task_group=group).expand(arg2=literal) - - task1 >> group >> finish - - assert task1.downstream_list == [mapped] - assert mapped.upstream_list == [task1] - - assert mapped in dag.tasks - assert mapped.task_group == group - - assert finish.upstream_list == [mapped] - assert mapped.downstream_list == [finish] - - class TestMappedSetupTeardown: @staticmethod def get_states(dr): diff --git a/tests/models/test_renderedtifields.py b/tests/models/test_renderedtifields.py index 3bdf9810a1e87..dfb7b652252b4 100644 --- a/tests/models/test_renderedtifields.py +++ b/tests/models/test_renderedtifields.py @@ -33,6 +33,7 @@ from airflow.decorators import task as task_decorator from airflow.models import DagRun, Variable from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF +from airflow.models.taskmap import TaskMap from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.python import PythonOperator from airflow.utils.task_instance_session import set_current_task_instance_session @@ -289,7 +290,7 @@ def test_delete_old_records_mapped( run_id=f"run_{num}", logical_date=dag.start_date + timedelta(days=num) ) - mapped.expand_mapped_task(dr.run_id, session=dag_maker.session) + TaskMap.expand_mapped_task(mapped, dr.run_id, session=dag_maker.session) session.refresh(dr) for ti in dr.task_instances: ti.task = dag.get_task(ti.task_id) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index fca20b4bed00c..6494236a6cc20 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -4055,7 +4055,6 @@ def test_operator_field_with_serialization(self, create_task_instance): deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op) assert deserialized_op.task_type == "EmptyOperator" # Verify that ti.operator field renders correctly "with" Serialization - deserialized_op.dag = ti.task.dag ser_ti = TI(task=deserialized_op, run_id=None) assert ser_ti.operator == "EmptyOperator" assert ser_ti.task.operator_name == "EmptyOperator" @@ -4904,7 +4903,7 @@ def show(value): emit_ti.run() show_task = dag.get_task("show") - mapped_tis, max_map_index = show_task.expand_mapped_task(dag_run.run_id, session=session) + mapped_tis, max_map_index = TaskMap.expand_mapped_task(show_task, dag_run.run_id, session=session) assert max_map_index + 1 == len(mapped_tis) == len(upstream_return) for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): @@ -4938,7 +4937,7 @@ def show(number, letter): ti.run() show_task = dag.get_task("show") - mapped_tis, max_map_index = show_task.expand_mapped_task(dag_run.run_id, session=session) + mapped_tis, max_map_index = TaskMap.expand_mapped_task(show_task, dag_run.run_id, session=session) assert max_map_index + 1 == len(mapped_tis) == 6 for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): @@ -4978,7 +4977,7 @@ def show(a, b): show_task = dag.get_task("show") with pytest.raises(NotFullyPopulated): assert show_task.get_parse_time_mapped_ti_count() - mapped_tis, max_map_index = show_task.expand_mapped_task(dag_run.run_id, session=session) + mapped_tis, max_map_index = TaskMap.expand_mapped_task(show_task, dag_run.run_id, session=session) assert max_map_index + 1 == len(mapped_tis) == 4 for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): @@ -5002,7 +5001,7 @@ def show(a, b): show_task = dag.get_task("show") assert show_task.get_parse_time_mapped_ti_count() == 6 - mapped_tis, max_map_index = show_task.expand_mapped_task(dag_run.run_id, session=session) + mapped_tis, max_map_index = TaskMap.expand_mapped_task(show_task, dag_run.run_id, session=session) assert len(mapped_tis) == 0 # Expanded at parse! assert max_map_index == 5 @@ -5050,7 +5049,9 @@ def cmds(): ti.run() bash_task = dag.get_task("dynamic.bash") - mapped_bash_tis, max_map_index = bash_task.expand_mapped_task(dag_run.run_id, session=session) + mapped_bash_tis, max_map_index = TaskMap.expand_mapped_task( + bash_task, dag_run.run_id, session=session + ) assert max_map_index == 3 # 2 * 2 mapped tasks. for ti in sorted(mapped_bash_tis, key=operator.attrgetter("map_index")): ti.refresh_from_task(bash_task) @@ -5170,7 +5171,7 @@ def add_one(x): ti.run() task_345 = dag.get_task("add_one__1") - for ti in task_345.expand_mapped_task(dagrun.run_id, session=session)[0]: + for ti in TaskMap.expand_mapped_task(task_345, dagrun.run_id, session=session)[0]: ti.refresh_from_task(task_345) ti.run() diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py index b2e885e940833..a654ef1dd4a2e 100644 --- a/tests/models/test_xcom_arg_map.py +++ b/tests/models/test_xcom_arg_map.py @@ -365,7 +365,7 @@ def convert_zipped(zipped): def test_xcom_concat(dag_maker, session): - from airflow.models.xcom_arg import _ConcatResult + from airflow.sdk.definitions.xcom_arg import _ConcatResult agg_results = set() all_results = None diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 2b5d4cce4c7bb..84a63674e5119 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -657,6 +657,9 @@ def validate_deserialized_task( task, ): """Verify non-Airflow operators are casted to BaseOperator or MappedOperator.""" + from airflow.sdk import BaseOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator + assert not isinstance(task, SerializedBaseOperator) assert isinstance(task, (BaseOperator, MappedOperator)) @@ -2650,9 +2653,9 @@ def test_sensor_expand_deserialized_unmap(): deser_unmapped = deser_mapped.unmap(None) ser_normal = SerializedBaseOperator.serialize(normal) deser_normal = SerializedBaseOperator.deserialize(ser_normal) - deser_normal.dag = dag comps = set(BashSensor._comps) comps.remove("task_id") + comps.remove("dag_id") assert all(getattr(deser_unmapped, c, None) == getattr(deser_normal, c, None) for c in comps) @@ -2896,7 +2899,7 @@ def x(arg1, arg2, arg3): def test_mapped_task_group_serde(): from airflow.decorators.task_group import task_group from airflow.models.expandinput import DictOfListsExpandInput - from airflow.utils.task_group import MappedTaskGroup + from airflow.sdk.definitions.taskgroup import MappedTaskGroup with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as dag: diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py index 413c83eb51bdb..1b68f039eaa17 100644 --- a/tests_common/pytest_plugin.py +++ b/tests_common/pytest_plugin.py @@ -784,6 +784,8 @@ def __init__(self): self.dagbag = DagBag(os.devnull, include_examples=False, read_dags_from_db=False) def __enter__(self): + self.serialized_model = None + self.dag.__enter__() if self.want_serialized: return lazy_object_proxy.Proxy(self._serialized_dag) diff --git a/tests_common/test_utils/compat.py b/tests_common/test_utils/compat.py index 3bd4b89dfc1c4..7757a879a7c1a 100644 --- a/tests_common/test_utils/compat.py +++ b/tests_common/test_utils/compat.py @@ -21,7 +21,6 @@ from typing import TYPE_CHECKING, Any, cast from airflow.exceptions import AirflowOptionalProviderFeatureException -from airflow.models import Connection, Operator from airflow.utils.helpers import prune_dict from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS @@ -60,6 +59,7 @@ if TYPE_CHECKING: + from airflow.models import Connection from airflow.models.asset import ( AssetAliasModel, AssetDagRunQueue, @@ -68,6 +68,7 @@ DagScheduleAssetReference, TaskOutletAssetReference, ) + from airflow.sdk.types import Operator else: try: from airflow.models.asset import ( @@ -103,7 +104,7 @@ def deserialize_operator(serialized_operator: dict[str, Any]) -> Operator: # are updated to airflow 2.10+. from airflow.serialization.serialized_objects import BaseSerialization - return cast(Operator, BaseSerialization.deserialize(serialized_operator)) + return BaseSerialization.deserialize(serialized_operator) else: from airflow.serialization.serialized_objects import SerializedBaseOperator diff --git a/tests_common/test_utils/mapping.py b/tests_common/test_utils/mapping.py index 15a57679cf024..dbcb0ecc364e6 100644 --- a/tests_common/test_utils/mapping.py +++ b/tests_common/test_utils/mapping.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator def expand_mapped_task( @@ -41,4 +41,4 @@ def expand_mapped_task( ) session.flush() - mapped.expand_mapped_task(run_id, session=session) + TaskMap.expand_mapped_task(mapped, run_id, session=session) diff --git a/tests_common/test_utils/system_tests.py b/tests_common/test_utils/system_tests.py index 9be67c06822ed..5428eae9114ab 100644 --- a/tests_common/test_utils/system_tests.py +++ b/tests_common/test_utils/system_tests.py @@ -25,6 +25,7 @@ from airflow.utils.state import DagRunState if TYPE_CHECKING: + from airflow.models.dagrun import DagRun from airflow.sdk.definitions.context import Context logger = logging.getLogger(__name__) @@ -32,6 +33,8 @@ def get_test_run(dag, **test_kwargs): def callback(context: Context): + if TYPE_CHECKING: + assert isinstance(context["dag_run"], DagRun) ti = context["dag_run"].get_task_instances() if not ti: logger.warning("could not retrieve tasks that ran in the DAG, cannot display a summary")