diff --git a/airflow/api_fastapi/execution_api/app.py b/airflow/api_fastapi/execution_api/app.py index 61283dc2cf87f..2b85be363f25f 100644 --- a/airflow/api_fastapi/execution_api/app.py +++ b/airflow/api_fastapi/execution_api/app.py @@ -77,8 +77,14 @@ def custom_openapi() -> dict: def get_extra_schemas() -> dict[str, dict]: """Get all the extra schemas that are not part of the main FastAPI app.""" - from airflow.api_fastapi.execution_api.datamodels import taskinstance + from airflow.api_fastapi.execution_api.datamodels.taskinstance import TaskInstance + from airflow.executors.workloads import BundleInfo + from airflow.utils.state import TerminalTIState return { - "TaskInstance": taskinstance.TaskInstance.model_json_schema(), + "TaskInstance": TaskInstance.model_json_schema(), + "BundleInfo": BundleInfo.model_json_schema(), + # Include the combined state enum too. In the datamodels we separate out SUCCESS from the other states + # as that has different payload requirements + "TerminalTIState": {"type": "string", "enum": list(TerminalTIState)}, } diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index d4e268fd12fe7..0c8e5eb1b6a3f 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -14,11 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - from __future__ import annotations import uuid from datetime import timedelta +from enum import Enum from typing import Annotated, Any, Literal, Union from pydantic import ( @@ -60,15 +60,20 @@ class TIEnterRunningPayload(StrictBaseModel): """When the task started executing""" +# Create an enum to give a nice name in the generated datamodels +class TerminalStateNonSuccess(str, Enum): + """TaskInstance states that can be reported without extra information.""" + + FAILED = TerminalTIState.FAILED + SKIPPED = TerminalTIState.SKIPPED + REMOVED = TerminalTIState.REMOVED + FAIL_WITHOUT_RETRY = TerminalTIState.FAIL_WITHOUT_RETRY + + class TITerminalStatePayload(StrictBaseModel): """Schema for updating TaskInstance to a terminal state except SUCCESS state.""" - state: Literal[ - TerminalTIState.FAILED, - TerminalTIState.SKIPPED, - TerminalTIState.REMOVED, - TerminalTIState.FAIL_WITHOUT_RETRY, - ] + state: TerminalStateNonSuccess end_date: UtcDateTime """When the task completed executing""" @@ -242,6 +247,8 @@ class TIRunContext(BaseModel): connections: Annotated[list[ConnectionResponse], Field(default_factory=list)] """Connections that can be accessed by the task instance.""" + upstream_map_indexes: dict[str, int] | None = None + class PrevSuccessfulDagRunResponse(BaseModel): """Schema for response with previous successful DagRun information for Task Template Context.""" diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index fa7355c0c876d..1c605b5cadcf0 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -65,6 +65,7 @@ status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"}, status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"}, }, + response_model_exclude_unset=True, ) def ti_run( task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep @@ -149,6 +150,7 @@ def ti_run( DR.run_type, DR.conf, DR.logical_date, + DR.external_trigger, ).filter_by(dag_id=dag_id, run_id=run_id) ).one_or_none() diff --git a/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow/api_fastapi/execution_api/routes/xcoms.py index f330744536b33..9a8cef62e07a0 100644 --- a/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -21,7 +21,8 @@ import logging from typing import Annotated -from fastapi import Body, HTTPException, Query, status +from fastapi import Body, Depends, HTTPException, Query, Response, status +from sqlalchemy.sql.selectable import Select from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.router import AirflowRouter @@ -30,6 +31,7 @@ from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse from airflow.models.taskmap import TaskMap from airflow.models.xcom import BaseXCom +from airflow.utils.db import get_query_count # TODO: Add dependency on JWT token router = AirflowRouter( @@ -42,20 +44,15 @@ log = logging.getLogger(__name__) -@router.get( - "/{dag_id}/{run_id}/{task_id}/{key}", - responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}}, -) -def get_xcom( +async def xcom_query( dag_id: str, run_id: str, task_id: str, key: str, - token: deps.TokenDep, session: SessionDep, - map_index: Annotated[int, Query()] = -1, -) -> XComResponse: - """Get an Airflow XCom from database - not other XCom Backends.""" + token: deps.TokenDep, + map_index: Annotated[int | None, Query()] = None, +) -> Select: if not has_xcom_access(dag_id, run_id, task_id, key, token): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -65,29 +62,87 @@ def get_xcom( }, ) - # We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend. - # This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead - # retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one` - # (which automatically deserializes using the backend), we avoid potential - # performance hits from retrieving large data files into the API server. query = BaseXCom.get_many( run_id=run_id, key=key, task_ids=task_id, dag_ids=dag_id, map_indexes=map_index, - limit=1, session=session, ) + return query.with_entities(BaseXCom.value) + + +@router.head( + "/{dag_id}/{run_id}/{task_id}/{key}", + responses={ + status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}, + status.HTTP_200_OK: { + "description": "Metadata about the number of matching XCom values", + "headers": { + "Content-Range": { + "pattern": r"^map_indexes \d+$", + "description": "The number of (mapped) XCom values found for this task.", + }, + }, + }, + }, + description="Return the count of the number of XCom values found via the Content-Range response header", +) +def head_xcom( + response: Response, + token: deps.TokenDep, + session: SessionDep, + xcom_query: Annotated[Select, Depends(xcom_query)], + map_index: Annotated[int | None, Query()] = None, +) -> None: + """Get the count of XComs from database - not other XCom Backends.""" + if map_index is not None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"reason": "invalid_request", "message": "Cannot specify map_index in a HEAD request"}, + ) + + count = get_query_count(xcom_query, session=session) + # Tell the caller how many items in this query. We define a custom range unit (HTTP spec only defines + # "bytes" but we can add our own) + response.headers["Content-Range"] = f"map_indexes {count}" + + +@router.get( + "/{dag_id}/{run_id}/{task_id}/{key}", + responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}}, + description="Get a single XCom Value", +) +def get_xcom( + session: SessionDep, + dag_id: str, + run_id: str, + task_id: str, + key: str, + xcom_query: Annotated[Select, Depends(xcom_query)], + map_index: Annotated[int, Query()] = -1, +) -> XComResponse: + """Get an Airflow XCom from database - not other XCom Backends.""" + # The xcom_query allows no map_index to be passed. This endpoint should always return just a single item, + # so we override that query value + + xcom_query = xcom_query.filter_by(map_index=map_index) + + # We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend. + # This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead + # retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one` + # (which automatically deserializes using the backend), we avoid potential + # performance hits from retrieving large data files into the API server. - result = query.with_entities(BaseXCom.value).first() + result = xcom_query.limit(1).first() if result is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={ "reason": "not_found", - "message": f"XCom with key '{key}' not found for task '{task_id}' in DAG '{dag_id}'", + "message": f"XCom with {key=} {map_index=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}", }, ) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 9db6d058adeb5..b79819bedcc5e 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -55,8 +55,6 @@ from airflow.utils.types import NOTSET if TYPE_CHECKING: - from sqlalchemy.orm import Session - from airflow.models.expandinput import ( ExpandInput, OperatorExpandArgument, @@ -184,7 +182,9 @@ def __init__( kwargs_to_upstream: dict[str, Any] | None = None, **kwargs, ) -> None: - task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group")) + if not getattr(self, "_BaseOperator__from_mapped", False): + # If we are being created from calling unmap(), then don't mangle the task id + task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group")) self.python_callable = python_callable kwargs_to_upstream = kwargs_to_upstream or {} op_args = op_args or [] @@ -218,10 +218,10 @@ def __init__( The function signature broke while assigning defaults to context key parameters. The decorator is replacing the signature - > {python_callable.__name__}({', '.join(str(param) for param in signature.parameters.values())}) + > {python_callable.__name__}({", ".join(str(param) for param in signature.parameters.values())}) with - > {python_callable.__name__}({', '.join(str(param) for param in parameters)}) + > {python_callable.__name__}({", ".join(str(param) for param in parameters)}) which isn't valid: {err} """ @@ -568,13 +568,11 @@ def __attrs_post_init__(self): super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self) XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value) - def _expand_mapped_kwargs( - self, context: Mapping[str, Any], session: Session, *, include_xcom: bool - ) -> tuple[Mapping[str, Any], set[int]]: + def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: # We only use op_kwargs_expand_input so this must always be empty. if self.expand_input is not EXPAND_INPUT_EMPTY: raise AssertionError(f"unexpected expand_input: {self.expand_input}") - op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, session, include_xcom=include_xcom) + op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context) return {"op_kwargs": op_kwargs}, resolved_oids def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index b8fb54f6966fd..98fd977c59128 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -27,7 +27,6 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.models.expandinput import NotFullyPopulated from airflow.sdk.definitions._internal.abstractoperator import ( AbstractOperator as TaskSDKAbstractOperator, NotMapped as NotMapped, # Re-export this for compat @@ -237,6 +236,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence ) from airflow.models.baseoperator import BaseOperator as DBBaseOperator + from airflow.models.expandinput import NotFullyPopulated try: total_length: int | None = DBBaseOperator.get_mapped_ti_count(self, run_id, session=session) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 425cbaca68880..3361fe33df6f7 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -848,11 +848,20 @@ def _(cls, task: TaskSDKAbstractOperator, run_id: str, *, session: Session) -> i @get_mapped_ti_count.register(MappedOperator) @classmethod def _(cls, task: MappedOperator, run_id: str, *, session: Session) -> int: - from airflow.serialization.serialized_objects import _ExpandInputRef + from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef exp_input = task._get_specified_expand_input() if isinstance(exp_input, _ExpandInputRef): exp_input = exp_input.deref(task.dag) + # TODO: TaskSDK This is only needed to support `dag.test()` etc until we port it over to use the + # task sdk runner. + if not hasattr(exp_input, "get_total_map_length"): + exp_input = _ExpandInputRef( + type(exp_input).EXPAND_INPUT_TYPE, + BaseSerialization.deserialize(BaseSerialization.serialize(exp_input.value)), + ) + exp_input = exp_input.deref(task.dag) + current_count = exp_input.get_total_map_length(run_id, session=session) group = task.get_closest_mapped_task_group() @@ -878,18 +887,24 @@ def _(cls, group: TaskGroup, run_id: str, *, session: Session) -> int: :raise NotFullyPopulated: If upstream tasks are not all complete yet. :return: Total number of mapped TIs this task should have. """ + from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef - def iter_mapped_task_groups(group) -> Iterator[MappedTaskGroup]: + def iter_mapped_task_group_lengths(group) -> Iterator[int]: while group is not None: if isinstance(group, MappedTaskGroup): - yield group + exp_input = group._expand_input + # TODO: TaskSDK This is only needed to support `dag.test()` etc until we port it over to use the + # task sdk runner. + if not hasattr(exp_input, "get_total_map_length"): + exp_input = _ExpandInputRef( + type(exp_input).EXPAND_INPUT_TYPE, + BaseSerialization.deserialize(BaseSerialization.serialize(exp_input.value)), + ) + exp_input = exp_input.deref(group.dag) + yield exp_input.get_total_map_length(run_id, session=session) group = group.parent_group - groups = iter_mapped_task_groups(group) - return functools.reduce( - operator.mul, - (g._expand_input.get_total_map_length(run_id, session=session) for g in groups), - ) + return functools.reduce(operator.mul, iter_mapped_task_group_lengths(group)) def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None: diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index d7deaf475d0f3..d9bf14c3f983b 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -63,7 +63,6 @@ from airflow.models.backfill import Backfill from airflow.models.base import Base, StringID from airflow.models.dag_version import DagVersion -from airflow.models.expandinput import NotFullyPopulated from airflow.models.taskinstance import TaskInstance as TI from airflow.models.tasklog import LogTemplate from airflow.models.taskmap import TaskMap @@ -1354,6 +1353,7 @@ def _check_for_removed_or_restored_tasks( """ from airflow.models.baseoperator import BaseOperator + from airflow.models.expandinput import NotFullyPopulated tis = self.get_task_instances(session=session) @@ -1491,6 +1491,7 @@ def _create_tasks( :param task_creator: Function to create task instances """ from airflow.models.baseoperator import BaseOperator + from airflow.models.expandinput import NotFullyPopulated map_indexes: Iterable[int] for task in tasks: @@ -1562,6 +1563,7 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) -> for more details. """ from airflow.models.baseoperator import BaseOperator + from airflow.models.expandinput import NotFullyPopulated from airflow.settings import task_instance_mutation_hook try: diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index 8fb35f7032965..72f7b0eca22ba 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -17,108 +17,54 @@ # under the License. from __future__ import annotations -import collections.abc import functools import operator -from collections.abc import Iterable, Mapping, Sequence, Sized -from typing import TYPE_CHECKING, Any, NamedTuple, Union +from collections.abc import Iterable, Sized +from typing import TYPE_CHECKING, Any -import attr - -from airflow.sdk.definitions._internal.mixins import ResolveMixin -from airflow.utils.session import NEW_SESSION, provide_session +import attrs if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.xcom_arg import XComArg - from airflow.sdk.types import Operator - from airflow.serialization.serialized_objects import _ExpandInputRef + from airflow.models.xcom_arg import SchedulerXComArg from airflow.typing_compat import TypeGuard -ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] - -# Each keyword argument to expand() can be an XComArg, sequence, or dict (not -# any mapping since we need the value to be ordered). -OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, dict[str, Any]] - -# The single argument of expand_kwargs() can be an XComArg, or a list with each -# element being either an XComArg or a dict. -OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]] - - -@attr.define(kw_only=True) -class MappedArgument(ResolveMixin): - """ - Stand-in stub for task-group-mapping arguments. - - This is very similar to an XComArg, but resolved differently. Declared here - (instead of in the task group module) to avoid import cycles. - """ - - _input: ExpandInput - _key: str - - def iter_references(self) -> Iterable[tuple[Operator, str]]: - yield from self._input.iter_references() - - @provide_session - def resolve( - self, context: Mapping[str, Any], *, include_xcom: bool = True, session: Session = NEW_SESSION - ) -> Any: - data, _ = self._input.resolve(context, session=session, include_xcom=include_xcom) - return data[self._key] - - -# To replace tedious isinstance() checks. -def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]: - from airflow.models.xcom_arg import XComArg - - return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str) - - -# To replace tedious isinstance() checks. -def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]: - from airflow.models.xcom_arg import XComArg - - return not isinstance(v, (MappedArgument, XComArg)) - - -# To replace tedious isinstance() checks. -def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]: - from airflow.models.xcom_arg import XComArg - - return isinstance(v, (MappedArgument, XComArg)) - - -class NotFullyPopulated(RuntimeError): - """ - Raise when ``get_map_lengths`` cannot populate all mapping metadata. - - This is generally due to not all upstream tasks have finished when the - function is called. - """ +from airflow.sdk.definitions._internal.expandinput import ( + DictOfListsExpandInput, + ExpandInput, + ListOfDictsExpandInput, + MappedArgument, + NotFullyPopulated, + OperatorExpandArgument, + OperatorExpandKwargsArgument, + is_mappable, +) - def __init__(self, missing: set[str]) -> None: - self.missing = missing +__all__ = [ + "DictOfListsExpandInput", + "ListOfDictsExpandInput", + "MappedArgument", + "NotFullyPopulated", + "OperatorExpandArgument", + "OperatorExpandKwargsArgument", + "is_mappable", +] - def __str__(self) -> str: - keys = ", ".join(repr(k) for k in sorted(self.missing)) - return f"Failed to populate all mapping metadata; missing: {keys}" +def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | SchedulerXComArg]: + from airflow.models.xcom_arg import SchedulerXComArg -class DictOfListsExpandInput(NamedTuple): - """ - Storage type of a mapped operator's mapped kwargs. + return isinstance(v, (MappedArgument, SchedulerXComArg)) - This is created from ``expand(**kwargs)``. - """ - value: dict[str, OperatorExpandArgument] +@attrs.define +class SchedulerDictOfListsExpandInput: + value: dict def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: """Generate kwargs with values available on parse-time.""" - return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v)) + return ((k, v) for k, v in self.value.items() if not _needs_run_time_resolution(v)) def get_parse_time_mapped_ti_count(self) -> int: if not self.value: @@ -136,13 +82,12 @@ def _get_map_lengths(self, run_id: str, *, session: Session) -> dict[str, int]: If any arguments are not known right now (upstream task not finished), they will not be present in the dict. """ + from airflow.models.xcom_arg import SchedulerXComArg, get_task_map_length # TODO: This initiates one database call for each XComArg. Would it be # more efficient to do one single db call and unpack the value here? def _get_length(v: OperatorExpandArgument) -> int | None: - from airflow.models.xcom_arg import get_task_map_length - - if _needs_run_time_resolution(v): + if isinstance(v, SchedulerXComArg): return get_task_map_length(v, run_id, session=session) # Unfortunately a user-defined TypeGuard cannot apply negative type @@ -164,150 +109,34 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int: lengths = self._get_map_lengths(run_id, session=session) return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1) - def _expand_mapped_field( - self, key: str, value: Any, context: Mapping[str, Any], *, session: Session, include_xcom: bool - ) -> Any: - if _needs_run_time_resolution(value): - value = ( - value.resolve(context, session=session, include_xcom=include_xcom) - if include_xcom - else str(value) - ) - map_index = context["ti"].map_index - if map_index < 0: - raise RuntimeError("can't resolve task-mapping argument without expanding") - all_lengths = self._get_map_lengths(context["run_id"], session=session) - - def _find_index_for_this_field(index: int) -> int: - # Need to use the original user input to retain argument order. - for mapped_key in reversed(self.value): - mapped_length = all_lengths[mapped_key] - if mapped_length < 1: - raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}") - if mapped_key == key: - return index % mapped_length - index //= mapped_length - return -1 - - found_index = _find_index_for_this_field(map_index) - if found_index < 0: - return value - if isinstance(value, collections.abc.Sequence): - return value[found_index] - if not isinstance(value, dict): - raise TypeError(f"can't map over value of type {type(value)}") - for i, (k, v) in enumerate(value.items()): - if i == found_index: - return k, v - raise IndexError(f"index {map_index} is over mapped length") - - def iter_references(self) -> Iterable[tuple[Operator, str]]: - from airflow.models.xcom_arg import XComArg - - for x in self.value.values(): - if isinstance(x, XComArg): - yield from x.iter_references() - - def resolve( - self, context: Mapping[str, Any], session: Session, *, include_xcom: bool = True - ) -> tuple[Mapping[str, Any], set[int]]: - data = { - k: self._expand_mapped_field(k, v, context, session=session, include_xcom=include_xcom) - for k, v in self.value.items() - } - literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()} - resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys} - return data, resolved_oids - - -def _describe_type(value: Any) -> str: - if value is None: - return "None" - return type(value).__name__ - - -class ListOfDictsExpandInput(NamedTuple): - """ - Storage type of a mapped operator's mapped kwargs. - - This is created from ``expand_kwargs(xcom_arg)``. - """ - value: OperatorExpandKwargsArgument +@attrs.define +class SchedulerListOfDictsExpandInput: + value: list def get_parse_time_mapped_ti_count(self) -> int: - if isinstance(self.value, collections.abc.Sized): + if isinstance(self.value, Sized): return len(self.value) raise NotFullyPopulated({"expand_kwargs() argument"}) def get_total_map_length(self, run_id: str, *, session: Session) -> int: from airflow.models.xcom_arg import get_task_map_length - if isinstance(self.value, collections.abc.Sized): + if isinstance(self.value, Sized): return len(self.value) length = get_task_map_length(self.value, run_id, session=session) if length is None: raise NotFullyPopulated({"expand_kwargs() argument"}) return length - def iter_references(self) -> Iterable[tuple[Operator, str]]: - from airflow.models.xcom_arg import XComArg - - if isinstance(self.value, XComArg): - yield from self.value.iter_references() - else: - for x in self.value: - if isinstance(x, XComArg): - yield from x.iter_references() - - def resolve( - self, context: Mapping[str, Any], session: Session, *, include_xcom: bool = True - ) -> tuple[Mapping[str, Any], set[int]]: - map_index = context["ti"].map_index - if map_index < 0: - raise RuntimeError("can't resolve task-mapping argument without expanding") - - mapping: Any - if isinstance(self.value, collections.abc.Sized): - mapping = self.value[map_index] - if not isinstance(mapping, collections.abc.Mapping): - mapping = mapping.resolve(context, session, include_xcom=include_xcom) - elif include_xcom: - mappings = self.value.resolve(context, session, include_xcom=include_xcom) - if not isinstance(mappings, collections.abc.Sequence): - raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}") - mapping = mappings[map_index] - - if not isinstance(mapping, collections.abc.Mapping): - raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]") - - for key in mapping: - if not isinstance(key, str): - raise ValueError( - f"expand_kwargs() input dict keys must all be str, " - f"but {key!r} is of type {_describe_type(key)}" - ) - # filter out parse time resolved values from the resolved_oids - resolved_oids = {id(v) for k, v in mapping.items() if not _is_parse_time_mappable(v)} - - return mapping, resolved_oids - EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value. _EXPAND_INPUT_TYPES = { - "dict-of-lists": DictOfListsExpandInput, - "list-of-dicts": ListOfDictsExpandInput, + "dict-of-lists": SchedulerDictOfListsExpandInput, + "list-of-dicts": SchedulerListOfDictsExpandInput, } -def get_map_type_key(expand_input: ExpandInput | _ExpandInputRef) -> str: - from airflow.serialization.serialized_objects import _ExpandInputRef - - if isinstance(expand_input, _ExpandInputRef): - return expand_input.key - return next(k for k, v in _EXPAND_INPUT_TYPES.items() if isinstance(expand_input, v)) - - def create_expand_input(kind: str, value: Any) -> ExpandInput: return _EXPAND_INPUT_TYPES[kind](value) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index e7352ad1323d3..9b0b90b5814f5 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -17,10 +17,7 @@ # under the License. from __future__ import annotations -import contextlib -import copy -from collections.abc import Mapping -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import attrs @@ -51,49 +48,6 @@ class MappedOperator(TaskSDKMappedOperator, AbstractOperator): # type: ignore[misc] # It complains about weight_rule being different """Object representing a mapped operator in a DAG.""" - def _expand_mapped_kwargs( - self, context: Mapping[str, Any], session: Session, *, include_xcom: bool - ) -> tuple[Mapping[str, Any], set[int]]: - """ - Get the kwargs to create the unmapped operator. - - This exists because taskflow operators expand against op_kwargs, not the - entire operator kwargs dict. - """ - return self._get_specified_expand_input().resolve(context, session, include_xcom=include_xcom) - - def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: - """ - Get init kwargs to unmap the underlying operator class. - - :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``. - """ - if strict: - prevent_duplicates( - self.partial_kwargs, - mapped_kwargs, - fail_reason="unmappable or already specified", - ) - - # If params appears in the mapped kwargs, we need to merge it into the - # partial params, overriding existing keys. - params = copy.copy(self.params) - with contextlib.suppress(KeyError): - params.update(mapped_kwargs["params"]) - - # Ordering is significant; mapped kwargs should override partial ones, - # and the specially handled params should be respected. - return { - "task_id": self.task_id, - "dag": self.dag, - "task_group": self.task_group, - "start_date": self.start_date, - "end_date": self.end_date, - **self.partial_kwargs, - **mapped_kwargs, - "params": params, - } - def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool: """ Get the start_from_trigger value of the current abstract operator. @@ -107,7 +61,7 @@ def expand_start_from_trigger(self, *, context: Context, session: Session) -> bo if not self.start_trigger_args: return False - mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False) + mapped_kwargs, _ = self._expand_mapped_kwargs(context) if self._disallow_kwargs_override: prevent_duplicates( self.partial_kwargs, @@ -129,7 +83,7 @@ def expand_start_trigger_args(self, *, context: Context, session: Session) -> St if not self.start_trigger_args: return None - mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False) + mapped_kwargs, _ = self._expand_mapped_kwargs(context) if self._disallow_kwargs_override: prevent_duplicates( self.partial_kwargs, diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index 5da2f24957ff0..2787ce1b82993 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -291,7 +291,7 @@ def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str, SerializedDA @property def data(self) -> dict | None: # use __data_cache to avoid decompress and loads - if not hasattr(self, "__data_cache") or self.__data_cache is None: + if not hasattr(self, "_SerializedDagModel__data_cache") or self.__data_cache is None: if self._data_compressed: self.__data_cache = json.loads(zlib.decompress(self._data_compressed)) else: diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index 3b7d21d7b386b..78e1a631e080d 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -97,7 +97,7 @@ def skip( dag_id: str, run_id: str, tasks: Iterable[DAGNode], - map_index: int = -1, + map_index: int | None = -1, session: Session = NEW_SESSION, ): """ @@ -126,6 +126,9 @@ def skip( if task_id is not None: from airflow.models.xcom import XCom + if map_index is None: + map_index = -1 + XCom.set( key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list}, diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 41d5d89ac4836..ceb0acfcf9946 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -573,39 +573,6 @@ def _xcom_pull( default: Any = None, run_id: str | None = None, ) -> Any: - """ - Pull XComs that optionally meet certain criteria. - - :param key: A key for the XCom. If provided, only XComs with matching - keys will be returned. The default key is ``'return_value'``, also - available as constant ``XCOM_RETURN_KEY``. This key is automatically - given to XComs returned by tasks (as opposed to being pushed - manually). To remove the filter, pass *None*. - :param task_ids: Only XComs from tasks with matching ids will be - pulled. Pass *None* to remove the filter. - :param dag_id: If provided, only pulls XComs from this DAG. If *None* - (default), the DAG of the calling task is used. - :param map_indexes: If provided, only pull XComs with matching indexes. - If *None* (default), this is inferred from the task(s) being pulled - (see below for details). - :param include_prior_dates: If False, only XComs from the current - logical_date are returned. If *True*, XComs from previous dates - are returned as well. - :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id. - If *None* (default), the run_id of the calling task is used. - - When pulling one single task (``task_id`` is *None* or a str) without - specifying ``map_indexes``, the return value is inferred from whether - the specified task is mapped. If not, value from the one single task - instance is returned. If the task to pull is mapped, an iterator (not a - list) yielding XComs from mapped task instances is returned. In either - case, ``default`` (*None* if not specified) is returned if no matching - XComs are found. - - When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is - a non-str iterable), a list of matching XComs is returned. Elements in - the list is ordered by item ordering in ``task_id`` and ``map_index``. - """ if dag_id is None: dag_id = ti.dag_id if run_id is None: @@ -634,11 +601,10 @@ def _xcom_pull( return default if map_indexes is not None or first.map_index < 0: return XCom.deserialize_value(first) - return LazyXComSelectSequence.from_select( - query.with_entities(XCom.value).order_by(None).statement, - order_by=[XCom.map_index], - session=session, - ) + + # raise RuntimeError("Nothing should hit this anymore") + + # TODO: TaskSDK: We should remove this, but many tests still currently call `ti.run()`. See #45549 # At this point either task_ids or map_indexes is explicitly multi-value. # Order return values to match task_ids and map_indexes ordering. @@ -1035,14 +1001,18 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]: ) context["expanded_ti_count"] = expanded_ti_count if expanded_ti_count: - context["_upstream_map_indexes"] = { # type: ignore[typeddict-unknown-key] - upstream.task_id: task_instance.get_relevant_upstream_map_indexes( - upstream, - expanded_ti_count, - session=session, - ) - for upstream in task.upstream_list - } + setattr( + task_instance, + "_upstream_map_indexes", + { + upstream.task_id: task_instance.get_relevant_upstream_map_indexes( + upstream, + expanded_ti_count, + session=session, + ) + for upstream in task.upstream_list + }, + ) except NotMapped: pass @@ -3380,39 +3350,8 @@ def xcom_pull( default: Any = None, run_id: str | None = None, ) -> Any: - """ - Pull XComs that optionally meet certain criteria. - - :param key: A key for the XCom. If provided, only XComs with matching - keys will be returned. The default key is ``'return_value'``, also - available as constant ``XCOM_RETURN_KEY``. This key is automatically - given to XComs returned by tasks (as opposed to being pushed - manually). To remove the filter, pass *None*. - :param task_ids: Only XComs from tasks with matching ids will be - pulled. Pass *None* to remove the filter. - :param dag_id: If provided, only pulls XComs from this DAG. If *None* - (default), the DAG of the calling task is used. - :param map_indexes: If provided, only pull XComs with matching indexes. - If *None* (default), this is inferred from the task(s) being pulled - (see below for details). - :param include_prior_dates: If False, only XComs from the current - logical_date are returned. If *True*, XComs from previous dates - are returned as well. - :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id. - If *None* (default), the run_id of the calling task is used. - - When pulling one single task (``task_id`` is *None* or a str) without - specifying ``map_indexes``, the return value is inferred from whether - the specified task is mapped. If not, value from the one single task - instance is returned. If the task to pull is mapped, an iterator (not a - list) yielding XComs from mapped task instances is returned. In either - case, ``default`` (*None* if not specified) is returned if no matching - XComs are found. - - When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is - a non-str iterable), a list of matching XComs is returned. Elements in - the list is ordered by item ordering in ``task_id`` and ``map_index``. - """ + """:meta private:""" # noqa: D400 + # This is only kept for compatibility in tests for now while AIP-72 is in progress. return _xcom_pull( ti=self, task_ids=task_ids, diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 078a9e6ff5223..1d885fb5bd1b0 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -17,46 +17,108 @@ from __future__ import annotations +from collections.abc import Sequence from functools import singledispatch -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +import attrs from sqlalchemy import func, or_, select from sqlalchemy.orm import Session from airflow.sdk.definitions._internal.types import ArgNotSet from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.xcom_arg import ( - ConcatXComArg, - MapXComArg, - PlainXComArg, XComArg, - ZipXComArg, ) from airflow.utils.db import exists_query from airflow.utils.state import State +from airflow.utils.types import NOTSET from airflow.utils.xcom import XCOM_RETURN_KEY __all__ = ["XComArg", "get_task_map_length"] if TYPE_CHECKING: - from airflow.models.expandinput import OperatorExpandArgument + from airflow.models.dag import DAG as SchedulerDAG + from airflow.models.operator import Operator + from airflow.typing_compat import Self + + +@attrs.define +class SchedulerXComArg: + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + """ + Deserialize an XComArg. + + The implementation should be the inverse function to ``serialize``, + implementing given a data dict converted from this XComArg derivative, + how the original XComArg should be created. DAG serialization relies on + additional information added in ``serialize_xcom_arg`` to dispatch data + dicts to the correct ``_deserialize`` information, so this function does + not need to validate whether the incoming data contains correct keys. + """ + raise NotImplementedError() + + +@attrs.define +class SchedulerPlainXComArg(SchedulerXComArg): + operator: Operator + key: str + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + return cls(dag.get_task(data["task_id"]), data["key"]) + + +@attrs.define +class SchedulerMapXComArg(SchedulerXComArg): + arg: SchedulerXComArg + callables: Sequence[str] + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + # We are deliberately NOT deserializing the callables. These are shown + # in the UI, and displaying a function object is useless. + return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) + + +@attrs.define +class SchedulerConcatXComArg(SchedulerXComArg): + args: Sequence[SchedulerXComArg] + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + return cls([deserialize_xcom_arg(arg, dag) for arg in data["args"]]) + + +@attrs.define +class SchedulerZipXComArg(SchedulerXComArg): + args: Sequence[SchedulerXComArg] + fillvalue: Any + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + return cls( + [deserialize_xcom_arg(arg, dag) for arg in data["args"]], + fillvalue=data.get("fillvalue", NOTSET), + ) @singledispatch -def get_task_map_length(xcom_arg: OperatorExpandArgument, run_id: str, *, session: Session) -> int | None: +def get_task_map_length(xcom_arg: SchedulerXComArg, run_id: str, *, session: Session) -> int | None: # The base implementation -- specific XComArg subclasses have specialised implementations - raise NotImplementedError() + raise NotImplementedError(f"get_task_map_length not implemented for {type(xcom_arg)}") @get_task_map_length.register -def _(xcom_arg: PlainXComArg, run_id: str, *, session: Session): +def _(xcom_arg: SchedulerPlainXComArg, run_id: str, *, session: Session): from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.models.xcom import XCom dag_id = xcom_arg.operator.dag_id task_id = xcom_arg.operator.task_id - is_mapped = isinstance(xcom_arg.operator, MappedOperator) + is_mapped = xcom_arg.operator.is_mapped or isinstance(xcom_arg.operator, MappedOperator) if is_mapped: unfinished_ti_exists = exists_query( @@ -92,12 +154,12 @@ def _(xcom_arg: PlainXComArg, run_id: str, *, session: Session): @get_task_map_length.register -def _(xcom_arg: MapXComArg, run_id: str, *, session: Session): +def _(xcom_arg: SchedulerMapXComArg, run_id: str, *, session: Session): return get_task_map_length(xcom_arg.arg, run_id, session=session) @get_task_map_length.register -def _(xcom_arg: ZipXComArg, run_id: str, *, session: Session): +def _(xcom_arg: SchedulerZipXComArg, run_id: str, *, session: Session): all_lengths = (get_task_map_length(arg, run_id, session=session) for arg in xcom_arg.args) ready_lengths = [length for length in all_lengths if length is not None] if len(ready_lengths) != len(xcom_arg.args): @@ -108,9 +170,23 @@ def _(xcom_arg: ZipXComArg, run_id: str, *, session: Session): @get_task_map_length.register -def _(xcom_arg: ConcatXComArg, run_id: str, *, session: Session): +def _(xcom_arg: SchedulerConcatXComArg, run_id: str, *, session: Session): all_lengths = (get_task_map_length(arg, run_id, session=session) for arg in xcom_arg.args) ready_lengths = [length for length in all_lengths if length is not None] if len(ready_lengths) != len(xcom_arg.args): return None # If any of the referenced XComs is not ready, we are not ready either. return sum(ready_lengths) + + +def deserialize_xcom_arg(data: dict[str, Any], dag: SchedulerDAG): + """DAG serialization interface.""" + klass = _XCOM_ARG_TYPES[data.get("type", "")] + return klass._deserialize(data, dag) + + +_XCOM_ARG_TYPES: dict[str, type[SchedulerXComArg]] = { + "": SchedulerPlainXComArg, + "concat": SchedulerConcatXComArg, + "map": SchedulerMapXComArg, + "zip": SchedulerZipXComArg, +} diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 89ea668f02ffe..08d032e873c6a 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -45,10 +45,10 @@ from airflow.models.expandinput import ( EXPAND_INPUT_EMPTY, create_expand_input, - get_map_type_key, ) from airflow.models.taskinstance import SimpleTaskInstance from airflow.models.taskinstancekey import TaskInstanceKey +from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg from airflow.providers_manager import ProvidersManager from airflow.sdk.definitions.asset import ( Asset, @@ -66,7 +66,7 @@ from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.param import Param, ParamsDict from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup -from airflow.sdk.definitions.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg +from airflow.sdk.definitions.xcom_arg import XComArg, serialize_xcom_arg from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors from airflow.serialization.dag_dependency import DagDependency from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding @@ -493,7 +493,7 @@ class _XComRef(NamedTuple): data: dict - def deref(self, dag: DAG) -> XComArg: + def deref(self, dag: DAG) -> SchedulerXComArg: return deserialize_xcom_arg(self.data, dag) @@ -1195,7 +1195,7 @@ def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]: if TYPE_CHECKING: # Let Mypy check the input type for us! _ExpandInputRef.validate_expand_input_value(expansion_kwargs.value) serialized_op[op._expand_input_attr] = { - "type": get_map_type_key(expansion_kwargs), + "type": type(expansion_kwargs).EXPAND_INPUT_TYPE, "value": cls.serialize(expansion_kwargs.value), } @@ -1792,7 +1792,7 @@ def serialize_task_group(cls, task_group: TaskGroup) -> dict[str, Any] | None: if isinstance(task_group, MappedTaskGroup): expand_input = task_group._expand_input encoded["expand_input"] = { - "type": get_map_type_key(expand_input), + "type": expand_input.EXPAND_INPUT_TYPE, "value": cls.serialize(expand_input.value), } encoded["is_mapped"] = True diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 0415542c6ca8c..62308605dbb37 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -66,7 +66,6 @@ from sqlalchemy.orm import Session from sqlalchemy.sql.expression import Select, TextClause - from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.types import OutletEventAccessorsProtocol # NOTE: Please keep this in sync with the following: @@ -293,24 +292,6 @@ def context_merge(context: Context, *args: Any, **kwargs: Any) -> None: context.update(*args, **kwargs) -def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: - """ - Update context after task unmapping. - - Since ``get_template_context()`` is called before unmapping, the context - contains information about the mapped task. We need to do some in-place - updates to ensure the template context reflects the unmapped task instead. - - :meta private: - """ - from airflow.sdk.definitions.param import process_params - - context["task"] = context["ti"].task = task - context["params"] = process_params( - context["dag"], task, context["dag_run"].conf, suppress_exception=False - ) - - def context_copy_partial(source: Context, keys: Container[str]) -> Context: """ Create a context by copying items under selected keys in ``source``. diff --git a/airflow/utils/setup_teardown.py b/airflow/utils/setup_teardown.py index 3108657d30ac2..32d19d316844c 100644 --- a/airflow/utils/setup_teardown.py +++ b/airflow/utils/setup_teardown.py @@ -23,8 +23,8 @@ if TYPE_CHECKING: from airflow.models.taskmixin import DependencyMixin - from airflow.models.xcom_arg import PlainXComArg from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator + from airflow.sdk.definitions.xcom_arg import PlainXComArg class BaseSetupTeardownContext: @@ -335,7 +335,7 @@ class SetupTeardownContext(BaseSetupTeardownContext): @staticmethod def add_task(task: AbstractOperator | PlainXComArg): """Add task to context manager.""" - from airflow.models.xcom_arg import PlainXComArg + from airflow.sdk.definitions.xcom_arg import PlainXComArg if not SetupTeardownContext.active: raise AirflowException("Cannot add task to context outside the context manager.") diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index b74921f75d533..2ebb95709913c 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -19,61 +19,15 @@ from __future__ import annotations -import functools -import operator -from collections.abc import Iterator from typing import TYPE_CHECKING import airflow.sdk.definitions.taskgroup if TYPE_CHECKING: - from sqlalchemy.orm import Session - from airflow.typing_compat import TypeAlias TaskGroup: TypeAlias = airflow.sdk.definitions.taskgroup.TaskGroup - - -class MappedTaskGroup(airflow.sdk.definitions.taskgroup.MappedTaskGroup): # noqa: D101 - # TODO: Rename this to SerializedMappedTaskGroup perhaps? - - def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: - """ - Return mapped task groups in the hierarchy. - - Groups are returned from the closest to the outmost. If *self* is a - mapped task group, it is returned first. - - :meta private: - """ - group: TaskGroup | None = self - while group is not None: - if isinstance(group, MappedTaskGroup): - yield group - group = group.parent_group - - def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: - """ - Return the number of instances a task in this group should be mapped to at run time. - - This considers both literal and non-literal mapped arguments, and the - result is therefore available when all depended tasks have finished. The - return value should be identical to ``parse_time_mapped_ti_count`` if - all mapped arguments are literal. - - If this group is inside mapped task groups, all the nested counts are - multiplied and accounted. - - :meta private: - - :raise NotFullyPopulated: If upstream tasks are not all complete yet. - :return: Total number of mapped TIs this task should have. - """ - groups = self.iter_mapped_task_groups() - return functools.reduce( - operator.mul, - (g._expand_input.get_total_map_length(run_id, session=session) for g in groups), - ) +MappedTaskGroup: TypeAlias = airflow.sdk.definitions.taskgroup.MappedTaskGroup def task_group_to_dict(task_item_or_group): diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py index 1d77f459c9863..d4e2e76a66af6 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -489,7 +489,7 @@ def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool } map_index = ti.map_index - if map_index >= 0: + if map_index is not None and map_index >= 0: labels["map_index"] = str(map_index) if include_try_number: diff --git a/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py b/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py index 7aefd2971d4e6..8d2b1ae1eee9e 100644 --- a/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py @@ -241,7 +241,6 @@ def pull_xcom(self, context: Context) -> list: key=self.key, task_ids=self.task_id, dag_id=self.dag_id, - map_indexes=map_index, ) or [] ) diff --git a/providers/standard/tests/provider_tests/standard/decorators/test_python.py b/providers/standard/tests/provider_tests/standard/decorators/test_python.py index 84a58faae72e2..ae661a17dd583 100644 --- a/providers/standard/tests/provider_tests/standard/decorators/test_python.py +++ b/providers/standard/tests/provider_tests/standard/decorators/test_python.py @@ -18,7 +18,7 @@ import sys import typing from collections import namedtuple -from datetime import date, timedelta +from datetime import date from typing import Union import pytest @@ -26,15 +26,10 @@ from airflow.decorators import setup, task as task_decorator, teardown from airflow.decorators.base import DecoratedMappedOperator from airflow.exceptions import AirflowException, XComNotFound -from airflow.models.baseoperator import BaseOperator -from airflow.models.dag import DAG -from airflow.models.expandinput import DictOfListsExpandInput from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap -from airflow.models.xcom_arg import PlainXComArg, XComArg from airflow.utils import timezone from airflow.utils.state import State -from airflow.utils.task_group import TaskGroup from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType @@ -44,10 +39,17 @@ from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import DAG, BaseOperator, TaskGroup, XComArg + from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.utils.types import DagRunTriggeredByType else: + from airflow.models.baseoperator import BaseOperator + from airflow.models.dag import DAG # type: ignore[assignment] + from airflow.models.expandinput import DictOfListsExpandInput from airflow.models.mappedoperator import MappedOperator + from airflow.models.xcom_arg import XComArg + from airflow.utils.task_group import TaskGroup pytestmark = pytest.mark.db_test @@ -733,6 +735,11 @@ def double(number: int): def test_partial_mapped_decorator() -> None: + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.definitions.xcom_arg import PlainXComArg + else: + from airflow.models.xcom_arg import PlainXComArg # type: ignore[attr-defined, no-redef] + @task_decorator def product(number: int, multiple: int): return number * multiple @@ -789,38 +796,12 @@ def task2(arg1, arg2): ... dec = run.task_instance_scheduling_decisions(session=session) assert [ti.task_id for ti in dec.schedulable_tis] == ["task2"] ti = dec.schedulable_tis[0] - unmapped = ti.task.unmap((ti.get_template_context(session), session)) - assert set(unmapped.op_kwargs) == {"arg1", "arg2"} - - -def test_mapped_decorator_converts_partial_kwargs(dag_maker, session): - with dag_maker(session=session): - - @task_decorator - def task1(arg): - return ["x" * arg] - - @task_decorator(retry_delay=30) - def task2(arg1, arg2): ... - task2.partial(arg1=1).expand(arg2=task1.expand(arg=[1, 2])) - - run = dag_maker.create_dagrun() - - # Expand and run task1. - dec = run.task_instance_scheduling_decisions(session=session) - assert [ti.task_id for ti in dec.schedulable_tis] == ["task1", "task1"] - for ti in dec.schedulable_tis: - ti.run(session=session) - assert not isinstance(ti.task, MappedOperator) - assert ti.task.retry_delay == timedelta(seconds=300) # Operator default. - - # Expand task2. - dec = run.task_instance_scheduling_decisions(session=session) - assert [ti.task_id for ti in dec.schedulable_tis] == ["task2", "task2"] - for ti in dec.schedulable_tis: + if AIRFLOW_V_3_0_PLUS: + unmapped = ti.task.unmap((ti.get_template_context(session),)) + else: unmapped = ti.task.unmap((ti.get_template_context(session), session)) - assert unmapped.retry_delay == timedelta(seconds=30) + assert set(unmapped.op_kwargs) == {"arg1", "arg2"} def test_mapped_render_template_fields(dag_maker, session): diff --git a/providers/standard/tests/provider_tests/standard/operators/test_datetime.py b/providers/standard/tests/provider_tests/standard/operators/test_datetime.py index b15ab5a117d7a..36df3942e9402 100644 --- a/providers/standard/tests/provider_tests/standard/operators/test_datetime.py +++ b/providers/standard/tests/provider_tests/standard/operators/test_datetime.py @@ -74,13 +74,13 @@ def base_tests_setup(self, dag_maker): self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) - self.dr = dag_maker.create_dagrun( - run_id="manual__", - start_date=DEFAULT_DATE, - logical_date=DEFAULT_DATE, - state=State.RUNNING, - data_interval=(DEFAULT_DATE, DEFAULT_DATE), - ) + self.dr = dag_maker.create_dagrun( + run_id="manual__", + start_date=DEFAULT_DATE, + logical_date=DEFAULT_DATE, + state=State.RUNNING, + data_interval=(DEFAULT_DATE, DEFAULT_DATE), + ) def teardown_method(self): with create_session() as session: diff --git a/providers/standard/tests/provider_tests/standard/operators/test_python.py b/providers/standard/tests/provider_tests/standard/operators/test_python.py index d87de0f62dfb4..9f0f3515c696a 100644 --- a/providers/standard/tests/provider_tests/standard/operators/test_python.py +++ b/providers/standard/tests/provider_tests/standard/operators/test_python.py @@ -42,7 +42,6 @@ from slugify import slugify from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG -from airflow.decorators import task_group from airflow.exceptions import ( AirflowException, DeserializingResultError, @@ -110,8 +109,12 @@ def base_tests_setup(self, request, create_serialized_task_instance_of_operator, self.run_id = f"run_{slugify(request.node.name, max_length=40)}" self.ds_templated = self.default_date.date().isoformat() self.ti_maker = create_serialized_task_instance_of_operator + self.dag_maker = dag_maker self.dag_non_serialized = self.dag_maker(self.dag_id, template_searchpath=TEMPLATE_SEARCHPATH).dag + # We need to entre the context in order to the factory to create things + with self.dag_maker: + ... clear_db_runs() yield clear_db_runs() @@ -138,6 +141,10 @@ def default_kwargs(**kwargs): return kwargs def create_dag_run(self) -> DagRun: + from airflow.models.serialized_dag import SerializedDagModel + + # Update the serialized DAG with any tasks added after initial dag was created + self.dag_maker.serialized_model = SerializedDagModel(self.dag_non_serialized) return self.dag_maker.create_dagrun( state=DagRunState.RUNNING, start_date=self.dag_maker.start_date, @@ -753,39 +760,6 @@ def test_xcom_push_skipped_tasks(self): "skipped": ["empty_task"] } - def test_mapped_xcom_push_skipped_tasks(self, session): - with self.dag_non_serialized: - - @task_group - def group(x): - short_op_push_xcom = ShortCircuitOperator( - task_id="push_xcom_from_shortcircuit", - python_callable=lambda arg: arg % 2 == 0, - op_kwargs={"arg": x}, - ) - empty_task = EmptyOperator(task_id="empty_task") - short_op_push_xcom >> empty_task - - group.expand(x=[0, 1]) - dr = self.create_dag_run() - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run() - # dr.run(start_date=self.default_date, end_date=self.default_date) - tis = dr.get_task_instances() - - assert ( - tis[0].xcom_pull(task_ids="group.push_xcom_from_shortcircuit", key="return_value", map_indexes=0) - is True - ) - assert ( - tis[0].xcom_pull(task_ids="group.push_xcom_from_shortcircuit", key="skipmixin_key", map_indexes=0) - is None - ) - assert tis[0].xcom_pull( - task_ids="group.push_xcom_from_shortcircuit", key="skipmixin_key", map_indexes=1 - ) == {"skipped": ["group.empty_task"]} - virtualenv_string_args: list[str] = [] diff --git a/scripts/ci/pre_commit/template_context_key_sync.py b/scripts/ci/pre_commit/template_context_key_sync.py index 01977598ceefd..525a49ca59642 100755 --- a/scripts/ci/pre_commit/template_context_key_sync.py +++ b/scripts/ci/pre_commit/template_context_key_sync.py @@ -73,12 +73,13 @@ def extract_keys_from_dict(node: ast.Dict) -> typing.Iterator[str]: raise ValueError("'context' is not assigned a dictionary literal") yield from extract_keys_from_dict(context_assignment.value) - # Handle keys added conditionally in `if self._ti_context_from_server` + # Handle keys added conditionally in `if x := self._ti_context_from_server` for stmt in fn_get_template_context.body: if ( isinstance(stmt, ast.If) - and isinstance(stmt.test, ast.Attribute) - and stmt.test.attr == "_ti_context_from_server" + and isinstance(stmt.test, ast.NamedExpr) + and isinstance(stmt.test.value, ast.Attribute) + and stmt.test.value.attr == "_ti_context_from_server" ): for sub_stmt in stmt.body: # Get keys from `context_from_server` assignment diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml index 7b0c7ed7b2b14..22872e3f4f9d6 100644 --- a/task_sdk/pyproject.toml +++ b/task_sdk/pyproject.toml @@ -106,7 +106,7 @@ exclude_also = [ [dependency-groups] codegen = [ - "datamodel-code-generator[http]>=0.26.3", + "datamodel-code-generator[http]>=0.26.5", ] [tool.black] diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index 84992b9ab1945..b1a85fc78b2c8 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -39,6 +39,7 @@ ConnectionResponse, DagRunType, PrevSuccessfulDagRunResponse, + TerminalStateNonSuccess, TerminalTIState, TIDeferredStatePayload, TIEnterRunningPayload, @@ -132,10 +133,12 @@ def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext: resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json()) return TIRunContext.model_validate_json(resp.read()) - def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime): + def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: datetime): """Tell the API server that this TI has reached a terminal state.""" + if state == TerminalTIState.SUCCESS: + raise ValueError("Logic error. SUCCESS state should call the `succeed` function instead") # TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing. - body = TITerminalStatePayload(end_date=when, state=TerminalTIState(state)) + body = TITerminalStatePayload(end_date=when, state=TerminalStateNonSuccess(state)) self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) def succeed(self, id: uuid.UUID, when: datetime, task_outlets, outlet_events): @@ -253,6 +256,17 @@ class XComOperations: def __init__(self, client: Client): self.client = client + def head(self, dag_id: str, run_id: str, task_id: str, key: str) -> int: + """Get the number of mapped XCom values.""" + resp = self.client.head(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}") + + # content_range: str | None + if not (content_range := resp.headers["Content-Range"]) or not content_range.startswith( + "map_indexes " + ): + raise RuntimeError(f"Unable to parse Content-Range header from HEAD {resp.request.url}") + return int(content_range[len("map_indexes ") :]) + def get( self, dag_id: str, run_id: str, task_id: str, key: str, map_index: int | None = None ) -> XComResponse: @@ -260,7 +274,7 @@ def get( # TODO: check if we need to use map_index as params in the uri # ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81 params = {} - if map_index is not None: + if map_index is not None and map_index >= 0: params.update({"map_index": map_index}) try: resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params) @@ -290,7 +304,7 @@ def set( # TODO: check if we need to use map_index as params in the uri # ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81 params = {} - if map_index: + if map_index is not None and map_index >= 0: params = {"map_index": map_index} self.client.post(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params, json=value) # Any error from the server will anyway be propagated down to the supervisor, @@ -428,7 +442,7 @@ def xcoms(self) -> XComOperations: @lru_cache() # type: ignore[misc] @property def assets(self) -> AssetOperations: - """Operations related to XComs.""" + """Operations related to Assets.""" return AssetOperations(self) diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index 4f9ad8adb6b93..ac459e817f2eb 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -1,3 +1,7 @@ +# generated by datamodel-codegen: +# filename: http://0.0.0.0:9091/execution/openapi.json +# version: 0.26.5 + # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -14,11 +18,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# generated by datamodel-codegen: -# filename: http://0.0.0.0:9091/execution/openapi.json -# version: 0.26.3 - from __future__ import annotations from datetime import datetime, timedelta @@ -26,7 +25,7 @@ from typing import Annotated, Any, Literal from uuid import UUID -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, RootModel class AssetProfile(BaseModel): @@ -97,6 +96,10 @@ class IntermediateTIState(str, Enum): DEFERRED = "deferred" +class JsonValue(RootModel[Any]): + root: Any + + class PrevSuccessfulDagRunResponse(BaseModel): """ Schema for response with previous successful DagRun information for Task Template Context. @@ -200,12 +203,11 @@ class TITargetStatePayload(BaseModel): state: IntermediateTIState -class TerminalTIState(str, Enum): +class TerminalStateNonSuccess(str, Enum): """ - States that a Task Instance can be in that indicate it has reached a terminal state. + TaskInstance states that can be reported without extra information. """ - SUCCESS = "success" FAILED = "failed" SKIPPED = "skipped" REMOVED = "removed" @@ -248,11 +250,6 @@ class XComResponse(BaseModel): value: Annotated[Any, Field(title="Value")] -class BundleInfo(BaseModel): - name: str - version: str | None = None - - class TaskInstance(BaseModel): """ Schema for TaskInstance model with minimal required fields needed for Runtime. @@ -266,10 +263,27 @@ class TaskInstance(BaseModel): dag_id: Annotated[str, Field(title="Dag Id")] run_id: Annotated[str, Field(title="Run Id")] try_number: Annotated[int, Field(title="Try Number")] - map_index: Annotated[int, Field(title="Map Index")] = -1 + map_index: Annotated[int | None, Field(title="Map Index")] = -1 hostname: Annotated[str | None, Field(title="Hostname")] = None +class BundleInfo(BaseModel): + """ + Schema for telling task which bundle to run with. + """ + + name: Annotated[str, Field(title="Name")] + version: Annotated[str | None, Field(title="Version")] = None + + +class TerminalTIState(str, Enum): + SUCCESS = "success" + FAILED = "failed" + SKIPPED = "skipped" + REMOVED = "removed" + FAIL_WITHOUT_RETRY = "fail_without_retry" + + class DagRun(BaseModel): """ Schema for DagRun model with minimal required fields needed for Runtime. @@ -288,7 +302,7 @@ class DagRun(BaseModel): end_date: Annotated[datetime | None, Field(title="End Date")] = None run_type: DagRunType conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None - external_trigger: Annotated[bool, Field(title="External Trigger")] = False + external_trigger: Annotated[bool | None, Field(title="External Trigger")] = False class HTTPValidationError(BaseModel): @@ -304,6 +318,7 @@ class TIRunContext(BaseModel): max_tries: Annotated[int, Field(title="Max Tries")] variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None + upstream_map_indexes: Annotated[dict[str, int] | None, Field(title="Upstream Map Indexes")] = None class TITerminalStatePayload(BaseModel): @@ -314,6 +329,5 @@ class TITerminalStatePayload(BaseModel): model_config = ConfigDict( extra="forbid", ) - - state: TerminalTIState + state: TerminalStateNonSuccess end_date: Annotated[datetime, Field(title="End Date")] diff --git a/task_sdk/src/airflow/sdk/definitions/_internal/expandinput.py b/task_sdk/src/airflow/sdk/definitions/_internal/expandinput.py new file mode 100644 index 0000000000000..18e7d11fb08ae --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/_internal/expandinput.py @@ -0,0 +1,278 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import functools +import operator +from collections.abc import Iterable, Mapping, Sequence, Sized +from typing import TYPE_CHECKING, Any, ClassVar, Union + +import attrs + +from airflow.sdk.definitions._internal.mixins import ResolveMixin + +if TYPE_CHECKING: + from airflow.sdk.definitions.xcom_arg import XComArg + from airflow.sdk.types import Operator + from airflow.typing_compat import TypeGuard + +ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] + +# Each keyword argument to expand() can be an XComArg, sequence, or dict (not +# any mapping since we need the value to be ordered). +OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, dict[str, Any]] + +# The single argument of expand_kwargs() can be an XComArg, or a list with each +# element being either an XComArg or a dict. +OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]] + + +class NotFullyPopulated(RuntimeError): + """ + Raise when ``get_map_lengths`` cannot populate all mapping metadata. + + This is generally due to not all upstream tasks have finished when the + function is called. + """ + + def __init__(self, missing: set[str]) -> None: + self.missing = missing + + def __str__(self) -> str: + keys = ", ".join(repr(k) for k in sorted(self.missing)) + return f"Failed to populate all mapping metadata; missing: {keys}" + + +# To replace tedious isinstance() checks. +def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]: + from airflow.sdk.definitions.xcom_arg import XComArg + + return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str) + + +# To replace tedious isinstance() checks. +def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]: + from airflow.models.xcom_arg import XComArg + + return not isinstance(v, (MappedArgument, XComArg)) + + +# To replace tedious isinstance() checks. +def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]: + from airflow.models.xcom_arg import XComArg + + return isinstance(v, (MappedArgument, XComArg)) + + +@attrs.define(kw_only=True) +class MappedArgument(ResolveMixin): + """ + Stand-in stub for task-group-mapping arguments. + + This is very similar to an XComArg, but resolved differently. Declared here + (instead of in the task group module) to avoid import cycles. + """ + + _input: ExpandInput = attrs.field() + _key: str + + @_input.validator + def _validate_input(self, _, input): + if isinstance(input, DictOfListsExpandInput): + for value in input.value.values(): + if isinstance(value, MappedArgument): + raise ValueError("Nested Mapped TaskGroups are not yet supported") + + def iter_references(self) -> Iterable[tuple[Operator, str]]: + yield from self._input.iter_references() + + def resolve(self, context: Mapping[str, Any]) -> Any: + data, _ = self._input.resolve(context) + return data[self._key] + + +@attrs.define() +class DictOfListsExpandInput(ResolveMixin): + """ + Storage type of a mapped operator's mapped kwargs. + + This is created from ``expand(**kwargs)``. + """ + + value: dict[str, OperatorExpandArgument] + + EXPAND_INPUT_TYPE: ClassVar[str] = "dict-of-lists" + + def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: + """Generate kwargs with values available on parse-time.""" + return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v)) + + def get_parse_time_mapped_ti_count(self) -> int: + if not self.value: + return 0 + literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()] + if len(literal_values) != len(self.value): + literal_keys = (k for k, _ in self._iter_parse_time_resolved_kwargs()) + raise NotFullyPopulated(set(self.value).difference(literal_keys)) + return functools.reduce(operator.mul, literal_values, 1) + + def _get_map_lengths( + self, resolved_vals: dict[str, Sized], upstream_map_indexes: dict[str, int] + ) -> dict[str, int]: + """ + Return dict of argument name to map length. + + If any arguments are not known right now (upstream task not finished), + they will not be present in the dict. + """ + + # TODO: This initiates one API call for each XComArg. Would it be + # more efficient to do one single call and unpack the value here? + def _get_length(k: str, v: OperatorExpandArgument) -> int | None: + from airflow.sdk.definitions.xcom_arg import XComArg, get_task_map_length + + if isinstance(v, XComArg): + return get_task_map_length(v, resolved_vals[k], upstream_map_indexes) + + # Unfortunately a user-defined TypeGuard cannot apply negative type + # narrowing. https://github.com/python/typing/discussions/1013 + if TYPE_CHECKING: + assert isinstance(v, Sized) + return len(v) + + map_lengths = { + k: res for k, v in self.value.items() if v is not None if (res := _get_length(k, v)) is not None + } + if len(map_lengths) < len(self.value): + raise NotFullyPopulated(set(self.value).difference(map_lengths)) + return map_lengths + + def _expand_mapped_field(self, key: str, value: Any, map_index: int, all_lengths: dict[str, int]) -> Any: + def _find_index_for_this_field(index: int) -> int: + # Need to use the original user input to retain argument order. + for mapped_key in reversed(self.value): + mapped_length = all_lengths[mapped_key] + if mapped_length < 1: + raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}") + if mapped_key == key: + return index % mapped_length + index //= mapped_length + return -1 + + found_index = _find_index_for_this_field(map_index) + if found_index < 0: + return value + if isinstance(value, Sequence): + return value[found_index] + if not isinstance(value, dict): + raise TypeError(f"can't map over value of type {type(value)}") + for i, (k, v) in enumerate(value.items()): + if i == found_index: + return k, v + raise IndexError(f"index {map_index} is over mapped length") + + def iter_references(self) -> Iterable[tuple[Operator, str]]: + from airflow.models.xcom_arg import XComArg + + for x in self.value.values(): + if isinstance(x, XComArg): + yield from x.iter_references() + + def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: + map_index: int | None = context["ti"].map_index + if map_index is None or map_index < 0: + raise RuntimeError("can't resolve task-mapping argument without expanding") + + upstream_map_indexes = getattr(context["ti"], "_upstream_map_indexes", {}) + + # TODO: This initiates one API call for each XComArg. Would it be + # more efficient to do one single call and unpack the value here? + resolved = { + k: v.resolve(context) if _needs_run_time_resolution(v) else v for k, v in self.value.items() + } + + all_lengths = self._get_map_lengths(resolved, upstream_map_indexes) + + data = {k: self._expand_mapped_field(k, v, map_index, all_lengths) for k, v in resolved.items()} + literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()} + resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys} + return data, resolved_oids + + +def _describe_type(value: Any) -> str: + if value is None: + return "None" + return type(value).__name__ + + +@attrs.define() +class ListOfDictsExpandInput(ResolveMixin): + """ + Storage type of a mapped operator's mapped kwargs. + + This is created from ``expand_kwargs(xcom_arg)``. + """ + + value: OperatorExpandKwargsArgument + + EXPAND_INPUT_TYPE: ClassVar[str] = "list-of-dicts" + + def get_parse_time_mapped_ti_count(self) -> int: + if isinstance(self.value, Sized): + return len(self.value) + raise NotFullyPopulated({"expand_kwargs() argument"}) + + def iter_references(self) -> Iterable[tuple[Operator, str]]: + from airflow.models.xcom_arg import XComArg + + if isinstance(self.value, XComArg): + yield from self.value.iter_references() + else: + for x in self.value: + if isinstance(x, XComArg): + yield from x.iter_references() + + def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: + map_index = context["ti"].map_index + if map_index < 0: + raise RuntimeError("can't resolve task-mapping argument without expanding") + + mapping: Any = None + if isinstance(self.value, Sized): + mapping = self.value[map_index] + if not isinstance(mapping, Mapping): + mapping = mapping.resolve(context) + else: + mappings = self.value.resolve(context) + if not isinstance(mappings, Sequence): + raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}") + mapping = mappings[map_index] + + if not isinstance(mapping, Mapping): + raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]") + + for key in mapping: + if not isinstance(key, str): + raise ValueError( + f"expand_kwargs() input dict keys must all be str, " + f"but {key!r} is of type {_describe_type(key)}" + ) + # filter out parse time resolved values from the resolved_oids + resolved_oids = {id(v) for k, v in mapping.items() if not _is_parse_time_mappable(v)} + + return mapping, resolved_oids diff --git a/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py b/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py index fcd68ba20b2c6..93fd9431cbe38 100644 --- a/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py +++ b/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py @@ -133,7 +133,7 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]: """ raise NotImplementedError - def resolve(self, context: Context, *, include_xcom: bool = True) -> Any: + def resolve(self, context: Context) -> Any: """ Resolve this value for runtime. diff --git a/task_sdk/src/airflow/sdk/definitions/_internal/templater.py b/task_sdk/src/airflow/sdk/definitions/_internal/templater.py index d7028d4d6bca7..b50c4dbb3cadb 100644 --- a/task_sdk/src/airflow/sdk/definitions/_internal/templater.py +++ b/task_sdk/src/airflow/sdk/definitions/_internal/templater.py @@ -50,7 +50,7 @@ class LiteralValue(ResolveMixin): def iter_references(self) -> Iterable[tuple[Operator, str]]: return () - def resolve(self, context: Context, *, include_xcom: bool = True) -> Any: + def resolve(self, context: Context) -> Any: return self.value @@ -179,7 +179,7 @@ def render_template( return self._render_object_storage_path(value, context, jinja_env) if resolve := getattr(value, "resolve", None): - return resolve(context, include_xcom=True) + return resolve(context) # Fast path for common built-in collections. if value.__class__ is tuple: diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index 5662d542859f7..3c8d99737ab8e 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -359,6 +359,13 @@ class DAG: # argument "description" in call to "DAG"`` etc), so for init=True args we use the `default=Factory()` # style + def __rich_repr__(self): + yield "dag_id", self.dag_id + yield "schedule", self.schedule + yield "#tasks", len(self.tasks) + + __rich_repr__.angular = True # type: ignore[attr-defined] + # NOTE: When updating arguments here, please also keep arguments in @dag() # below in sync. (Search for 'def dag(' in this file.) dag_id: str = attrs.field(kw_only=False, validator=attrs.validators.instance_of(str)) diff --git a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py index dddeaf21cf2bc..92f83792f3c79 100644 --- a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -27,11 +27,6 @@ import methodtools from airflow.models.abstractoperator import NotMapped -from airflow.models.expandinput import ( - DictOfListsExpandInput, - ListOfDictsExpandInput, - is_mappable, -) from airflow.sdk.definitions._internal.abstractoperator import ( DEFAULT_EXECUTOR, DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, @@ -47,13 +42,16 @@ DEFAULT_WEIGHT_RULE, AbstractOperator, ) +from airflow.sdk.definitions._internal.expandinput import ( + DictOfListsExpandInput, + ListOfDictsExpandInput, + is_mappable, +) from airflow.sdk.definitions._internal.types import NOTSET from airflow.serialization.enums import DagAttributeTypes from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy -from airflow.triggers.base import StartTriggerArgs from airflow.typing_compat import Literal from airflow.utils.helpers import is_container, prevent_duplicates -from airflow.utils.task_instance_session import get_current_task_instance_session from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: @@ -61,7 +59,6 @@ import jinja2 # Slow import. import pendulum - from sqlalchemy.orm.session import Session from airflow.models.abstractoperator import ( TaskStateChangeCallback, @@ -78,6 +75,7 @@ from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.types import Operator from airflow.ti_deps.deps.base_ti_dep import BaseTIDep + from airflow.triggers.base import StartTriggerArgs from airflow.typing_compat import TypeGuard from airflow.utils.context import Context from airflow.utils.operator_resources import Resources @@ -683,16 +681,14 @@ def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: """Implement DAGNode.""" return DagAttributeTypes.OP, self.task_id - def _expand_mapped_kwargs( - self, context: Mapping[str, Any], session: Session, *, include_xcom: bool - ) -> tuple[Mapping[str, Any], set[int]]: + def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: """ Get the kwargs to create the unmapped operator. This exists because taskflow operators expand against op_kwargs, not the entire operator kwargs dict. """ - return self._get_specified_expand_input().resolve(context, session, include_xcom=include_xcom) + return self._get_specified_expand_input().resolve(context) def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: """ @@ -726,70 +722,7 @@ def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) - "params": params, } - def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool: - """ - Get the start_from_trigger value of the current abstract operator. - - MappedOperator uses this to unmap start_from_trigger to decide whether to start the task - execution directly from triggerer. - - :meta private: - """ - # start_from_trigger only makes sense when start_trigger_args exists. - if not self.start_trigger_args: - return False - - mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False) - if self._disallow_kwargs_override: - prevent_duplicates( - self.partial_kwargs, - mapped_kwargs, - fail_reason="unmappable or already specified", - ) - - # Ordering is significant; mapped kwargs should override partial ones. - return mapped_kwargs.get( - "start_from_trigger", self.partial_kwargs.get("start_from_trigger", self.start_from_trigger) - ) - - def expand_start_trigger_args(self, *, context: Context, session: Session) -> StartTriggerArgs | None: - """ - Get the kwargs to create the unmapped start_trigger_args. - - This method is for allowing mapped operator to start execution from triggerer. - """ - if not self.start_trigger_args: - return None - - mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False) - if self._disallow_kwargs_override: - prevent_duplicates( - self.partial_kwargs, - mapped_kwargs, - fail_reason="unmappable or already specified", - ) - - # Ordering is significant; mapped kwargs should override partial ones. - trigger_kwargs = mapped_kwargs.get( - "trigger_kwargs", - self.partial_kwargs.get("trigger_kwargs", self.start_trigger_args.trigger_kwargs), - ) - next_kwargs = mapped_kwargs.get( - "next_kwargs", - self.partial_kwargs.get("next_kwargs", self.start_trigger_args.next_kwargs), - ) - timeout = mapped_kwargs.get( - "trigger_timeout", self.partial_kwargs.get("trigger_timeout", self.start_trigger_args.timeout) - ) - return StartTriggerArgs( - trigger_cls=self.start_trigger_args.trigger_cls, - trigger_kwargs=trigger_kwargs, - next_method=self.start_trigger_args.next_method, - next_kwargs=next_kwargs, - timeout=timeout, - ) - - def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator: + def unmap(self, resolve: None | Mapping[str, Any]) -> BaseOperator: """ Get the "normal" Operator after applying the current mapping. @@ -798,30 +731,21 @@ def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> not a class (i.e. this DAG has been deserialized), this returns a SerializedBaseOperator that "looks like" the actual unmapping result. - If *resolve* is a two-tuple (context, session), the information is used - to resolve the mapped arguments into init arguments. If it is a mapping, - no resolving happens, the mapping directly provides those init arguments - resolved from mapped kwargs. - :meta private: """ if isinstance(self.operator_class, type): if isinstance(resolve, Mapping): kwargs = resolve elif resolve is not None: - kwargs, _ = self._expand_mapped_kwargs(*resolve, include_xcom=True) + kwargs, _ = self._expand_mapped_kwargs(*resolve) else: raise RuntimeError("cannot unmap a non-serialized operator without context") kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override) is_setup = kwargs.pop("is_setup", False) is_teardown = kwargs.pop("is_teardown", False) on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False) + kwargs["task_id"] = self.task_id op = self.operator_class(**kwargs, _airflow_from_mapped=True) - # We need to overwrite task_id here because BaseOperator further - # mangles the task_id based on the task hierarchy (namely, group_id - # is prepended, and '__N' appended to deduplicate). This is hacky, - # but better than duplicating the whole mangling logic. - op.task_id = self.task_id op.is_setup = is_setup op.is_teardown = is_teardown op.on_failure_fail_dagrun = on_failure_fail_dagrun @@ -856,7 +780,7 @@ def prepare_for_execution(self) -> MappedOperator: def iter_mapped_dependencies(self) -> Iterator[Operator]: """Upstream dependencies that provide XComs used by this task for task mapping.""" - from airflow.models.xcom_arg import XComArg + from airflow.sdk.definitions.xcom_arg import XComArg for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()): yield operator @@ -886,23 +810,14 @@ def render_template_fields( :param context: Context dict with values to apply on content. :param jinja_env: Jinja environment to use for rendering. """ - from airflow.utils.context import context_update_for_unmapped + from airflow.sdk.execution_time.context import context_update_for_unmapped if not jinja_env: jinja_env = self.get_template_env() - # We retrieve the session here, stored by _run_raw_task in set_current_task_session - # context manager - we cannot pass the session via @provide_session because the signature - # of render_template_fields is defined by BaseOperator and there are already many subclasses - # overriding it, so changing the signature is not an option. However render_template_fields is - # always executed within "_run_raw_task" so we make sure that _run_raw_task uses the - # set_current_task_session context manager to store the session in the current task. - session = get_current_task_instance_session() - - mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session, include_xcom=True) + mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context) unmapped_task = self.unmap(mapped_kwargs) - # TODO: Task-SDK: remove arg-type ignore once Kaxil's PR lands - context_update_for_unmapped(context, unmapped_task) # type: ignore[arg-type] + context_update_for_unmapped(context, unmapped_task) # Since the operators that extend `BaseOperator` are not subclasses of # `MappedOperator`, we need to call `_do_render_template_fields` from diff --git a/task_sdk/src/airflow/sdk/definitions/param.py b/task_sdk/src/airflow/sdk/definitions/param.py index cd3ccec26a48a..d9eec82a147fc 100644 --- a/task_sdk/src/airflow/sdk/definitions/param.py +++ b/task_sdk/src/airflow/sdk/definitions/param.py @@ -67,8 +67,7 @@ def _check_json(value): json.dumps(value) except Exception: raise ParamValidationError( - "All provided parameters must be json-serializable. " - f"The value '{value}' is not serializable." + f"All provided parameters must be json-serializable. The value '{value}' is not serializable." ) def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any: @@ -294,7 +293,7 @@ def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET): def iter_references(self) -> Iterable[tuple[Operator, str]]: return () - def resolve(self, context: Context, *, include_xcom: bool = True) -> Any: + def resolve(self, context: Context) -> Any: """Pull DagParam value from DagRun context. This method is run during ``op.execute()``.""" with contextlib.suppress(KeyError): if context["dag_run"].conf: diff --git a/task_sdk/src/airflow/sdk/definitions/xcom_arg.py b/task_sdk/src/airflow/sdk/definitions/xcom_arg.py index 436cd9d005012..42eef8321f4b7 100644 --- a/task_sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task_sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -20,30 +20,28 @@ import contextlib import inspect import itertools -from collections.abc import Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Callable, Union, overload +from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized +from functools import singledispatch +from typing import TYPE_CHECKING, Any, Callable, overload from airflow.exceptions import AirflowException, XComNotFound from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions._internal.mixins import DependencyMixin, ResolveMixin from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet -from airflow.utils.session import NEW_SESSION, provide_session +from airflow.sdk.execution_time.lazy_sequence import LazyXComSequence from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.trigger_rule import TriggerRule from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: - from sqlalchemy.orm import Session - from airflow.sdk.definitions.baseoperator import BaseOperator - from airflow.sdk.definitions.dag import DAG from airflow.sdk.types import Operator from airflow.utils.edgemodifier import EdgeModifier # Callable objects contained by MapXComArg. We only accept callables from # the user, but deserialize them into strings in a serialized XComArg for # safety (those callables are arbitrary user code). -MapCallables = Sequence[Union[Callable[[Any], Any], str]] +MapCallables = Sequence[Callable[[Any], Any]] class XComArg(ResolveMixin, DependencyMixin): @@ -168,20 +166,6 @@ def _serialize(self) -> dict[str, Any]: """ raise NotImplementedError() - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - """ - Deserialize an XComArg. - - The implementation should be the inverse function to ``serialize``, - implementing given a data dict converted from this XComArg derivative, - how the original XComArg should be created. DAG serialization relies on - additional information added in ``serialize_xcom_arg`` to dispatch data - dicts to the correct ``_deserialize`` information, so this function does - not need to validate whether the incoming data contains correct keys. - """ - raise NotImplementedError() - def map(self, f: Callable[[Any], Any]) -> MapXComArg: return MapXComArg(self, [f]) @@ -191,9 +175,7 @@ def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: def concat(self, *others: XComArg) -> ConcatXComArg: return ConcatXComArg([self, *others]) - def resolve( - self, context: Mapping[str, Any], session: Session | None = None, *, include_xcom: bool = True - ) -> Any: + def resolve(self, context: Mapping[str, Any]) -> Any: raise NotImplementedError() def __enter__(self): @@ -266,7 +248,7 @@ def __str__(self) -> str: **Example**: to use XComArg at BashOperator:: - BashOperator(cmd=f"... { xcomarg } ...") + BashOperator(cmd=f"... {xcomarg} ...") :return: """ @@ -285,10 +267,6 @@ def __str__(self) -> str: def _serialize(self) -> dict[str, Any]: return {"task_id": self.operator.task_id, "key": self.key} - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - return cls(dag.get_task(data["task_id"]), data["key"]) - @property def is_setup(self) -> bool: return self.operator.is_setup @@ -354,17 +332,15 @@ def concat(self, *others: XComArg) -> ConcatXComArg: raise ValueError("cannot concatenate non-return XCom") return super().concat(*others) - # TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK - def resolve( - self, context: Mapping[str, Any], session: Session | None = None, *, include_xcom: bool = True - ) -> Any: + def resolve(self, context: Mapping[str, Any]) -> Any: ti = context["ti"] task_id = self.operator.task_id - map_indexes = context.get("_upstream_map_indexes", {}).get(task_id) + + if self.operator.is_mapped: + return LazyXComSequence[Any](xcom_arg=self, ti=ti) result = ti.xcom_pull( task_ids=task_id, - map_indexes=map_indexes, key=self.key, default=NOTSET, ) @@ -403,18 +379,8 @@ def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: def __getitem__(self, index: Any) -> Any: value = self.value[index] - # In the worker, we can access all actual callables. Call them. - callables = [f for f in self.callables if callable(f)] - if len(callables) == len(self.callables): - for f in callables: - value = f(value) - return value - - # In the scheduler, we don't have access to the actual callables, nor do - # we want to run it since it's arbitrary code. This builds a string to - # represent the call chain in the UI or logs instead. - for v in self.callables: - value = f"{_get_callable_name(v)}({value})" + for f in self.callables: + value = f(value) return value def __len__(self) -> int: @@ -448,12 +414,6 @@ def _serialize(self) -> dict[str, Any]: "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables], } - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - # We are deliberately NOT deserializing the callables. These are shown - # in the UI, and displaying a function object is useless. - return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) - def iter_references(self) -> Iterator[tuple[Operator, str]]: yield from self.arg.iter_references() @@ -461,12 +421,8 @@ def map(self, f: Callable[[Any], Any]) -> MapXComArg: # Flatten arg.map(f1).map(f2) into one MapXComArg. return MapXComArg(self.arg, [*self.callables, f]) - # TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK - @provide_session - def resolve( - self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True - ) -> Any: - value = self.arg.resolve(context, session=session, include_xcom=include_xcom) + def resolve(self, context: Mapping[str, Any]) -> Any: + value = self.arg.resolve(context) if not isinstance(value, (Sequence, dict)): raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") return _MapResult(value, self.callables) @@ -525,22 +481,12 @@ def _serialize(self) -> dict[str, Any]: return {"args": args} return {"args": args, "fillvalue": self.fillvalue} - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - return cls( - [deserialize_xcom_arg(arg, dag) for arg in data["args"]], - fillvalue=data.get("fillvalue", NOTSET), - ) - def iter_references(self) -> Iterator[tuple[Operator, str]]: for arg in self.args: yield from arg.iter_references() - @provide_session - def resolve( - self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True - ) -> Any: - values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args] + def resolve(self, context: Mapping[str, Any]) -> Any: + values = [arg.resolve(context) for arg in self.args] for value in values: if not isinstance(value, (Sequence, dict)): raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}") @@ -594,10 +540,6 @@ def __repr__(self) -> str: def _serialize(self) -> dict[str, Any]: return {"args": [serialize_xcom_arg(arg) for arg in self.args]} - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - return cls([deserialize_xcom_arg(arg, dag) for arg in data["args"]]) - def iter_references(self) -> Iterator[tuple[Operator, str]]: for arg in self.args: yield from arg.iter_references() @@ -606,11 +548,8 @@ def concat(self, *others: XComArg) -> ConcatXComArg: # Flatten foo.concat(x).concat(y) into one call. return ConcatXComArg([*self.args, *others]) - @provide_session - def resolve( - self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True - ) -> Any: - values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args] + def resolve(self, context: Mapping[str, Any]) -> Any: + values = [arg.resolve(context) for arg in self.args] for value in values: if not isinstance(value, (Sequence, dict)): raise ValueError(f"XCom concat expects sequence or dict, not {type(value).__name__}") @@ -633,7 +572,44 @@ def serialize_xcom_arg(value: XComArg) -> dict[str, Any]: return value._serialize() -def deserialize_xcom_arg(data: dict[str, Any], dag: DAG) -> XComArg: - """DAG serialization interface.""" - klass = _XCOM_ARG_TYPES[data.get("type", "")] - return klass._deserialize(data, dag) +@singledispatch +def get_task_map_length( + xcom_arg: XComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int] +) -> int | None: + # The base implementation -- specific XComArg subclasses have specialised implementations + raise NotImplementedError() + + +@get_task_map_length.register +def _(xcom_arg: PlainXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]): + task_id = xcom_arg.operator.task_id + + if xcom_arg.operator.is_mapped: + # TODO: How to tell if all the upstream TIs finished? + pass + return (upstream_map_indexes.get(task_id) or 1) * len(resolved_val) + + +@get_task_map_length.register +def _(xcom_arg: MapXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]): + return get_task_map_length(xcom_arg.arg, resolved_val, upstream_map_indexes) + + +@get_task_map_length.register +def _(xcom_arg: ZipXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]): + all_lengths = (get_task_map_length(arg, resolved_val, upstream_map_indexes) for arg in xcom_arg.args) + ready_lengths = [length for length in all_lengths if length is not None] + if len(ready_lengths) != len(xcom_arg.args): + return None # If any of the referenced XComs is not ready, we are not ready either. + if isinstance(xcom_arg.fillvalue, ArgNotSet): + return min(ready_lengths) + return max(ready_lengths) + + +@get_task_map_length.register +def _(xcom_arg: ConcatXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]): + all_lengths = (get_task_map_length(arg, resolved_val, upstream_map_indexes) for arg in xcom_arg.args) + ready_lengths = [length for length in all_lengths if length is not None] + if len(ready_lengths) != len(xcom_arg.args): + return None # If any of the referenced XComs is not ready, we are not ready either. + return sum(ready_lengths) diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index 1f398cb8b6009..a7749ed4ae27a 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -119,6 +119,11 @@ def from_xcom_response(cls, xcom_response: XComResponse) -> XComResult: return cls(**xcom_response.model_dump()) +class XComCountResponse(BaseModel): + len: int + type: Literal["XComLengthResponse"] = "XComLengthResponse" + + class ConnectionResult(ConnectionResponse): type: Literal["ConnectionResult"] = "ConnectionResult" @@ -184,6 +189,7 @@ class OKResponse(BaseModel): StartupDetails, VariableResult, XComResult, + XComCountResponse, OKResponse, ], Field(discriminator="type"), @@ -240,6 +246,16 @@ class GetXCom(BaseModel): type: Literal["GetXCom"] = "GetXCom" +class GetXComCount(BaseModel): + """Get the number of (mapped) XCom values available.""" + + key: str + dag_id: str + run_id: str + task_id: str + type: Literal["GetNumberXComs"] = "GetNumberXComs" + + class SetXCom(BaseModel): key: str value: Annotated[ @@ -324,6 +340,7 @@ class GetPrevSuccessfulDagRun(BaseModel): GetPrevSuccessfulDagRun, GetVariable, GetXCom, + GetXComCount, PutVariable, RescheduleTask, SetRenderedFields, diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py b/task_sdk/src/airflow/sdk/execution_time/context.py index 984919ea1c86b..0674d27320a6c 100644 --- a/task_sdk/src/airflow/sdk/execution_time/context.py +++ b/task_sdk/src/airflow/sdk/execution_time/context.py @@ -42,6 +42,7 @@ if TYPE_CHECKING: from uuid import UUID + from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.variable import Variable @@ -315,3 +316,21 @@ def set_current_context(context: Context) -> Generator[Context, None, None]: expected=context, got=expected_state, ) + + +def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: + """ + Update context after task unmapping. + + Since ``get_template_context()`` is called before unmapping, the context + contains information about the mapped task. We need to do some in-place + updates to ensure the template context reflects the unmapped task instead. + + :meta private: + """ + from airflow.sdk.definitions.param import process_params + + context["task"] = context["ti"].task = task + context["params"] = process_params( + context["dag"], task, context["dag_run"].conf, suppress_exception=False + ) diff --git a/task_sdk/src/airflow/sdk/execution_time/lazy_sequence.py b/task_sdk/src/airflow/sdk/execution_time/lazy_sequence.py new file mode 100644 index 0000000000000..acbca29c0b954 --- /dev/null +++ b/task_sdk/src/airflow/sdk/execution_time/lazy_sequence.py @@ -0,0 +1,188 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import itertools +from collections.abc import Iterator, Sequence +from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload + +import attrs +import structlog + +if TYPE_CHECKING: + from airflow.sdk.definitions.xcom_arg import PlainXComArg + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + +T = TypeVar("T") + +log = structlog.get_logger(logger_name=__name__) + + +@attrs.define +class LazyXComIterator(Iterator[T]): + seq: LazyXComSequence[T] + index: int = 0 + dir: Literal[1, -1] = 1 + + def __next__(self) -> T: + if self.index < 0: + # When iterating backwards, avoid extra HTTP request + raise StopIteration() + val = self.seq._get_item(self.index) + if val is None: + # None isn't the best signal (it's bad in fact) but it's the best we can do until https://github.com/apache/airflow/issues/46426 + raise StopIteration() + self.index += self.dir + return val + + def __iter__(self) -> Iterator[T]: + return self + + +@attrs.define +class LazyXComSequence(Sequence[T]): + _len: int | None = attrs.field(init=False, default=None) + _xcom_arg: PlainXComArg = attrs.field(alias="xcom_arg") + _ti: RuntimeTaskInstance = attrs.field(alias="ti") + + def __repr__(self) -> str: + if self._len is not None: + counter = "item" if (length := len(self)) == 1 else "items" + return f"LazyXComSequence([{length} {counter}])" + return "LazyXComSequence()" + + def __str__(self) -> str: + return repr(self) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Sequence): + return NotImplemented + z = itertools.zip_longest(iter(self), iter(other), fillvalue=object()) + return all(x == y for x, y in z) + + def __iter__(self) -> Iterator[T]: + return LazyXComIterator(seq=self) + + def __len__(self) -> int: + if self._len is None: + from airflow.sdk.execution_time.comms import ErrorResponse, GetXComCount, XComCountResponse + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + task = self._xcom_arg.operator + + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetXComCount( + key=self._xcom_arg.key, + dag_id=task.dag_id, + run_id=self._ti.run_id, + task_id=task.task_id, + ), + ) + msg = SUPERVISOR_COMMS.get_message() + if isinstance(msg, ErrorResponse): + raise RuntimeError(msg) + elif not isinstance(msg, XComCountResponse): + raise TypeError(f"Got unexpected response to GetXComCount: {msg}") + self._len = msg.len + return self._len + + @overload + def __getitem__(self, key: int) -> T: ... + + @overload + def __getitem__(self, key: slice) -> Sequence[T]: ... + + def __getitem__(self, key: int | slice) -> T | Sequence[T]: + if isinstance(key, int): + if key >= 0: + return self._get_item(key) + else: + # val[-1] etc. + return self._get_item(len(self) + key) + + elif isinstance(key, slice): + # This implements the slicing syntax. We want to optimize negative slicing (e.g. seq[-10:]) by not + # doing an additional COUNT query (via HEAD http request) if possible. We can do this unless the + # start and stop have different signs (i.e. one is positive and another negative). + ... + """ + Todo? + elif isinstance(key, slice): + start, stop, reverse = _coerce_slice(key) + if start >= 0: + if stop is None: + stmt = self._select_asc.offset(start) + elif stop >= 0: + stmt = self._select_asc.slice(start, stop) + else: + stmt = self._select_asc.slice(start, len(self) + stop) + rows = [self._process_row(row) for row in self._session.execute(stmt)] + if reverse: + rows.reverse() + else: + if stop is None: + stmt = self._select_desc.limit(-start) + elif stop < 0: + stmt = self._select_desc.slice(-stop, -start) + else: + stmt = self._select_desc.slice(len(self) - stop, -start) + rows = [self._process_row(row) for row in self._session.execute(stmt)] + if not reverse: + rows.reverse() + return rows + """ + raise TypeError(f"Sequence indices must be integers or slices, not {type(key).__name__}") + + def _get_item(self, index: int) -> T: + # TODO: maybe we need to call SUPERVISOR_COMMS manually so we can handle not found here? + return self._ti.xcom_pull( + task_ids=self._xcom_arg.operator.task_id, + key=self._xcom_arg.key, + map_indexes=index, + ) + + +def _coerce_index(value: Any) -> int | None: + """ + Check slice attribute's type and convert it to int. + + See CPython documentation on this: + https://docs.python.org/3/reference/datamodel.html#object.__index__ + """ + if value is None or isinstance(value, int): + return value + if (index := getattr(value, "__index__", None)) is not None: + return index() + raise TypeError("slice indices must be integers or None or have an __index__ method") + + +def _coerce_slice(key: slice) -> tuple[int, int | None, bool]: + """ + Check slice content and convert it for SQL. + + See CPython documentation on this: + https://docs.python.org/3/reference/datamodel.html#slice-objects + """ + if key.step is None or key.step == 1: + reverse = False + elif key.step == -1: + reverse = True + else: + raise ValueError("non-trivial slice step not supported") + return _coerce_index(key.start) or 0, _coerce_index(key.stop), reverse diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 3a65247e0ee40..31ad8077aec4b 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -70,6 +70,7 @@ GetPrevSuccessfulDagRun, GetVariable, GetXCom, + GetXComCount, PrevSuccessfulDagRunResult, PutVariable, RescheduleTask, @@ -81,6 +82,7 @@ TaskState, ToSupervisor, VariableResult, + XComCountResponse, XComResult, ) @@ -797,6 +799,9 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index) xcom_result = XComResult.from_xcom_response(xcom) resp = xcom_result.model_dump_json().encode() + elif isinstance(msg, GetXComCount): + len = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id, msg.key) + resp = XComCountResponse(len=len).model_dump_json().encode() elif isinstance(msg, DeferTask): self._terminal_state = IntermediateTIState.DEFERRED self.client.task_instances.defer(self.id, msg) diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index b89d6e8c548bb..77cd29d090c2d 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -35,8 +35,10 @@ from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.sdk.api.datamodels._generated import AssetProfile, TaskInstance, TerminalTIState, TIRunContext from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager +from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUriRef from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.param import process_params from airflow.sdk.execution_time.comms import ( DeferTask, @@ -67,6 +69,7 @@ import jinja2 from structlog.typing import FilteringBoundLogger as Logger + from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions.context import Context @@ -83,6 +86,16 @@ class RuntimeTaskInstance(TaskInstance): max_tries: int = 0 """The maximum number of retries for the task.""" + def __rich_repr__(self): + yield "id", self.id + yield "task_id", self.task_id + yield "dag_id", self.dag_id + yield "run_id", self.run_id + yield "max_tries", self.max_tries + yield "task", type(self.task) + + __rich_repr__.angular = True # type: ignore[attr-defined] + def get_template_context(self) -> Context: # TODO: Move this to `airflow.sdk.execution_time.context` # once we port the entire context logic from airflow/utils/context.py ? @@ -123,12 +136,12 @@ def get_template_context(self) -> Context: }, "conn": ConnectionAccessor(), } - if self._ti_context_from_server: - dag_run = self._ti_context_from_server.dag_run + if from_server := self._ti_context_from_server: + dag_run = from_server.dag_run context_from_server: Context = { # TODO: Assess if we need to pass these through timezone.coerce_datetime - "dag_run": dag_run, + "dag_run": dag_run, # type: ignore[typeddict-item] # Removable after #46522 "data_interval_end": dag_run.data_interval_end, "data_interval_start": dag_run.data_interval_start, "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{dag_run.run_id}", @@ -164,6 +177,10 @@ def get_template_context(self) -> Context: "ts_nodash_with_tz": ts_nodash_with_tz, } ) + if from_server.upstream_map_indexes is not None: + # We stash this in here for later use, but we purposefully don't want to document it's + # existence. Should this be a private attribute on RuntimeTI instead perhaps? + setattr(self, "_upstream_map_indexes", from_server.upstream_map_indexes) return context @@ -195,10 +212,7 @@ def render_templates( # unmapped BaseOperator created by this function! This is because the # MappedOperator is useless for template rendering, and we need to be # able to access the unmapped task instead. - original_task.render_template_fields(context, jinja_env) - # TODO: Add support for rendering templates in the MappedOperator - # if isinstance(self.task, MappedOperator): - # self.task = context["ti"].task + self.task.render_template_fields(context, jinja_env) return original_task @@ -209,7 +223,7 @@ def xcom_pull( key: str = "return_value", # TODO: Make this a constant (``XCOM_RETURN_KEY``) include_prior_dates: bool = False, # TODO: Add support for this *, - map_indexes: int | Iterable[int] | None = None, + map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET, default: Any = None, run_id: str | None = None, ) -> Any: @@ -256,7 +270,7 @@ def xcom_pull( task_ids = self.task_id elif isinstance(task_ids, str): task_ids = [task_ids] - if map_indexes is None: + if isinstance(map_indexes, ArgNotSet): map_indexes = self.map_index elif isinstance(map_indexes, Iterable): # TODO: Handle multiple map_indexes or remove support @@ -364,8 +378,10 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance: # TODO: Handle task not found task = dag.task_dict[what.ti.task_id] - if not isinstance(task, BaseOperator): - raise TypeError(f"task is of the wrong type, got {type(task)}, wanted {BaseOperator}") + if not isinstance(task, (BaseOperator, MappedOperator)): + raise TypeError( + f"task is of the wrong type, got {type(task)}, wanted {BaseOperator} or {MappedOperator}" + ) return RuntimeTaskInstance.model_construct( **what.ti.model_dump(exclude_unset=True), @@ -451,22 +467,10 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]: else: raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") - # TODO: Render fields here - # 1. Implementing the part where we pull in the logic to render fields and add that here - # for all operators, we should do setattr(task, templated_field, rendered_templated_field) - # task.templated_fields should give all the templated_fields and each of those fields should - # give the rendered values. task.templated_fields should already be in a JSONable format and - # we should not have to handle that here. - - # 2. Once rendered, we call the `set_rtif` API to store the rtif in the metadata DB - - # so that we do not call the API unnecessarily - if rendered_fields := _get_rendered_fields(ti.task): - SUPERVISOR_COMMS.send_request(log=log, msg=SetRenderedFields(rendered_fields=rendered_fields)) return ti, log -def _get_rendered_fields(task: BaseOperator) -> dict[str, JsonValue]: +def _serialize_rendered_fields(task: AbstractOperator) -> dict[str, JsonValue]: # TODO: Port one of the following to Task SDK # airflow.serialization.helpers.serialize_template_field or # airflow.models.renderedtifields.get_serialized_template_fields @@ -505,6 +509,34 @@ def _process_outlets(context: Context, outlets: list[AssetProfile]): return task_outlets, outlet_events +def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSupervisor | None: + ti.hostname = get_hostname() + ti.task = ti.task.prepare_for_execution() + if ti.task.inlets or ti.task.outlets: + inlets = [asset.asprofile() for asset in ti.task.inlets if isinstance(asset, Asset)] + outlets = [asset.asprofile() for asset in ti.task.outlets if isinstance(asset, Asset)] + SUPERVISOR_COMMS.send_request(msg=RuntimeCheckOnTask(inlets=inlets, outlets=outlets), log=log) # type: ignore + ok_response = SUPERVISOR_COMMS.get_message() # type: ignore + if not isinstance(ok_response, OKResponse) or not ok_response.ok: + log.info("Runtime checks failed for task, marking task as failed..") + return TaskState( + state=TerminalTIState.FAILED, + end_date=datetime.now(tz=timezone.utc), + ) + + jinja_env = ti.task.dag.get_template_env() + ti.render_templates(context=context, jinja_env=jinja_env) + + if rendered_fields := _serialize_rendered_fields(ti.task): + # so that we do not call the API unnecessarily + SUPERVISOR_COMMS.send_request(log=log, msg=SetRenderedFields(rendered_fields=rendered_fields)) + + # TODO: Call pre execute etc. + + # No error, carry on and execute the task + return None + + def run(ti: RuntimeTaskInstance, log: Logger) -> ToSupervisor | None: """Run the task in this process.""" from airflow.exceptions import ( @@ -524,31 +556,17 @@ def run(ti: RuntimeTaskInstance, log: Logger) -> ToSupervisor | None: msg: ToSupervisor | None = None try: - # TODO: pre execute etc. - # TODO: Get a real context object - ti.hostname = get_hostname() - ti.task = ti.task.prepare_for_execution() - if ti.task.inlets or ti.task.outlets: - inlets = [asset.asprofile() for asset in ti.task.inlets if isinstance(asset, Asset)] - outlets = [asset.asprofile() for asset in ti.task.outlets if isinstance(asset, Asset)] - SUPERVISOR_COMMS.send_request(msg=RuntimeCheckOnTask(inlets=inlets, outlets=outlets), log=log) # type: ignore - ok_response = SUPERVISOR_COMMS.get_message() # type: ignore - if not isinstance(ok_response, OKResponse) or not ok_response.ok: - log.info("Runtime checks failed for task, marking task as failed..") - msg = TaskState( - state=TerminalTIState.FAILED, - end_date=datetime.now(tz=timezone.utc), - ) - return msg context = ti.get_template_context() with set_current_context(context): - jinja_env = ti.task.dag.get_template_env() - ti.task = ti.render_templates(context=context, jinja_env=jinja_env) - # TODO: Get things from _execute_task_with_callbacks - # - Pre Execute - # etc - result = _execute_task(context, ti.task) + # This is the earliest that we can render templates -- as if it excepts for any reason we need to + # catch it and handle it like a normal task failure + if early_exit := _prepare(ti, log, context): + msg = early_exit + return msg + + result = _execute_task(context, ti) + log.info("Pushing xcom", ti=ti) _push_xcom_if_needed(result, ti) task_outlets, outlet_events = _process_outlets(context, ti.task.outlets) @@ -629,10 +647,11 @@ def run(ti: RuntimeTaskInstance, log: Logger) -> ToSupervisor | None: return msg -def _execute_task(context: Context, task: BaseOperator): +def _execute_task(context: Context, ti: RuntimeTaskInstance): """Execute Task (optionally with a Timeout) and push Xcom results.""" from airflow.exceptions import AirflowTaskTimeout + task = ti.task if task.execution_timeout: # TODO: handle timeout in case of deferral from airflow.utils.timeout import timeout diff --git a/task_sdk/src/airflow/sdk/types.py b/task_sdk/src/airflow/sdk/types.py index 99ad1b1c1fa7b..795130bf37462 100644 --- a/task_sdk/src/airflow/sdk/types.py +++ b/task_sdk/src/airflow/sdk/types.py @@ -17,8 +17,11 @@ from __future__ import annotations +from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Protocol, Union +from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet + if TYPE_CHECKING: from collections.abc import Iterator from datetime import datetime @@ -44,7 +47,9 @@ class DagRunProtocol(Protocol): run_type: Any run_after: datetime conf: dict[str, Any] | None - external_trigger: bool + # This shouldn't be "| None", but there's a bug in the datamodel generator, and None evaluates to Falsey + # too, so this is "okay" + external_trigger: bool | None = False class RuntimeTaskInstanceProtocol(Protocol): @@ -55,7 +60,7 @@ class RuntimeTaskInstanceProtocol(Protocol): dag_id: str run_id: str try_number: int - map_index: int + map_index: int | None max_tries: int hostname: str | None = None @@ -67,7 +72,7 @@ def xcom_pull( # TODO: `include_prior_dates` isn't yet supported in the SDK # include_prior_dates: bool = False, *, - map_indexes: int | list[int] | None = None, + map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET, default: Any = None, run_id: str | None = None, ) -> Any: ... diff --git a/task_sdk/tests/api/test_client.py b/task_sdk/tests/api/test_client.py index 319a00354bdb7..43e35dfec9ec0 100644 --- a/task_sdk/tests/api/test_client.py +++ b/task_sdk/tests/api/test_client.py @@ -228,7 +228,9 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert resp == ti_context assert call_count == 4 - @pytest.mark.parametrize("state", [state for state in TerminalTIState]) + @pytest.mark.parametrize( + "state", [state for state in TerminalTIState if state != TerminalTIState.SUCCESS] + ) def test_task_instance_finish(self, state): # Simulate a successful response from the server that finishes (moved to terminal state) a task ti_id = uuid6.uuid7() diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py index fff15b59bf8e7..bb40c641bf9b2 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/tests/conftest.py @@ -35,6 +35,9 @@ from structlog.typing import EventDict, WrappedLogger from airflow.sdk.api.datamodels._generated import TIRunContext + from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.execution_time.comms import StartupDetails + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance @pytest.hookimpl() @@ -246,3 +249,135 @@ def mock_supervisor_comms(): "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True ) as supervisor_comms: yield supervisor_comms + + +@pytest.fixture +def mocked_parse(spy_agency): + """ + Fixture to set up an inline DAG and use it in a stubbed `parse` function. Use this fixture if you + want to isolate and test `parse` or `run` logic without having to define a DAG file. + + This fixture returns a helper function `set_dag` that: + 1. Creates an in line DAG with the given `dag_id` and `task` (limited to one task) + 2. Constructs a `RuntimeTaskInstance` based on the provided `StartupDetails` and task. + 3. Stubs the `parse` function using `spy_agency`, to return the mocked `RuntimeTaskInstance`. + + After adding the fixture in your test function signature, you can use it like this :: + + mocked_parse( + StartupDetails( + ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1), + file="", + requests_fd=0, + ), + "example_dag_id", + CustomOperator(task_id="hello"), + ) + """ + + def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTaskInstance: + from airflow.sdk.definitions.dag import DAG + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, parse + from airflow.utils import timezone + + if not task.has_dag(): + dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3)) + task.dag = dag + task = dag.task_dict[task.task_id] + else: + dag = task.dag + if what.ti_context.dag_run.conf: + dag.params = what.ti_context.dag_run.conf # type: ignore[assignment] + ti = RuntimeTaskInstance.model_construct( + **what.ti.model_dump(exclude_unset=True), + task=task, + _ti_context_from_server=what.ti_context, + max_tries=what.ti_context.max_tries, + ) + if hasattr(parse, "spy"): + spy_agency.unspy(parse) + spy_agency.spy_on(parse, call_fake=lambda _: ti) + return ti + + return set_dag + + +@pytest.fixture +def create_runtime_ti(mocked_parse, make_ti_context): + """ + Fixture to create a Runtime TaskInstance for testing purposes without defining a dag file. + + It mimics the behavior of the `parse` function by creating a `RuntimeTaskInstance` based on the provided + `StartupDetails` (formed from arguments) and task. This allows you to test the logic of a task without + having to define a DAG file, parse it, get context from the server, etc. + + Example usage: :: + + def test_custom_task_instance(create_runtime_ti): + class MyTaskOperator(BaseOperator): + def execute(self, context): + assert context["dag_run"].run_id == "test_run" + + task = MyTaskOperator(task_id="test_task") + ti = create_runtime_ti(task, context_from_server=make_ti_context(run_id="test_run")) + # Further test logic... + """ + from uuid6 import uuid7 + + from airflow.sdk.api.datamodels._generated import TaskInstance + from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails + + def _create_task_instance( + task: BaseOperator, + dag_id: str = "test_dag", + run_id: str = "test_run", + logical_date: str | datetime = "2024-12-01T01:00:00Z", + data_interval_start: str | datetime = "2024-12-01T00:00:00Z", + data_interval_end: str | datetime = "2024-12-01T01:00:00Z", + start_date: str | datetime = "2024-12-01T01:00:00Z", + run_type: str = "manual", + try_number: int = 1, + map_index: int | None = -1, + upstream_map_indexes: dict[str, int] | None = None, + ti_id=None, + conf=None, + ) -> RuntimeTaskInstance: + if not ti_id: + ti_id = uuid7() + + if task.has_dag(): + dag_id = task.dag.dag_id + + ti_context = make_ti_context( + dag_id=dag_id, + run_id=run_id, + logical_date=logical_date, + data_interval_start=data_interval_start, + data_interval_end=data_interval_end, + start_date=start_date, + run_type=run_type, + conf=conf, + ) + + if upstream_map_indexes is not None: + ti_context.upstream_map_indexes = upstream_map_indexes + + startup_details = StartupDetails( + ti=TaskInstance( + id=ti_id, + task_id=task.task_id, + dag_id=dag_id, + run_id=run_id, + try_number=try_number, + map_index=map_index, + ), + dag_rel_path="", + bundle_info=BundleInfo(name="anything", version="any"), + requests_fd=0, + ti_context=ti_context, + ) + + ti = mocked_parse(startup_details, dag_id, task) + return ti + + return _create_task_instance diff --git a/task_sdk/tests/definitions/conftest.py b/task_sdk/tests/definitions/conftest.py new file mode 100644 index 0000000000000..02ed599c5ec60 --- /dev/null +++ b/task_sdk/tests/definitions/conftest.py @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +import structlog + +from airflow.sdk.execution_time.comms import SucceedTask, TaskState + +if TYPE_CHECKING: + from airflow.sdk.definitions.dag import DAG + + +@pytest.fixture +def run_ti(create_runtime_ti, mock_supervisor_comms): + def run(dag: DAG, task_id: str, map_index: int): + """Run the task and return the state that the SDK sent as the result for easier asserts""" + from airflow.sdk.execution_time.task_runner import run + + log = structlog.get_logger() + + mock_supervisor_comms.send_request.reset_mock() + ti = create_runtime_ti(dag.task_dict[task_id], map_index=map_index) + run(ti, log) + + for call in mock_supervisor_comms.send_request.mock_calls: + msg = call.kwargs["msg"] + if isinstance(msg, (TaskState, SucceedTask)): + return msg.state + raise RuntimeError("Unable to find call to TaskState") + + return run diff --git a/task_sdk/tests/definitions/test_baseoperator.py b/task_sdk/tests/definitions/test_baseoperator.py index af6bf592f5373..35f33818dc198 100644 --- a/task_sdk/tests/definitions/test_baseoperator.py +++ b/task_sdk/tests/definitions/test_baseoperator.py @@ -621,26 +621,3 @@ def _do_render(): assert expected_log in caplog.text if not_expected_log: assert not_expected_log not in caplog.text - - -def test_find_mapped_dependants_in_another_group(): - from airflow.decorators import task as task_decorator - from airflow.sdk import TaskGroup - - @task_decorator - def gen(x): - return list(range(x)) - - @task_decorator - def add(x, y): - return x + y - - with DAG(dag_id="test"): - with TaskGroup(group_id="g1"): - gen_result = gen(3) - with TaskGroup(group_id="g2"): - add_result = add.partial(y=1).expand(x=gen_result) - - # breakpoint() - dependants = list(gen_result.operator.iter_mapped_dependants()) - assert dependants == [add_result.operator] diff --git a/task_sdk/tests/definitions/test_dag.py b/task_sdk/tests/definitions/test_dag.py index b760cf71396c6..fa0cd65846ec5 100644 --- a/task_sdk/tests/definitions/test_dag.py +++ b/task_sdk/tests/definitions/test_dag.py @@ -25,7 +25,7 @@ from airflow.exceptions import DuplicateTaskIdFound from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG, dag as dag_decorator -from airflow.sdk.definitions.param import Param, ParamsDict +from airflow.sdk.definitions.param import DagParam, Param, ParamsDict DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc) @@ -350,6 +350,13 @@ def test__tags_mutable(): assert test_dag.tags == expected_tags +def test_create_dag_while_active_context(): + """Test that we can safely create a DAG whilst a DAG is activated via ``with dag1:``.""" + with DAG(dag_id="simple_dag"): + DAG(dag_id="dag2") + # No asserts needed, it just needs to not fail + + class TestDagDecorator: DEFAULT_ARGS = { "owner": "test", @@ -418,12 +425,6 @@ def noop_pipeline(value): ... with pytest.raises(TypeError): noop_pipeline() - def test_create_dag_while_active_context(self): - """Test that we can safely create a DAG whilst a DAG is activated via ``with dag1:``.""" - with DAG(dag_id="simple_dag"): - DAG(dag_id="dag2") - # No asserts needed, it just needs to not fail - def test_documentation_template_rendered(self): """Test that @dag uses function docs as doc_md for DAG object""" @@ -457,3 +458,22 @@ def markdown_docs(): ... dag = markdown_docs() assert dag.dag_id == "test-dag" assert dag.doc_md == raw_content + + def test_dag_param_resolves(self): + """Test that dag param is correctly resolved by operator""" + from airflow.decorators import task + + @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) + def xcom_pass_to_op(value=self.VALUE): + @task + def return_num(num): + return num + + xcom_arg = return_num(value) + self.operator = xcom_arg.operator + + xcom_pass_to_op() + + assert isinstance(self.operator.op_args[0], DagParam) + self.operator.render_template_fields({}) + assert self.operator.op_args[0] == 42 diff --git a/task_sdk/tests/definitions/test_mappedoperator.py b/task_sdk/tests/definitions/test_mappedoperator.py index eeb79f31b4d47..64b0b8d7b8379 100644 --- a/task_sdk/tests/definitions/test_mappedoperator.py +++ b/task_sdk/tests/definitions/test_mappedoperator.py @@ -17,16 +17,20 @@ # under the License. from __future__ import annotations +import json from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Callable +from unittest import mock import pendulum import pytest +from airflow.sdk.api.datamodels._generated import TerminalTIState from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.mappedoperator import MappedOperator -from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.definitions.xcom_arg import XComArg +from airflow.sdk.execution_time.comms import GetXCom, SetXCom, XComResult from airflow.utils.trigger_rule import TriggerRule from tests_common.test_utils.mapping import expand_mapped_task # noqa: F401 @@ -120,9 +124,6 @@ def test_map_xcom_arg(): assert task1.downstream_list == [mapped] -# def test_map_xcom_arg_multiple_upstream_xcoms(dag_maker, session): - - def test_partial_on_instance() -> None: """`.partial` on an instance should fail -- it's only designed to be called on classes""" with pytest.raises(TypeError): @@ -154,15 +155,6 @@ def test_partial_on_invalid_pool_slots_raises() -> None: MockOperator.partial(task_id="pool_slots_test", pool="test", pool_slots="a").expand(arg1=[1, 2, 3]) -# def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expected): - - -# def test_expand_mapped_task_failed_state_in_db(dag_maker, session): - - -# def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): - - def test_mapped_task_applies_default_args_classic(): with DAG("test", default_args={"execution_timeout": timedelta(minutes=30)}) as dag: MockOperator(task_id="simple", arg1=None, arg2=0) @@ -191,43 +183,120 @@ def mapped(arg): @pytest.mark.parametrize( - "dag_params, task_params, expected_partial_params", + ("callable", "expected"), [ - pytest.param(None, None, ParamsDict(), id="none"), - pytest.param({"a": -1}, None, ParamsDict({"a": -1}), id="dag"), - pytest.param(None, {"b": -2}, ParamsDict({"b": -2}), id="task"), - pytest.param({"a": -1}, {"b": -2}, ParamsDict({"a": -1, "b": -2}), id="merge"), + pytest.param( + lambda partial, output1: partial.expand( + map_template=output1, map_static=output1, file_template=["/path/to/file.ext"] + ), + # Note to the next person to come across this. In #32272 we changed expand_kwargs so that it + # resolves the mapped template when it's in `expand_kwargs()`, but we _didn't_ do the same for + # things in `expand()`. This feels like a bug to me (ashb) but I am not changing that now, I have + # just moved and parametrized this test. + "{{ ds }}", + id="expand", + ), + pytest.param( + lambda partial, output1: partial.expand_kwargs( + [{"map_template": "{{ ds }}", "map_static": "{{ ds }}", "file_template": "/path/to/file.ext"}] + ), + "2024-12-01", + id="expand_kwargs", + ), ], ) -def test_mapped_expand_against_params(dag_params, task_params, expected_partial_params): - with DAG("test", params=dag_params) as dag: - MockOperator.partial(task_id="t", params=task_params).expand(params=[{"c": "x"}, {"d": 1}]) - - t = dag.get_task("t") - assert isinstance(t, MappedOperator) - assert t.params == expected_partial_params - assert t.expand_input.value == {"params": [{"c": "x"}, {"d": 1}]} - +def test_mapped_render_template_fields_validating_operator( + tmp_path, create_runtime_ti, mock_supervisor_comms, callable, expected: bool +): + file_template_dir = tmp_path / "path" / "to" + file_template_dir.mkdir(parents=True, exist_ok=True) + file_template = file_template_dir / "file.ext" + file_template.write_text("loaded data") + + class MyOperator(BaseOperator): + template_fields = ("partial_template", "map_template", "file_template") + template_ext = (".ext",) + + def __init__( + self, partial_template, partial_static, map_template, map_static, file_template, **kwargs + ): + for value in [partial_template, partial_static, map_template, map_static, file_template]: + assert isinstance(value, str), "value should have been resolved before unmapping" + super().__init__(**kwargs) + self.partial_template = partial_template + self.partial_static = partial_static + self.map_template = map_template + self.map_static = map_static + self.file_template = file_template + + def execute(self, context): + pass -# def test_mapped_render_template_fields_validating_operator(dag_maker, session, tmp_path): + with DAG("test_dag", template_searchpath=tmp_path.__fspath__()): + task1 = BaseOperator(task_id="op1") + mapped = MyOperator.partial( + task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" + ) + mapped = callable(mapped, task1.output) + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["{{ ds }}"]') -# def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, tmp_path): + mapped_ti = create_runtime_ti(task=mapped, map_index=0, upstream_map_indexes={task1.task_id: 1}) + assert isinstance(mapped_ti.task, MappedOperator) + mapped_ti.task.render_template_fields(context=mapped_ti.get_template_context()) + assert isinstance(mapped_ti.task, MyOperator) -# def test_mapped_render_nested_template_fields(dag_maker, session): + assert mapped_ti.task.partial_template == "a", "Should be templated!" + assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" + assert mapped_ti.task.map_template == expected + assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.file_template == "loaded data", "Should be templated!" -# def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis, expected): +def test_mapped_render_nested_template_fields(create_runtime_ti, mock_supervisor_comms): + with DAG("test_dag"): + mapped = MockOperatorWithNestedFields.partial( + task_id="t", arg2=NestedFields(field_1="{{ ti.task_id }}", field_2="value_2") + ).expand(arg1=["{{ ti.task_id }}", ["s", "{{ ti.task_id }}"]]) + ti = create_runtime_ti(task=mapped, map_index=0, upstream_map_indexes={}) + ti.task.render_template_fields(context=ti.get_template_context()) + assert ti.task.arg1 == "t" + assert ti.task.arg2.field_1 == "t" + assert ti.task.arg2.field_2 == "value_2" -# def test_expand_mapped_task_instance_with_named_index( + ti = create_runtime_ti(task=mapped, map_index=1, upstream_map_indexes={}) + ti.task.render_template_fields(context=ti.get_template_context()) + assert ti.task.arg1 == ["s", "t"] + assert ti.task.arg2.field_1 == "t" + assert ti.task.arg2.field_2 == "value_2" -# def test_expand_mapped_task_task_instance_mutation_hook(dag_maker, session, create_mapped_task) -> None: +@pytest.mark.parametrize( + ("map_index", "expected"), + [ + pytest.param(0, "2024-12-01", id="0"), + pytest.param(1, 2, id="1"), + ], +) +def test_expand_kwargs_render_template_fields_validating_operator( + map_index, expected, create_runtime_ti, mock_supervisor_comms +): + with DAG("test_dag"): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="a", arg2="{{ ti.task_id }}").expand_kwargs(task1.output) + mock_supervisor_comms.get_message.return_value = XComResult( + key="return_value", value=json.dumps([{"arg1": "{{ ds }}"}, {"arg1": 2}]) + ) -# def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, map_index, expected): + ti = create_runtime_ti(task=mapped, map_index=map_index, upstream_map_indexes={}) + assert isinstance(ti.task, MappedOperator) + ti.task.render_template_fields(context=ti.get_template_context()) + assert isinstance(ti.task, MockOperator) + assert ti.task.arg1 == expected + assert ti.task.arg2 == "a" def test_xcomarg_property_of_mapped_operator(): @@ -252,9 +321,6 @@ def test_set_xcomarg_dependencies_with_mapped_operator(): assert op2 in op5.upstream_list -# def test_all_xcomargs_from_mapped_tasks_are_consumable(dag_maker, session): - - def test_task_mapping_with_task_group_context(): from airflow.sdk.definitions.taskgroup import TaskGroup @@ -299,3 +365,331 @@ def test_task_mapping_with_explicit_task_group(): assert finish.upstream_list == [mapped] assert mapped.downstream_list == [finish] + + +def test_nested_mapped_task_groups(): + from airflow.decorators import task, task_group + + with DAG("test"): + + @task + def t(): + return [[1, 2], [3, 4]] + + @task + def m(x): + return x + + @task_group + def g1(x): + @task_group + def g2(y): + return m(y) + + return g2.expand(y=x) + + # Add a test once nested mapped task groups become supported + with pytest.raises(ValueError, match="Nested Mapped TaskGroups are not yet supported"): + g1.expand(x=t()) + + +RunTI = Callable[[DAG, str, int], TerminalTIState] + + +def test_map_cross_product(run_ti: RunTI, mock_supervisor_comms): + outputs = [] + + with DAG(dag_id="cross_product") as dag: + + @dag.task + def emit_numbers(): + return [1, 2] + + @dag.task + def emit_letters(): + return {"a": "x", "b": "y", "c": "z"} + + @dag.task + def show(number, letter): + outputs.append((number, letter)) + + show.expand(number=emit_numbers(), letter=emit_letters()) + + def xcom_get(): + # TODO: Tidy this after #45927 is reopened and fixed properly + last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] + if not isinstance(last_request, GetXCom): + return mock.DEFAULT + task = dag.get_task(last_request.task_id) + value = json.dumps(task.python_callable()) + return XComResult(key="return_value", value=value) + + mock_supervisor_comms.get_message.side_effect = xcom_get + + states = [run_ti(dag, "show", map_index) for map_index in range(6)] + assert states == [TerminalTIState.SUCCESS] * 6 + assert outputs == [ + (1, ("a", "x")), + (1, ("b", "y")), + (1, ("c", "z")), + (2, ("a", "x")), + (2, ("b", "y")), + (2, ("c", "z")), + ] + + +def test_map_product_same(run_ti: RunTI, mock_supervisor_comms): + """Test a mapped task can refer to the same source multiple times.""" + outputs = [] + + with DAG(dag_id="product_same") as dag: + + @dag.task + def emit_numbers(): + return [1, 2] + + @dag.task + def show(a, b): + outputs.append((a, b)) + + emit_task = emit_numbers() + show.expand(a=emit_task, b=emit_task) + + def xcom_get(): + # TODO: Tidy this after #45927 is reopened and fixed properly + last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] + if not isinstance(last_request, GetXCom): + return mock.DEFAULT + task = dag.get_task(last_request.task_id) + value = json.dumps(task.python_callable()) + return XComResult(key="return_value", value=value) + + mock_supervisor_comms.get_message.side_effect = xcom_get + + states = [run_ti(dag, "show", map_index) for map_index in range(4)] + assert states == [TerminalTIState.SUCCESS] * 4 + assert outputs == [(1, 1), (1, 2), (2, 1), (2, 2)] + + +class NestedFields: + """Nested fields for testing purposes.""" + + def __init__(self, field_1, field_2): + self.field_1 = field_1 + self.field_2 = field_2 + + +class MockOperatorWithNestedFields(BaseOperator): + """Operator with nested fields for testing purposes.""" + + template_fields = ("arg1", "arg2") + + def __init__(self, arg1: str = "", arg2: NestedFields | None = None, **kwargs): + super().__init__(**kwargs) + self.arg1 = arg1 + self.arg2 = arg2 + + def _render_nested_template_fields(self, content, context, jinja_env, seen_oids) -> None: + if id(content) not in seen_oids: + template_fields: tuple | None = None + + if isinstance(content, NestedFields): + template_fields = ("field_1", "field_2") + + if template_fields: + seen_oids.add(id(content)) + self._do_render_template_fields(content, template_fields, context, jinja_env, seen_oids) + return + + super()._render_nested_template_fields(content, context, jinja_env, seen_oids) + + +def test_find_mapped_dependants_in_another_group(): + from airflow.decorators import task as task_decorator + from airflow.sdk import TaskGroup + + @task_decorator + def gen(x): + return list(range(x)) + + @task_decorator + def add(x, y): + return x + y + + with DAG(dag_id="test"): + with TaskGroup(group_id="g1"): + gen_result = gen(3) + with TaskGroup(group_id="g2"): + add_result = add.partial(y=1).expand(x=gen_result) + + dependants = list(gen_result.operator.iter_mapped_dependants()) + assert dependants == [add_result.operator] + + +@pytest.mark.parametrize( + "partial_params, mapped_params, expected", + [ + pytest.param(None, [{"a": 1}], [{"a": 1}], id="simple"), + pytest.param({"b": 2}, [{"a": 1}], [{"a": 1, "b": 2}], id="merge"), + pytest.param({"b": 2}, [{"a": 1, "b": 3}], [{"a": 1, "b": 3}], id="override"), + pytest.param({"b": 2}, [{"a": 1, "b": 3}, {"b": 1}], [{"a": 1, "b": 3}, {"b": 1}], id="multiple"), + ], +) +def test_mapped_expand_against_params(create_runtime_ti, partial_params, mapped_params, expected): + with DAG("test"): + task = BaseOperator.partial(task_id="t", params=partial_params).expand(params=mapped_params) + + for map_index, expansion in enumerate(expected): + mapped_ti = create_runtime_ti(task=task, map_index=map_index) + mapped_ti.task.render_template_fields(context=mapped_ti.get_template_context()) + assert mapped_ti.task.params == expansion + + +def test_operator_mapped_task_group_receives_value(create_runtime_ti, mock_supervisor_comms): + # Test the runtime expansion behaviour of mapped task groups + mapped operators + results = {} + + from airflow.decorators import task_group + + with DAG("test") as dag: + + @dag.task + def t(value, *, ti=None): + results[(ti.task_id, ti.map_index)] = value + return value + + @task_group + def tg(va): + # Each expanded group has one t1 and t2 each. + t1 = t.override(task_id="t1")(va) + t2 = t.override(task_id="t2")(t1) + + with pytest.raises(NotImplementedError) as ctx: + t.override(task_id="t4").expand(value=va) + assert str(ctx.value) == "operator expansion in an expanded task group is not yet supported" + + return t2 + + # The group is mapped by 3. + t2 = tg.expand(va=[["a", "b"], [4], ["z"]]) + + # Aggregates results from task group. + t.override(task_id="t3")(t2) + + def xcom_get(): + # TODO: Tidy this after #45927 is reopened and fixed properly + last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] + if not isinstance(last_request, GetXCom): + return mock.DEFAULT + key = (last_request.task_id, last_request.map_index) + if key in expected_values: + value = expected_values[key] + return XComResult(key="return_value", value=json.dumps(value)) + elif last_request.map_index is None: + # Get all mapped XComValues for this ti + value = [v for k, v in expected_values.items() if k[0] == last_request.task_id] + return XComResult(key="return_value", value=json.dumps(value)) + return mock.DEFAULT + + mock_supervisor_comms.get_message.side_effect = xcom_get + + expected_values = { + ("tg.t1", 0): ["a", "b"], + ("tg.t1", 1): [4], + ("tg.t1", 2): ["z"], + ("tg.t2", 0): ["a", "b"], + ("tg.t2", 1): [4], + ("tg.t2", 2): ["z"], + ("t3", None): [["a", "b"], [4], ["z"]], + } + + # We hard-code the number of expansions here as the server is in charge of that. + expansion_per_task_id = { + "tg.t1": range(3), + "tg.t2": range(3), + "t3": [None], + } + for task in dag.tasks: + for map_index in expansion_per_task_id[task.task_id]: + mapped_ti = create_runtime_ti(task=task.prepare_for_execution(), map_index=map_index) + context = mapped_ti.get_template_context() + mapped_ti.task.render_template_fields(context) + mapped_ti.task.execute(context) + assert results == expected_values + + +@pytest.mark.xfail(reason="SkipMixin hasn't been ported over to use the Task Execution API yet") +def test_mapped_xcom_push_skipped_tasks(create_runtime_ti, mock_supervisor_comms): + from airflow.decorators import task_group + from airflow.operators.empty import EmptyOperator + + if TYPE_CHECKING: + from airflow.providers.standard.operators.python import ShortCircuitOperator + else: + ShortCircuitOperator = pytest.importorskip( + "airflow.providers.standard.operators.python" + ).ShortCircuitOperator + + with DAG("test") as dag: + + @task_group + def group(x): + short_op_push_xcom = ShortCircuitOperator( + task_id="push_xcom_from_shortcircuit", + python_callable=lambda arg: arg % 2 == 0, + op_kwargs={"arg": x}, + ) + empty_task = EmptyOperator(task_id="empty_task") + short_op_push_xcom >> empty_task + + group.expand(x=[0, 1]) + + for task in dag.tasks: + for map_index in range(2): + ti = create_runtime_ti(task=task.prepare_for_execution(), map_index=map_index) + context = ti.get_template_context() + ti.task.render_template_fields(context) + ti.task.execute(context) + + assert ti + # TODO: these tests might not be right + mock_supervisor_comms.send_request.assert_has_calls( + [ + SetXCom( + key="skipmixin_key", + value=None, + dag_id=ti.dag_id, + run_id=ti.run_id, + task_id="group.push_xcom_from_shortcircuit", + map_index=0, + ), + SetXCom( + key="return_value", + value=True, + dag_id=ti.dag_id, + run_id=ti.run_id, + task_id="group.push_xcom_from_shortcircuit", + map_index=0, + ), + SetXCom( + key="skipmixin_key", + value={"skipped": ["group.empty_task"]}, + dag_id=ti.dag_id, + run_id=ti.run_id, + task_id="group.push_xcom_from_shortcircuit", + map_index=1, + ), + ] + ) + # + # assert ( + # tis[0].xcom_pull(task_ids="group.push_xcom_from_shortcircuit", key="return_value", map_indexes=0) + # is True + # ) + # assert ( + # tis[0].xcom_pull(task_ids="group.push_xcom_from_shortcircuit", key="skipmixin_key", map_indexes=0) + # is None + # ) + # assert tis[0].xcom_pull( + # task_ids="group.push_xcom_from_shortcircuit", key="skipmixin_key", map_indexes=1 + # ) == {"skipped": ["group.empty_task"]} diff --git a/task_sdk/tests/definitions/test_xcom_arg.py b/task_sdk/tests/definitions/test_xcom_arg.py new file mode 100644 index 0000000000000..dc3c7f4916de0 --- /dev/null +++ b/task_sdk/tests/definitions/test_xcom_arg.py @@ -0,0 +1,360 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from typing import Callable +from unittest import mock + +import pytest +import structlog +from pytest_unordered import unordered + +from airflow.exceptions import AirflowSkipException +from airflow.sdk.api.datamodels._generated import TerminalTIState +from airflow.sdk.definitions.dag import DAG +from airflow.sdk.execution_time.comms import GetXCom, XComResult + +log = structlog.get_logger() + +RunTI = Callable[[DAG, str, int], TerminalTIState] + + +def test_xcom_map(run_ti: RunTI, mock_supervisor_comms): + results = set() + with DAG("test") as dag: + + @dag.task + def push(): + return ["a", "b", "c"] + + @dag.task + def pull(value): + results.add(value) + + pull.expand_kwargs(push().map(lambda v: {"value": v * 2})) + + # The function passed to "map" is *NOT* a task. + assert set(dag.task_dict) == {"push", "pull"} + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["a", "b", "c"]') + + for map_index in range(3): + assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + + assert results == {"aa", "bb", "cc"} + + +def test_xcom_map_transform_to_none(run_ti: RunTI, mock_supervisor_comms): + results = set() + + with DAG("test") as dag: + + @dag.task() + def push(): + return ["a", "b", "c"] + + @dag.task() + def pull(value): + results.add(value) + + def c_to_none(v): + if v == "c": + return None + return v + + pull.expand(value=push().map(c_to_none)) + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["a", "b", "c"]') + + # Run "pull". This should automatically convert "c" to None. + for map_index in range(3): + assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + + assert results == {"a", "b", None} + + +def test_xcom_convert_to_kwargs_fails_task(run_ti: RunTI, mock_supervisor_comms, captured_logs): + results = set() + + with DAG("test") as dag: + + @dag.task() + def push(): + return ["a", "b", "c"] + + @dag.task() + def pull(value): + results.add(value) + + def c_to_none(v): + if v == "c": + return None + return {"value": v} + + pull.expand_kwargs(push().map(c_to_none)) + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["a", "b", "c"]') + + # The first two "pull" tis should succeed. + for map_index in range(2): + assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + + # Clear captured logs from the above + captured_logs[:] = [] + + # But the third one fails because the map() result cannot be used as kwargs. + assert run_ti(dag, "pull", 2) == TerminalTIState.FAILED + + assert captured_logs == unordered( + [ + { + "event": "Task failed with exception", + "level": "error", + "timestamp": mock.ANY, + "exception": [ + { + "exc_type": "ValueError", + "exc_value": "expand_kwargs() expects a list[dict], not list[None]", + "frames": mock.ANY, + "is_cause": False, + "syntax_error": None, + } + ], + }, + ] + ) + + +def test_xcom_map_error_fails_task(mock_supervisor_comms, run_ti, captured_logs): + with DAG("test") as dag: + + @dag.task() + def push(): + return ["a", "b", "c"] + + @dag.task() + def pull(value): + print(value) + + def does_not_work_with_c(v): + if v == "c": + raise RuntimeError("nope") + return {"value": v * 2} + + pull.expand_kwargs(push().map(does_not_work_with_c)) + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["a", "b", "c"]') + # The third one (for "c") will fail. + assert run_ti(dag, "pull", 2) == TerminalTIState.FAILED + + assert captured_logs == unordered( + [ + { + "event": "Task failed with exception", + "level": "error", + "timestamp": mock.ANY, + "exception": [ + { + "exc_type": "RuntimeError", + "exc_value": "nope", + "frames": mock.ANY, + "is_cause": False, + "syntax_error": None, + } + ], + }, + ] + ) + + +def test_xcom_map_nest(mock_supervisor_comms, run_ti): + results = set() + + with DAG("test") as dag: + + @dag.task() + def push(): + return ["a", "b", "c"] + + @dag.task() + def pull(value): + results.add(value) + + converted = push().map(lambda v: v * 2).map(lambda v: {"value": v}) + pull.expand_kwargs(converted) + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["a", "b", "c"]') + + # Now "pull" should apply the mapping functions in order. + for map_index in range(3): + assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + assert results == {"aa", "bb", "cc"} + + +def test_xcom_map_zip_nest(mock_supervisor_comms, run_ti): + results = set() + + with DAG("test") as dag: + + @dag.task + def push_letters(): + return ["a", "b", "c", "d"] + + @dag.task + def push_numbers(): + return [1, 2, 3, 4] + + @dag.task + def pull(value): + results.add(value) + + doubled = push_numbers().map(lambda v: v * 2) + combined = doubled.zip(push_letters()) + + def convert_zipped(zipped): + letter, number = zipped + return letter * number + + pull.expand(value=combined.map(convert_zipped)) + + def xcom_get(): + # TODO: Tidy this after #45927 is reopened and fixed properly + last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] + if not isinstance(last_request, GetXCom): + return mock.DEFAULT + if last_request.task_id == "push_letters": + value = json.dumps(push_letters.function()) + return XComResult(key="return_value", value=value) + if last_request.task_id == "push_numbers": + value = json.dumps(push_numbers.function()) + return XComResult(key="return_value", value=value) + return mock.DEFAULT + + mock_supervisor_comms.get_message.side_effect = xcom_get + + # Run "pull". + for map_index in range(4): + assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + + assert results == {"aa", "bbbb", "cccccc", "dddddddd"} + + +def test_xcom_map_raise_to_skip(run_ti, mock_supervisor_comms): + result = [] + + with DAG("test") as dag: + + @dag.task() + def push(): + return ["a", "b", "c"] + + @dag.task() + def forward(value): + result.append(value) + + def skip_c(v): + if v == "c": + raise AirflowSkipException() + return {"value": v} + + forward.expand_kwargs(push().map(skip_c)) + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["a", "b", "c"]') + + # Run "forward". This should automatically skip "c". + states = [run_ti(dag, "forward", map_index) for map_index in range(3)] + + assert states == [TerminalTIState.SUCCESS, TerminalTIState.SUCCESS, TerminalTIState.SKIPPED] + + assert result == ["a", "b"] + + +def test_xcom_concat(run_ti, mock_supervisor_comms): + from airflow.sdk.definitions.xcom_arg import _ConcatResult + + agg_results = set() + all_results = None + + with DAG("test") as dag: + + @dag.task + def push_letters(): + return ["a", "b", "c"] + + @dag.task + def push_numbers(): + return [1, 2] + + @dag.task + def pull_one(value): + agg_results.add(value) + + @dag.task + def pull_all(value): + assert isinstance(value, _ConcatResult) + assert value[0] == "a" + assert value[1] == "b" + assert value[2] == "c" + assert value[3] == 1 + assert value[4] == 2 + with pytest.raises(IndexError): + value[5] + assert value[-5] == "a" + assert value[-4] == "b" + assert value[-3] == "c" + assert value[-2] == 1 + assert value[-1] == 2 + with pytest.raises(IndexError): + value[-6] + nonlocal all_results + all_results = list(value) + + pushed_values = push_letters().concat(push_numbers()) + + pull_one.expand(value=pushed_values) + pull_all(pushed_values) + + def xcom_get(): + # TODO: Tidy this after #45927 is reopened and fixed properly + last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] + if not isinstance(last_request, GetXCom): + return mock.DEFAULT + if last_request.task_id == "push_letters": + value = json.dumps(push_letters.function()) + return XComResult(key="return_value", value=value) + if last_request.task_id == "push_numbers": + value = json.dumps(push_numbers.function()) + return XComResult(key="return_value", value=value) + return mock.DEFAULT + + mock_supervisor_comms.get_message.side_effect = xcom_get + + # Run "pull_one" and "pull_all". + assert run_ti(dag, "pull_all", None) == TerminalTIState.SUCCESS + assert all_results == ["a", "b", "c", 1, 2] + + states = [run_ti(dag, "pull_one", map_index) for map_index in range(5)] + assert states == [TerminalTIState.SUCCESS] * 5 + assert agg_results == {"a", "b", "c", 1, 2} diff --git a/task_sdk/tests/execution_time/conftest.py b/task_sdk/tests/execution_time/conftest.py index e2d592e7e5195..4a537373363aa 100644 --- a/task_sdk/tests/execution_time/conftest.py +++ b/task_sdk/tests/execution_time/conftest.py @@ -18,14 +18,6 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from datetime import datetime - - from airflow.sdk.definitions.baseoperator import BaseOperator - from airflow.sdk.execution_time.comms import StartupDetails - from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance import pytest @@ -39,131 +31,3 @@ def disable_capturing(): sys.stderr = sys.__stderr__ yield sys.stdin, sys.stdout, sys.stderr = old_in, old_out, old_err - - -@pytest.fixture -def mocked_parse(spy_agency): - """ - Fixture to set up an inline DAG and use it in a stubbed `parse` function. Use this fixture if you - want to isolate and test `parse` or `run` logic without having to define a DAG file. - - This fixture returns a helper function `set_dag` that: - 1. Creates an in line DAG with the given `dag_id` and `task` (limited to one task) - 2. Constructs a `RuntimeTaskInstance` based on the provided `StartupDetails` and task. - 3. Stubs the `parse` function using `spy_agency`, to return the mocked `RuntimeTaskInstance`. - - After adding the fixture in your test function signature, you can use it like this :: - - mocked_parse( - StartupDetails( - ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1), - file="", - requests_fd=0, - ), - "example_dag_id", - CustomOperator(task_id="hello"), - ) - """ - - def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTaskInstance: - from airflow.sdk.definitions.dag import DAG - from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, parse - from airflow.utils import timezone - - if task.has_dag(): - if what.ti_context.dag_run.conf: - task.dag.params = what.ti_context.dag_run.conf # type: ignore[assignment] - ti = RuntimeTaskInstance.model_construct( - **what.ti.model_dump(exclude_unset=True), - task=task, - _ti_context_from_server=what.ti_context, - max_tries=what.ti_context.max_tries, - ) - spy_agency.spy_on(parse, call_fake=lambda _: ti) - return ti - - dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3)) - if what.ti_context.dag_run.conf: - dag.params = what.ti_context.dag_run.conf # type: ignore[assignment] - task.dag = dag - t = dag.task_dict[task.task_id] - ti = RuntimeTaskInstance.model_construct( - **what.ti.model_dump(exclude_unset=True), - task=t, - _ti_context_from_server=what.ti_context, - max_tries=what.ti_context.max_tries, - ) - spy_agency.spy_on(parse, call_fake=lambda _: ti) - return ti - - return set_dag - - -@pytest.fixture -def create_runtime_ti(mocked_parse, make_ti_context): - """ - Fixture to create a Runtime TaskInstance for testing purposes without defining a dag file. - - It mimics the behavior of the `parse` function by creating a `RuntimeTaskInstance` based on the provided - `StartupDetails` (formed from arguments) and task. This allows you to test the logic of a task without - having to define a DAG file, parse it, get context from the server, etc. - - Example usage: :: - - def test_custom_task_instance(create_runtime_ti): - class MyTaskOperator(BaseOperator): - def execute(self, context): - assert context["dag_run"].run_id == "test_run" - - task = MyTaskOperator(task_id="test_task") - ti = create_runtime_ti(task, context_from_server=make_ti_context(run_id="test_run")) - # Further test logic... - """ - from uuid6 import uuid7 - - from airflow.sdk.api.datamodels._generated import TaskInstance - from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails - - def _create_task_instance( - task: BaseOperator, - dag_id: str = "test_dag", - run_id: str = "test_run", - logical_date: str | datetime = "2024-12-01T01:00:00Z", - data_interval_start: str | datetime = "2024-12-01T00:00:00Z", - data_interval_end: str | datetime = "2024-12-01T01:00:00Z", - start_date: str | datetime = "2024-12-01T01:00:00Z", - run_after: str | datetime = "2024-12-01T01:00:00Z", - run_type: str = "manual", - try_number: int = 1, - conf=None, - ti_id=None, - ) -> RuntimeTaskInstance: - if not ti_id: - ti_id = uuid7() - - ti_context = make_ti_context( - dag_id=dag_id, - run_id=run_id, - logical_date=logical_date, - data_interval_start=data_interval_start, - data_interval_end=data_interval_end, - start_date=start_date, - run_type=run_type, - run_after=run_after, - conf=conf, - ) - - startup_details = StartupDetails( - ti=TaskInstance( - id=ti_id, task_id=task.task_id, dag_id=dag_id, run_id=run_id, try_number=try_number - ), - dag_rel_path="", - bundle_info=BundleInfo(name="anything", version="any"), - requests_fd=0, - ti_context=ti_context, - ) - - ti = mocked_parse(startup_details, dag_id, task) - return ti - - return _create_task_instance diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 6d820cddd6d8d..85804836981d0 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -164,28 +164,6 @@ def test_parse(test_dags_dir: Path, make_ti_context): assert isinstance(ti.task.dag, DAG) -def test_run_basic(time_machine, create_runtime_ti, spy_agency, mock_supervisor_comms): - """Test running a basic task.""" - instant = timezone.datetime(2024, 12, 3, 10, 0) - time_machine.move_to(instant, tick=False) - - ti = create_runtime_ti(dag_id="super_basic_run", task=CustomOperator(task_id="hello")) - - # Ensure that task is locked for execution - spy_agency.spy_on(ti.task.prepare_for_execution) - assert not ti.task._lock_for_execution - - run(ti, log=mock.MagicMock()) - - spy_agency.assert_spy_called(ti.task.prepare_for_execution) - assert ti.task._lock_for_execution - - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask(state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), - log=mock.ANY, - ) - - def test_run_deferred_basic(time_machine, create_runtime_ti, mock_supervisor_comms): """Test that a task can transition to a deferred state.""" import datetime @@ -219,7 +197,7 @@ def test_run_deferred_basic(time_machine, create_runtime_ti, mock_supervisor_com run(ti, log=mock.MagicMock()) # send_request will only be called when the TaskDeferred exception is raised - mock_supervisor_comms.send_request.assert_called_once_with(msg=expected_defer_task, log=mock.ANY) + mock_supervisor_comms.send_request.assert_any_call(msg=expected_defer_task, log=mock.ANY) def test_run_basic_skipped(time_machine, create_runtime_ti, mock_supervisor_comms): @@ -239,7 +217,7 @@ def test_run_basic_skipped(time_machine, create_runtime_ti, mock_supervisor_comm run(ti, log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( + mock_supervisor_comms.send_request.assert_called_with( msg=TaskState(state=TerminalTIState.SKIPPED, end_date=instant), log=mock.ANY ) @@ -259,7 +237,7 @@ def test_run_raises_base_exception(time_machine, create_runtime_ti, mock_supervi run(ti, log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( + mock_supervisor_comms.send_request.assert_called_with( msg=TaskState( state=TerminalTIState.FAILED, end_date=instant, @@ -283,7 +261,7 @@ def test_run_raises_system_exit(time_machine, create_runtime_ti, mock_supervisor run(ti, log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( + mock_supervisor_comms.send_request.assert_called_with( msg=TaskState( state=TerminalTIState.FAILED, end_date=instant, @@ -309,7 +287,7 @@ def test_run_raises_airflow_exception(time_machine, create_runtime_ti, mock_supe run(ti, log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( + mock_supervisor_comms.send_request.assert_called_with( msg=TaskState( state=TerminalTIState.FAILED, end_date=instant, @@ -336,7 +314,7 @@ def test_run_task_timeout(time_machine, create_runtime_ti, mock_supervisor_comms run(ti, log=mock.MagicMock()) # this state can only be reached if the try block passed down the exception to handler of AirflowTaskTimeout - mock_supervisor_comms.send_request.assert_called_once_with( + mock_supervisor_comms.send_request.assert_called_with( msg=TaskState( state=TerminalTIState.FAILED, end_date=instant, @@ -345,7 +323,7 @@ def test_run_task_timeout(time_machine, create_runtime_ti, mock_supervisor_comms ) -def test_startup_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comms): +def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comms, spy_agency): """Test running a DAG with templated task.""" from airflow.providers.standard.operators.bash import BashOperator @@ -363,15 +341,23 @@ def test_startup_basic_templated_dag(mocked_parse, make_ti_context, mock_supervi requests_fd=0, ti_context=make_ti_context(), ) - mocked_parse(what, "basic_templated_dag", task) + ti = mocked_parse(what, "basic_templated_dag", task) - mock_supervisor_comms.get_message.return_value = what - startup() + # Ensure that task is locked for execution + spy_agency.spy_on(task.prepare_for_execution) + assert not task._lock_for_execution - mock_supervisor_comms.send_request.assert_called_once_with( + # mock_supervisor_comms.get_message.return_value = what + run(ti, log=mock.Mock()) + + spy_agency.assert_spy_called(task.prepare_for_execution) + assert ti.task._lock_for_execution + assert ti.task is not task, "ti.task should be a copy of the original task" + + mock_supervisor_comms.send_request.assert_any_call( msg=SetRenderedFields( rendered_fields={ - "bash_command": "echo 'Logical date is {{ logical_date }}'", + "bash_command": "echo 'Logical date is 2024-12-01 01:00:00+00:00'", "cwd": None, "env": None, } @@ -462,15 +448,14 @@ def execute(self, context): requests_fd=0, ti_context=make_ti_context(), ) - ti = mocked_parse(what, "basic_dag", task) + mocked_parse(what, "basic_dag", task) instant = timezone.datetime(2024, 12, 3, 10, 0) time_machine.move_to(instant, tick=False) mock_supervisor_comms.get_message.return_value = what - startup() - run(ti, log=mock.MagicMock()) + run(*startup()) expected_calls = [ mock.call.send_request( msg=SetRenderedFields(rendered_fields=expected_rendered_fields), @@ -497,11 +482,11 @@ def execute(self, context): ("{{ logical_date }}", "2024-12-01 01:00:00+00:00"), ], ) +@pytest.mark.usefixtures("mock_supervisor_comms") def test_startup_and_run_dag_with_templated_fields( - command, rendered_command, create_runtime_ti, time_machine, mock_supervisor_comms + command, rendered_command, create_runtime_ti, time_machine ): """Test startup of a DAG with various templated fields.""" - from airflow.providers.standard.operators.bash import BashOperator task = BashOperator(task_id="templated_task", bash_command=command) @@ -512,7 +497,6 @@ def test_startup_and_run_dag_with_templated_fields( instant = timezone.datetime(2024, 12, 3, 10, 0) time_machine.move_to(instant, tick=False) - run(ti, log=mock.MagicMock()) assert ti.task.bash_command == rendered_command @@ -725,6 +709,7 @@ def test_run_with_inlets_and_outlets( create_runtime_ti, mock_supervisor_comms, time_machine, ok, last_expected_msg ): """Test running a basic tasks with inlets and outlets.""" + instant = timezone.datetime(2024, 12, 3, 10, 0) time_machine.move_to(instant, tick=False) @@ -1040,8 +1025,6 @@ def execute(self, context): mock_supervisor_comms.get_message.return_value = XComResult(key="key", value='"value"') - spy_agency.spy_on(runtime_ti.xcom_pull, call_original=True) - run(runtime_ti, log=mock.MagicMock()) if isinstance(task_ids, str): diff --git a/tests/api_fastapi/execution_api/routes/test_xcoms.py b/tests/api_fastapi/execution_api/routes/test_xcoms.py index 3232622b3ac1f..bfb5f0ca73b6a 100644 --- a/tests/api_fastapi/execution_api/routes/test_xcoms.py +++ b/tests/api_fastapi/execution_api/routes/test_xcoms.py @@ -70,7 +70,7 @@ def test_xcom_not_found(self, client, create_task_instance): assert response.status_code == 404 assert response.json() == { "detail": { - "message": "XCom with key 'xcom_non_existent' not found for task 'task' in DAG 'dag'", + "message": "XCom with key='xcom_non_existent' map_index=-1 not found for task 'task' in DAG run 'runid' of 'dag'", "reason": "not_found", } } diff --git a/tests/decorators/test_task_group.py b/tests/decorators/test_task_group.py index d5fa174fa2511..9262ec0cffe45 100644 --- a/tests/decorators/test_task_group.py +++ b/tests/decorators/test_task_group.py @@ -215,6 +215,7 @@ def tg(a, b): @pytest.mark.db_test +@pytest.mark.need_serialized_dag def test_task_group_expand_kwargs_with_upstream(dag_maker, session, caplog): with dag_maker() as dag: @@ -239,6 +240,7 @@ def t2(): @pytest.mark.db_test +@pytest.mark.need_serialized_dag def test_task_group_expand_with_upstream(dag_maker, session, caplog): with dag_maker() as dag: diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index fc3fb439a130a..84801d6cfb5b7 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -149,6 +149,7 @@ def _create_dagrun( run_type: DagRunType, state: DagRunState = DagRunState.RUNNING, start_date: datetime.datetime | None = None, + **kwargs, ) -> DagRun: logical_date = timezone.coerce_datetime(logical_date) if not isinstance(data_interval, DataInterval): @@ -167,6 +168,7 @@ def _create_dagrun( state=state, start_date=start_date, triggered_by=DagRunTriggeredByType.TEST, + **kwargs, ) @@ -1535,16 +1537,14 @@ def consumer(value): PythonOperator.partial(task_id=task_id, python_callable=consumer).expand(op_args=make_arg_lists()) session = dag_maker.session - dagrun_1 = dag.create_dagrun( - run_id="backfill", + dagrun_1 = _create_dagrun( + dag, run_type=DagRunType.BACKFILL_JOB, - state=State.FAILED, + state=DagRunState.FAILED, start_date=DEFAULT_DATE, logical_date=DEFAULT_DATE, - session=session, data_interval=(DEFAULT_DATE, DEFAULT_DATE), - run_after=DEFAULT_DATE, - triggered_by=DagRunTriggeredByType.TEST, + session=session, ) # Get the (de)serialized MappedOperator mapped = dag.get_task(task_id) @@ -1662,26 +1662,6 @@ def check_task_2(my_input): mock_task_object_1.assert_called() mock_task_object_2.assert_not_called() - def test_dag_test_with_task_mapping(self): - dag = DAG(dag_id="test_local_testing_conn_file", schedule=None, start_date=DEFAULT_DATE) - mock_object = mock.MagicMock() - - @task_decorator() - def get_index(current_val, ti=None): - return ti.map_index - - @task_decorator - def check_task(my_input): - # we call a mock object with the combined map to ensure all expected indexes are called - mock_object(list(my_input)) - - with dag: - mapped_task = get_index.expand(current_val=[1, 1, 1, 1, 1]) - check_task(mapped_task) - - dag.test() - mock_object.assert_called_with([0, 1, 2, 3, 4]) - def test_dag_connection_file(self, tmp_path): test_connections_string = """ --- diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 769ffe029ce0e..b7c2f2282c8be 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -55,7 +55,7 @@ from tests_common.test_utils.config import conf_vars from tests_common.test_utils.mock_operators import MockOperator -pytestmark = pytest.mark.db_test +pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag] if TYPE_CHECKING: @@ -329,7 +329,7 @@ def test_dagrun_deadlock(self, dag_maker, session): assert dr.state == DagRunState.RUNNING ti_op2.set_state(state=None, session=session) - op2.trigger_rule = "invalid" # type: ignore + ti_op2.task.trigger_rule = "invalid" # type: ignore dr.update_state(session=session) assert dr.state == DagRunState.FAILED @@ -760,7 +760,6 @@ def mutate_task_instance(task_instance): (None, False), ], ) - @pytest.mark.need_serialized_dag def test_depends_on_past(self, dag_maker, session, prev_ti_state, is_ti_schedulable): # DAG tests depends_on_past dependencies with dag_maker( @@ -803,7 +802,6 @@ def test_depends_on_past(self, dag_maker, session, prev_ti_state, is_ti_schedula (None, False), ], ) - @pytest.mark.need_serialized_dag def test_wait_for_downstream(self, dag_maker, session, prev_ti_state, is_ti_schedulable): dag_id = "test_wait_for_downstream" @@ -1091,7 +1089,6 @@ def test_expand_mapped_task_instance_at_create(is_noop, dag_maker, session): assert indices == [(0,), (1,), (2,), (3,)] -@pytest.mark.need_serialized_dag @pytest.mark.parametrize("is_noop", [True, False]) def test_expand_mapped_task_instance_task_decorator(is_noop, dag_maker, session): with mock.patch("airflow.settings.task_instance_mutation_hook") as mock_mut: @@ -1115,7 +1112,6 @@ def mynameis(arg): assert indices == [(0,), (1,), (2,), (3,)] -@pytest.mark.need_serialized_dag def test_mapped_literal_verify_integrity(dag_maker, session): """Test that when the length of a mapped literal changes we remove extra TIs""" @@ -1148,7 +1144,6 @@ def task_2(arg2): ... assert indices == [(0, None), (1, None), (2, TaskInstanceState.REMOVED), (3, TaskInstanceState.REMOVED)] -@pytest.mark.need_serialized_dag def test_mapped_literal_to_xcom_arg_verify_integrity(dag_maker, session): """Test that when we change from literal to a XComArg the TIs are removed""" @@ -1182,7 +1177,6 @@ def task_2(arg2): ... ] -@pytest.mark.need_serialized_dag def test_mapped_literal_length_increase_adds_additional_ti(dag_maker, session): """Test that when the length of mapped literal increases, additional ti is added""" @@ -1225,7 +1219,6 @@ def task_2(arg2): ... ] -@pytest.mark.need_serialized_dag def test_mapped_literal_length_reduction_adds_removed_state(dag_maker, session): """Test that when the length of mapped literal reduces, removed state is added""" @@ -1266,7 +1259,6 @@ def task_2(arg2): ... ] -@pytest.mark.need_serialized_dag def test_mapped_length_increase_at_runtime_adds_additional_tis(dag_maker, session): """Test that when the length of mapped literal increases at runtime, additional ti is added""" # Variable.set(key="arg1", value=[1, 2, 3]) @@ -1318,7 +1310,6 @@ def task_2(arg2): ... ] -@pytest.mark.need_serialized_dag def test_mapped_literal_length_reduction_at_runtime_adds_removed_state(dag_maker, session): """ Test that when the length of mapped literal reduces at runtime, the missing task instances @@ -1406,7 +1397,6 @@ def task_2(arg2): ... assert len(decision.schedulable_tis) == 2 -@pytest.mark.need_serialized_dag def test_calls_to_verify_integrity_with_mapped_task_zero_length_at_runtime(dag_maker, session, caplog): """ Test zero length reduction in mapped task at runtime with calls to dagrun.verify_integrity @@ -1469,7 +1459,6 @@ def task_2(arg2): ... ) -@pytest.mark.need_serialized_dag def test_mapped_mixed_literal_not_expanded_at_create(dag_maker, session): literal = [1, 2, 3, 4] with dag_maker(session=session): @@ -1646,8 +1635,6 @@ def consumer(*args): def test_mapped_task_all_finish_before_downstream(dag_maker, session): - result = None - with dag_maker(session=session) as dag: @dag.task @@ -1660,8 +1647,8 @@ def double(value): @dag.task def consumer(value): - nonlocal result - result = list(value) + ... + # result = list(value) consumer(value=double.expand(value=make_list())) @@ -1675,26 +1662,29 @@ def _task_ids(tis): assert _task_ids(decision.schedulable_tis) == ["make_list"] # After make_list is run, double is expanded. - decision.schedulable_tis[0].run(verbose=False, session=session) + ti = decision.schedulable_tis[0] + ti.state = TaskInstanceState.SUCCESS + session.add(TaskMap.from_task_instance_xcom(ti, [1, 2])) + session.flush() + decision = dr.task_instance_scheduling_decisions(session=session) assert _task_ids(decision.schedulable_tis) == ["double", "double"] # Running just one of the mapped tis does not make downstream schedulable. - decision.schedulable_tis[0].run(verbose=False, session=session) + ti = decision.schedulable_tis[0] + ti.state = TaskInstanceState.SUCCESS + session.flush() + decision = dr.task_instance_scheduling_decisions(session=session) assert _task_ids(decision.schedulable_tis) == ["double"] - # Downstream is schedulable after all mapped tis are run. - decision.schedulable_tis[0].run(verbose=False, session=session) + # Downstream is scheduleable after all mapped tis are run. + ti = decision.schedulable_tis[0] + ti.state = TaskInstanceState.SUCCESS + session.flush() decision = dr.task_instance_scheduling_decisions(session=session) assert _task_ids(decision.schedulable_tis) == ["consumer"] - # We should be able to get all values aggregated from mapped upstreams. - decision.schedulable_tis[0].run(verbose=False, session=session) - decision = dr.task_instance_scheduling_decisions(session=session) - assert decision.schedulable_tis == [] - assert result == [2, 4] - def test_schedule_tis_map_index(dag_maker, session): with dag_maker(session=session, dag_id="test"): @@ -1772,6 +1762,7 @@ def test_schedule_tis_empty_operator_try_number(dag_maker, session: Session): assert empty_ti.try_number == 1 +@pytest.mark.xfail(reason="We can't keep this bevaviour with remote workers where scheduler can't reach xcom") def test_schedule_tis_start_trigger_through_expand(dag_maker, session): """ Test that an operator with start_trigger_args set can be directly deferred during scheduling. @@ -1926,28 +1917,18 @@ def do_something_else(i): @pytest.mark.parametrize( "partial_params, mapped_params, expected", [ - pytest.param(None, [{"a": 1}], [[("a", 1)]], id="simple"), - pytest.param({"b": 2}, [{"a": 1}], [[("a", 1), ("b", 2)]], id="merge"), - pytest.param({"b": 2}, [{"a": 1, "b": 3}], [[("a", 1), ("b", 3)]], id="override"), + pytest.param(None, [{"a": 1}], 1, id="simple"), + pytest.param({"b": 2}, [{"a": 1}], 1, id="merge"), + pytest.param({"b": 2}, [{"a": 1, "b": 3}], 1, id="override"), ], ) def test_mapped_expand_against_params(dag_maker, partial_params, mapped_params, expected): - results = [] - - class PullOperator(BaseOperator): - def execute(self, context): - results.append(sorted(context["params"].items())) - with dag_maker(): - PullOperator.partial(task_id="t", params=partial_params).expand(params=mapped_params) + BaseOperator.partial(task_id="t", params=partial_params).expand(params=mapped_params) dr: DagRun = dag_maker.create_dagrun() decision = dr.task_instance_scheduling_decisions() - - for ti in decision.schedulable_tis: - ti.run() - - assert sorted(results) == expected + assert len(decision.schedulable_tis) == expected def test_mapped_task_group_expands(dag_maker, session): @@ -1992,9 +1973,7 @@ def test_operator_mapped_task_group_receives_value(dag_maker, session): with dag_maker(session=session): @task - def t(value, *, ti=None): - results[(ti.task_id, ti.map_index)] = value - return value + def t(value): ... @task_group def tg(va): @@ -2016,24 +1995,29 @@ def tg(va): dr: DagRun = dag_maker.create_dagrun() - results = {} + results = set() decision = dr.task_instance_scheduling_decisions(session=session) for ti in decision.schedulable_tis: - ti.run() - assert results == {("tg.t1", 0): ["a", "b"], ("tg.t1", 1): [4], ("tg.t1", 2): ["z"]} + results.add((ti.task_id, ti.map_index)) + ti.state = TaskInstanceState.SUCCESS + session.flush() + assert results == {("tg.t1", 0), ("tg.t1", 1), ("tg.t1", 2)} - results = {} + results.clear() decision = dr.task_instance_scheduling_decisions(session=session) for ti in decision.schedulable_tis: - ti.run() - assert results == {("tg.t2", 0): ["a", "b"], ("tg.t2", 1): [4], ("tg.t2", 2): ["z"]} + results.add((ti.task_id, ti.map_index)) + ti.state = TaskInstanceState.SUCCESS + session.flush() + assert results == {("tg.t2", 0), ("tg.t2", 1), ("tg.t2", 2)} - results = {} + results.clear() decision = dr.task_instance_scheduling_decisions(session=session) for ti in decision.schedulable_tis: - ti.run() - assert len(results) == 1 - assert list(results[("t3", -1)]) == [["a", "b"], [4], ["z"]] + results.add((ti.task_id, ti.map_index)) + ti.state = TaskInstanceState.SUCCESS + session.flush() + assert results == {("t3", -1)} def test_mapping_against_empty_list(dag_maker, session): @@ -2101,13 +2085,15 @@ def print_value(value): decision = dr1.task_instance_scheduling_decisions(session=session) assert len(decision.schedulable_tis) == 2 for ti in decision.schedulable_tis: - ti.run(session=session) + ti.state = TaskInstanceState.SUCCESS + session.flush() # Now print_value in dr2 can run decision = dr2.task_instance_scheduling_decisions(session=session) assert len(decision.schedulable_tis) == 2 for ti in decision.schedulable_tis: - ti.run(session=session) + ti.state = TaskInstanceState.SUCCESS + session.flush() # Both runs are finished now. decision = dr1.task_instance_scheduling_decisions(session=session) @@ -2116,6 +2102,69 @@ def print_value(value): assert len(decision.unfinished_tis) == 0 +def test_xcom_map_skip_raised(dag_maker, session): + result = None + + with dag_maker(session=session) as dag: + # Note: this doesn't actually run this dag, the callbacks are for reference only. + + @dag.task() + def push(): + return ["a", "b", "c"] + + @dag.task() + def forward(value): + return value + + @dag.task(trigger_rule=TriggerRule.ALL_DONE) + def collect(value): + nonlocal result + result = list(value) + + def skip_c(v): + ... + # if v == "c": + # raise AirflowSkipException + # return {"value": v} + + collect(value=forward.expand_kwargs(push().map(skip_c))) + + dr: DagRun = dag_maker.create_dagrun(session=session) + + def _task_ids(tis): + return [(ti.task_id, ti.map_index) for ti in tis] + + # Check that when forward w/ map_index=2 ends up skipping, that the collect task can still be + # scheduled! + + # Run "push". + decision = dr.task_instance_scheduling_decisions(session=session) + assert _task_ids(decision.schedulable_tis) == [("push", -1)] + ti = decision.schedulable_tis[0] + ti.state = TaskInstanceState.SUCCESS + session.add(TaskMap.from_task_instance_xcom(ti, push.function())) + session.flush() + + decision = dr.task_instance_scheduling_decisions(session=session) + assert _task_ids(decision.schedulable_tis) == [ + ("forward", 0), + ("forward", 1), + ("forward", 2), + ] + # Run "forward". "c"/index 2 is skipped. Runtime behaviour checked in test_xcom_map_raise_to_skip in + # TaskSDK + for ti, state in zip( + decision.schedulable_tis, + [TaskInstanceState.SUCCESS, TaskInstanceState.SUCCESS, TaskInstanceState.SKIPPED], + ): + ti.state = state + session.flush() + + # Now "collect" should only get "a" and "b". + decision = dr.task_instance_scheduling_decisions(session=session) + assert _task_ids(decision.schedulable_tis) == [("collect", -1)] + + def test_clearing_task_and_moving_from_non_mapped_to_mapped(dag_maker, session): """ Test that clearing a task and moving from non-mapped to mapped clears existing @@ -2387,7 +2436,7 @@ def make_task(task_id, dag): with dag_maker() as dag: for line in input: - tasks = [make_task(x, dag) for x in line.split(" >> ")] + tasks = [make_task(x, dag_maker.dag) for x in line.split(" >> ")] reduce(lambda x, y: x >> y, tasks) dr = dag_maker.create_dagrun() diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 2d06dc6216f84..256d74d88bc88 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -32,19 +32,13 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.providers.standard.operators.python import PythonOperator -from airflow.sdk.definitions.mappedoperator import MappedOperator +from airflow.sdk.execution_time.comms import XComCountResponse from airflow.utils.state import TaskInstanceState from airflow.utils.task_group import TaskGroup -from airflow.utils.task_instance_session import set_current_task_instance_session -from airflow.utils.xcom import XCOM_RETURN_KEY from tests.models import DEFAULT_DATE from tests_common.test_utils.mapping import expand_mapped_task -from tests_common.test_utils.mock_operators import ( - MockOperator, - MockOperatorWithNestedFields, - NestedFields, -) +from tests_common.test_utils.mock_operators import MockOperator pytestmark = pytest.mark.db_test @@ -81,43 +75,6 @@ def execute(self, context: Context): mock_render_template.assert_called() -def test_map_xcom_arg_multiple_upstream_xcoms(dag_maker, session): - """Test that the correct number of downstream tasks are generated when mapping with an XComArg""" - - class PushExtraXComOperator(BaseOperator): - """Push an extra XCom value along with the default return value.""" - - def __init__(self, return_value, **kwargs): - super().__init__(**kwargs) - self.return_value = return_value - - def execute(self, context): - context["task_instance"].xcom_push(key="extra_key", value="extra_value") - return self.return_value - - with dag_maker("test-dag", session=session, start_date=DEFAULT_DATE) as dag: - upstream_return = [1, 2, 3] - task1 = PushExtraXComOperator(return_value=upstream_return, task_id="task_1") - task2 = PushExtraXComOperator.partial(task_id="task_2").expand(return_value=task1.output) - task3 = PushExtraXComOperator.partial(task_id="task_3").expand(return_value=task2.output) - - dr = dag_maker.create_dagrun() - ti_1 = dr.get_task_instance("task_1", session) - ti_1.run() - - ti_2s, _ = TaskMap.expand_mapped_task(task2, dr.run_id, session=session) - for ti in ti_2s: - ti.refresh_from_task(dag.get_task("task_2")) - ti.run() - - ti_3s, _ = TaskMap.expand_mapped_task(task3, dr.run_id, session=session) - for ti in ti_3s: - ti.refresh_from_task(dag.get_task("task_3")) - ti.run() - - assert len(ti_3s) == len(ti_2s) == len(upstream_return) - - @pytest.mark.parametrize( ["num_existing_tis", "expected"], ( @@ -255,151 +212,6 @@ def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): assert indices == [(-1, TaskInstanceState.SKIPPED)] -class _RenderTemplateFieldsValidationOperator(BaseOperator): - template_fields = ( - "partial_template", - "map_template_xcom", - "map_template_literal", - "map_template_file", - ) - template_ext = (".ext",) - - fields_to_test = [ - "partial_template", - "partial_static", - "map_template_xcom", - "map_template_literal", - "map_static", - "map_template_file", - ] - - def __init__( - self, - partial_template, - partial_static, - map_template_xcom, - map_template_literal, - map_static, - map_template_file, - **kwargs, - ): - for field in self.fields_to_test: - setattr(self, field, value := locals()[field]) - assert isinstance(value, str), "value should have been resolved before unmapping" - super().__init__(**kwargs) - - def execute(self, context): - pass - - -def test_mapped_render_template_fields_validating_operator(dag_maker, session, tmp_path): - file_template_dir = tmp_path / "path" / "to" - file_template_dir.mkdir(parents=True, exist_ok=True) - file_template = file_template_dir / "file.ext" - file_template.write_text("loaded data") - - with set_current_task_instance_session(session=session): - with dag_maker(session=session, template_searchpath=tmp_path.__fspath__()): - task1 = BaseOperator(task_id="op1") - output1 = task1.output - mapped = _RenderTemplateFieldsValidationOperator.partial( - task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" - ).expand( - map_static=output1, - map_template_literal=["{{ ds }}"], - map_template_xcom=output1, - map_template_file=["/path/to/file.ext"], - ) - - dr = dag_maker.create_dagrun() - ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) - ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session) - session.add( - TaskMap( - dag_id=dr.dag_id, - task_id=task1.task_id, - run_id=dr.run_id, - map_index=-1, - length=1, - keys=None, - ) - ) - session.flush() - - mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) - mapped_ti.map_index = 0 - assert isinstance(mapped_ti.task, MappedOperator) - mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) - assert isinstance(mapped_ti.task, _RenderTemplateFieldsValidationOperator) - - assert mapped_ti.task.partial_template == "a", "Should be rendered!" - assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be rendered!" - assert mapped_ti.task.map_static == "{{ ds }}", "Should not be rendered!" - assert mapped_ti.task.map_template_literal == "2016-01-01", "Should be rendered!" - assert mapped_ti.task.map_template_xcom == "{{ ds }}", "XCom resolved but not double rendered!" - assert mapped_ti.task.map_template_file == "loaded data", "Should be rendered!" - - -def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, tmp_path): - file_template_dir = tmp_path / "path" / "to" - file_template_dir.mkdir(parents=True, exist_ok=True) - file_template = file_template_dir / "file.ext" - file_template.write_text("loaded data") - - with set_current_task_instance_session(session=session): - with dag_maker(session=session, template_searchpath=tmp_path.__fspath__()): - mapped = _RenderTemplateFieldsValidationOperator.partial( - task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" - ).expand_kwargs( - [ - { - "map_template_literal": "{{ ds }}", - "map_static": "{{ ds }}", - "map_template_file": "/path/to/file.ext", - # This field is not tested since XCom inside a literal list - # is not rendered (matching BaseOperator rendering behavior). - "map_template_xcom": "", - } - ] - ) - - dr = dag_maker.create_dagrun() - mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session, map_index=0) - assert isinstance(mapped_ti.task, MappedOperator) - mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) - assert isinstance(mapped_ti.task, _RenderTemplateFieldsValidationOperator) - - assert mapped_ti.task.partial_template == "a", "Should be rendered!" - assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be rendered!" - assert mapped_ti.task.map_template_literal == "2016-01-01", "Should be rendered!" - assert mapped_ti.task.map_static == "{{ ds }}", "Should not be rendered!" - assert mapped_ti.task.map_template_file == "loaded data", "Should be rendered!" - - -def test_mapped_render_nested_template_fields(dag_maker, session): - with dag_maker(session=session): - MockOperatorWithNestedFields.partial( - task_id="t", arg2=NestedFields(field_1="{{ ti.task_id }}", field_2="value_2") - ).expand(arg1=["{{ ti.task_id }}", ["s", "{{ ti.task_id }}"]]) - - dr = dag_maker.create_dagrun() - decision = dr.task_instance_scheduling_decisions() - tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} - assert len(tis) == 2 - - ti = tis[("t", 0)] - ti.run(session=session) - assert ti.task.arg1 == "t" - assert ti.task.arg2.field_1 == "t" - assert ti.task.arg2.field_2 == "value_2" - - ti = tis[("t", 1)] - ti.run(session=session) - assert ti.task.arg1 == ["s", "t"] - assert ti.task.arg2.field_1 == "t" - assert ti.task.arg2.field_2 == "value_2" - - @pytest.mark.parametrize( ["num_existing_tis", "expected"], ( @@ -467,6 +279,45 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis assert indices == expected +def test_map_product_expansion(dag_maker, session): + """Test the cross-product effect of mapping two inputs""" + outputs = [] + + with dag_maker(dag_id="product", session=session) as dag: + + @dag.task + def emit_numbers(): + return [1, 2] + + @dag.task + def emit_letters(): + return {"a": "x", "b": "y", "c": "z"} + + @dag.task + def show(number, letter): + outputs.append((number, letter)) + + show.expand(number=emit_numbers(), letter=emit_letters()) + + dr = dag_maker.create_dagrun() + for fn in (emit_numbers, emit_letters): + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=fn.__name__, + run_id=dr.run_id, + map_index=-1, + length=len(fn.function()), + keys=None, + ) + ) + + session.flush() + show_task = dag.get_task("show") + mapped_tis, max_map_index = TaskMap.expand_mapped_task(show_task, dr.run_id, session=session) + assert max_map_index + 1 == len(mapped_tis) == 6 + + def _create_mapped_with_name_template_classic(*, task_id, map_names, template): class HasMapName(BaseOperator): def __init__(self, *, map_name: str, **kwargs): @@ -591,68 +442,6 @@ def test_expand_mapped_task_task_instance_mutation_hook(dag_maker, session, crea assert call.args[0].map_index == expected_map_index[index] -@pytest.mark.parametrize( - "map_index, expected", - [ - pytest.param(0, "2016-01-01", id="0"), - pytest.param(1, 2, id="1"), - ], -) -def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, map_index, expected): - with set_current_task_instance_session(session=session): - with dag_maker(session=session): - task1 = BaseOperator(task_id="op1") - mapped = MockOperator.partial(task_id="a", arg2="{{ ti.task_id }}").expand_kwargs(task1.output) - - dr = dag_maker.create_dagrun() - ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) - - ti.xcom_push(key=XCOM_RETURN_KEY, value=[{"arg1": "{{ ds }}"}, {"arg1": 2}], session=session) - - session.add( - TaskMap( - dag_id=dr.dag_id, - task_id=task1.task_id, - run_id=dr.run_id, - map_index=-1, - length=2, - keys=None, - ) - ) - session.flush() - - ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) - ti.refresh_from_task(mapped) - ti.map_index = map_index - assert isinstance(ti.task, MappedOperator) - mapped.render_template_fields(context=ti.get_template_context(session=session)) - assert isinstance(ti.task, MockOperator) - assert ti.task.arg1 == expected - assert ti.task.arg2 == "a" - - -def test_all_xcomargs_from_mapped_tasks_are_consumable(dag_maker, session): - class PushXcomOperator(MockOperator): - def __init__(self, arg1, **kwargs): - super().__init__(arg1=arg1, **kwargs) - - def execute(self, context): - return self.arg1 - - class ConsumeXcomOperator(PushXcomOperator): - def execute(self, context): - assert set(self.arg1) == {1, 2, 3} - - with dag_maker("test_all_xcomargs_from_mapped_tasks_are_consumable"): - op1 = PushXcomOperator.partial(task_id="op1").expand(arg1=[1, 2, 3]) - ConsumeXcomOperator(task_id="op2", arg1=op1.output) - - dr = dag_maker.create_dagrun() - tis = dr.get_task_instances(session=session) - for ti in tis: - ti.run() - - class TestMappedSetupTeardown: @staticmethod def get_states(dr): @@ -1474,7 +1263,14 @@ def my_teardown(val): my_work(s) tg1, tg2 = dag.task_group.children.values() tg1 >> tg2 - dr = dag.test() + + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as supervisor_comms: + # TODO: TaskSDK: this is a bit of a hack that we need to stub this at all. `dag.test()` should + # really work without this! + supervisor_comms.get_message.return_value = XComCountResponse(len=3) + dr = dag.test() states = self.get_states(dr) expected = { "tg_1.my_pre_setup": "success", diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py index 5c3654d7e69c5..8f3b7e65dc386 100644 --- a/tests/models/test_skipmixin.py +++ b/tests/models/test_skipmixin.py @@ -120,6 +120,7 @@ def get_state(ti): assert executed_states == expected_states + @pytest.mark.need_serialized_dag def test_mapped_tasks_skip_all_except(self, dag_maker): with dag_maker("dag_test_skip_all_except") as dag: diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 188e9b718acbe..3d985394646d0 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -22,7 +22,6 @@ import operator import os import pathlib -import pickle import signal import sys import urllib @@ -59,7 +58,6 @@ from airflow.models.dag import DAG from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun -from airflow.models.expandinput import EXPAND_INPUT_EMPTY, NotFullyPopulated from airflow.models.pool import Pool from airflow.models.renderedtifields import RenderedTaskInstanceFields from airflow.models.serialized_dag import SerializedDagModel @@ -73,7 +71,7 @@ from airflow.models.taskmap import TaskMap from airflow.models.taskreschedule import TaskReschedule from airflow.models.variable import Variable -from airflow.models.xcom import LazyXComSelectSequence, XCom +from airflow.models.xcom import XCom from airflow.notifications.basenotifier import BaseNotifier from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.empty import EmptyOperator @@ -95,7 +93,6 @@ from airflow.utils.session import create_session, provide_session from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.task_group import TaskGroup -from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.types import DagRunTriggeredByType, DagRunType from airflow.utils.xcom import XCOM_RETURN_KEY @@ -1607,192 +1604,6 @@ def test_are_dependents_done( session.flush() assert ti.are_dependents_done(session) == expected_are_dependents_done - def test_xcom_pull(self, dag_maker): - """Test xcom_pull, using different filtering methods.""" - with dag_maker(dag_id="test_xcom") as dag: - task_1 = EmptyOperator(task_id="test_xcom_1") - task_2 = EmptyOperator(task_id="test_xcom_2") - - dagrun = dag_maker.create_dagrun(start_date=timezone.datetime(2016, 6, 1, 0, 0, 0)) - ti1 = dagrun.get_task_instance(task_1.task_id) - - # Push a value - ti1.xcom_push(key="foo", value="bar") - - # Push another value with the same key (but by a different task) - XCom.set(key="foo", value="baz", task_id=task_2.task_id, dag_id=dag.dag_id, run_id=dagrun.run_id) - - # Pull with no arguments - result = ti1.xcom_pull() - assert result is None - # Pull the value pushed most recently by any task. - result = ti1.xcom_pull(key="foo") - assert result in "baz" - # Pull the value pushed by the first task - result = ti1.xcom_pull(task_ids="test_xcom_1", key="foo") - assert result == "bar" - # Pull the value pushed by the second task - result = ti1.xcom_pull(task_ids="test_xcom_2", key="foo") - assert result == "baz" - # Pull the values pushed by both tasks & Verify Order of task_ids pass & values returned - result = ti1.xcom_pull(task_ids=["test_xcom_1", "test_xcom_2"], key="foo") - assert result == ["bar", "baz"] - - def test_xcom_pull_mapped(self, dag_maker, session): - with dag_maker(dag_id="test_xcom", session=session): - # Use the private _expand() method to avoid the empty kwargs check. - # We don't care about how the operator runs here, only its presence. - task_1 = EmptyOperator.partial(task_id="task_1")._expand(EXPAND_INPUT_EMPTY, strict=False) - EmptyOperator(task_id="task_2") - - dagrun = dag_maker.create_dagrun(start_date=timezone.datetime(2016, 6, 1, 0, 0, 0)) - - ti_1_0 = dagrun.get_task_instance("task_1", session=session) - ti_1_0.map_index = 0 - ti_1_1 = session.merge(TI(task_1, run_id=dagrun.run_id, map_index=1, state=ti_1_0.state)) - session.flush() - - ti_1_0.xcom_push(key=XCOM_RETURN_KEY, value="a", session=session) - ti_1_1.xcom_push(key=XCOM_RETURN_KEY, value="b", session=session) - - ti_2 = dagrun.get_task_instance("task_2", session=session) - - assert set(ti_2.xcom_pull(["task_1"], session=session)) == {"a", "b"} # Ordering not guaranteed. - assert ti_2.xcom_pull(["task_1"], map_indexes=0, session=session) == ["a"] - - assert ti_2.xcom_pull(map_indexes=[0, 1], session=session) == ["a", "b"] - assert ti_2.xcom_pull("task_1", map_indexes=[1, 0], session=session) == ["b", "a"] - assert ti_2.xcom_pull(["task_1"], map_indexes=[0, 1], session=session) == ["a", "b"] - - assert ti_2.xcom_pull("task_1", map_indexes=1, session=session) == "b" - assert list(ti_2.xcom_pull("task_1", session=session)) == ["a", "b"] - - def test_xcom_pull_after_success(self, create_task_instance): - """ - tests xcom set/clear relative to a task in a 'success' rerun scenario - """ - key = "xcom_key" - value = "xcom_value" - - ti = create_task_instance( - dag_id="test_xcom", - schedule="@monthly", - task_id="test_xcom", - pool="test_xcom", - serialized=True, - ) - - ti.run(mark_success=True) - ti.xcom_push(key=key, value=value) - assert ti.xcom_pull(task_ids="test_xcom", key=key) == value - ti.run() - # Check that we do not clear Xcom until the task is certain to execute - assert ti.xcom_pull(task_ids="test_xcom", key=key) == value - # Xcom shouldn't be cleared if the task doesn't execute, even if dependencies are ignored - ti.run(ignore_all_deps=True, mark_success=True) - assert ti.xcom_pull(task_ids="test_xcom", key=key) == value - # Xcom IS finally cleared once task has executed - ti.run(ignore_all_deps=True) - assert ti.xcom_pull(task_ids="test_xcom", key=key) is None - - def test_xcom_pull_after_deferral(self, create_task_instance, session): - """ - tests xcom will not clear before a task runs its next method after deferral. - """ - - key = "xcom_key" - value = "xcom_value" - - ti = create_task_instance( - dag_id="test_xcom", - schedule="@monthly", - task_id="test_xcom", - pool="test_xcom", - ) - - ti.run(mark_success=True) - ti.xcom_push(key=key, value=value) - - ti.next_method = "execute" - session.merge(ti) - session.commit() - - ti.run(ignore_all_deps=True) - assert ti.xcom_pull(task_ids="test_xcom", key=key) == value - - def test_xcom_pull_different_logical_date(self, create_task_instance): - """ - tests xcom fetch behavior with different logical dates, using - both xcom_pull with "include_prior_dates" and without - """ - key = "xcom_key" - value = "xcom_value" - - ti = create_task_instance( - dag_id="test_xcom", - schedule="@monthly", - task_id="test_xcom", - pool="test_xcom", - ) - exec_date = ti.dag_run.logical_date - - ti.run(mark_success=True) - ti.xcom_push(key=key, value=value) - assert ti.xcom_pull(task_ids="test_xcom", key=key) == value - ti.run() - exec_date += datetime.timedelta(days=1) - dr = ti.task.dag.create_dagrun( - run_id="test2", - run_type=DagRunType.MANUAL, - logical_date=exec_date, - data_interval=(exec_date, exec_date), - run_after=exec_date, - state=None, - triggered_by=DagRunTriggeredByType.TEST, - ) - ti = TI(task=ti.task, run_id=dr.run_id) - ti.run() - # We have set a new logical date (and did not pass in - # 'include_prior_dates'which means this task should now have a cleared - # xcom value - assert ti.xcom_pull(task_ids="test_xcom", key=key) is None - # We *should* get a value using 'include_prior_dates' - assert ti.xcom_pull(task_ids="test_xcom", key=key, include_prior_dates=True) == value - - def test_xcom_pull_different_run_ids(self, create_task_instance): - """ - tests xcom fetch behavior w/different run ids - """ - key = "xcom_key" - task_id = "test_xcom" - diff_run_id = "diff_run_id" - same_run_id_value = "xcom_value_same_run_id" - diff_run_id_value = "xcom_value_different_run_id" - - ti_same_run_id = create_task_instance( - dag_id="test_xcom", - task_id=task_id, - ) - ti_same_run_id.run(mark_success=True) - ti_same_run_id.xcom_push(key=key, value=same_run_id_value) - - ti_diff_run_id = create_task_instance( - dag_id="test_xcom", - task_id=task_id, - run_id=diff_run_id, - ) - ti_diff_run_id.run(mark_success=True) - ti_diff_run_id.xcom_push(key=key, value=diff_run_id_value) - - assert ( - ti_same_run_id.xcom_pull(run_id=ti_same_run_id.dag_run.run_id, task_ids=task_id, key=key) - == same_run_id_value - ) - assert ( - ti_same_run_id.xcom_pull(run_id=ti_diff_run_id.dag_run.run_id, task_ids=task_id, key=key) - == diff_run_id_value - ) - def test_xcom_push_flag(self, dag_maker): """ Tests the option for Operators to push XComs @@ -2826,6 +2637,7 @@ def producer_with_inactive(*, outlet_events): assert asset_alias_obj.assets[0].uri == asset_name @pytest.mark.want_activate_assets(True) + @pytest.mark.need_serialized_dag def test_inlet_asset_extra(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset @@ -2870,11 +2682,15 @@ def read(*, inlet_events): # Run "write1", "write2", and "write3" (in this order). decision = dr.task_instance_scheduling_decisions(session=session) for ti in sorted(decision.schedulable_tis, key=operator.attrgetter("task_id")): + # TODO: TaskSDK #45549 + ti.task = dag_maker.dag.get_task(ti.task_id) ti.run(session=session) # Run "read". decision = dr.task_instance_scheduling_decisions(session=session) for ti in decision.schedulable_tis: + # TODO: TaskSDK #45549 + ti.task = dag_maker.dag.get_task(ti.task_id) ti.run(session=session) # Should be done. @@ -2882,6 +2698,7 @@ def read(*, inlet_events): assert read_task_evaluated @pytest.mark.want_activate_assets(True) + @pytest.mark.need_serialized_dag def test_inlet_asset_alias_extra(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset, AssetAlias @@ -2933,17 +2750,22 @@ def read(*, inlet_events): # Run "write1", "write2", and "write3" (in this order). decision = dr.task_instance_scheduling_decisions(session=session) for ti in sorted(decision.schedulable_tis, key=operator.attrgetter("task_id")): + # TODO: TaskSDK #45549 + ti.task = dag_maker.dag.get_task(ti.task_id) ti.run(session=session) # Run "read". decision = dr.task_instance_scheduling_decisions(session=session) for ti in decision.schedulable_tis: + # TODO: TaskSDK #45549 + ti.task = dag_maker.dag.get_task(ti.task_id) ti.run(session=session) # Should be done. assert not dr.task_instance_scheduling_decisions(session=session).schedulable_tis assert read_task_evaluated + @pytest.mark.need_serialized_dag def test_inlet_unresolved_asset_alias(self, dag_maker, session): asset_alias_name = "test_inlet_asset_extra_asset_alias" @@ -2964,6 +2786,8 @@ def read(*, inlet_events): dr: DagRun = dag_maker.create_dagrun() for ti in dr.get_task_instances(session=session): + # TODO: TaskSDK #45549 + ti.task = dag_maker.dag.get_task(ti.task_id) ti.run(session=session) # Should be done. @@ -4934,80 +4758,6 @@ def show(value): ti.run() assert outputs == expected_outputs - def test_map_product(self, dag_maker, session): - outputs = [] - - with dag_maker(dag_id="product", session=session) as dag: - - @dag.task - def emit_numbers(): - return [1, 2] - - @dag.task - def emit_letters(): - return {"a": "x", "b": "y", "c": "z"} - - @dag.task - def show(number, letter): - outputs.append((number, letter)) - - show.expand(number=emit_numbers(), letter=emit_letters()) - - dag_run = dag_maker.create_dagrun() - for task_id in ["emit_numbers", "emit_letters"]: - ti = dag_run.get_task_instance(task_id, session=session) - ti.refresh_from_task(dag.get_task(task_id)) - ti.run() - - show_task = dag.get_task("show") - mapped_tis, max_map_index = TaskMap.expand_mapped_task(show_task, dag_run.run_id, session=session) - assert max_map_index + 1 == len(mapped_tis) == 6 - - for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): - ti.refresh_from_task(show_task) - ti.run() - assert outputs == [ - (1, ("a", "x")), - (1, ("b", "y")), - (1, ("c", "z")), - (2, ("a", "x")), - (2, ("b", "y")), - (2, ("c", "z")), - ] - - def test_map_product_same(self, dag_maker, session): - """Test a mapped task can refer to the same source multiple times.""" - outputs = [] - - with dag_maker(dag_id="product_same", session=session) as dag: - - @dag.task - def emit_numbers(): - return [1, 2] - - @dag.task - def show(a, b): - outputs.append((a, b)) - - emit_task = emit_numbers() - show.expand(a=emit_task, b=emit_task) - - dag_run = dag_maker.create_dagrun() - ti = dag_run.get_task_instance("emit_numbers", session=session) - ti.refresh_from_task(dag.get_task("emit_numbers")) - ti.run() - - show_task = dag.get_task("show") - with pytest.raises(NotFullyPopulated): - assert show_task.get_parse_time_mapped_ti_count() - mapped_tis, max_map_index = TaskMap.expand_mapped_task(show_task, dag_run.run_id, session=session) - assert max_map_index + 1 == len(mapped_tis) == 4 - - for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): - ti.refresh_from_task(show_task) - ti.run() - assert outputs == [(1, 1), (1, 2), (2, 1), (2, 2)] - def test_map_literal_cross_product(self, dag_maker, session): """Test a mapped task with literal cross product args expand properly.""" outputs = [] @@ -5112,132 +4862,6 @@ def _get_lazy_xcom_access_expected_sql_lines() -> list[str]: raise RuntimeError(f"unknown backend {backend!r}") -def test_lazy_xcom_access_does_not_pickle_session(dag_maker, session): - with dag_maker(session=session): - EmptyOperator(task_id="t") - - run: DagRun = dag_maker.create_dagrun() - run.get_task_instance("t", session=session).xcom_push("xxx", 123, session=session) - - with set_current_task_instance_session(session=session): - original = LazyXComSelectSequence.from_select( - select(XCom.value).filter_by( - dag_id=run.dag_id, - run_id=run.run_id, - task_id="t", - map_index=-1, - key="xxx", - ), - order_by=(), - ) - processed = pickle.loads(pickle.dumps(original)) - - # After the object went through pickling, the underlying ORM query should be - # replaced by one backed by a literal SQL string with all variables binded. - sql_lines = [line.strip() for line in str(processed._select_asc.compile(None)).splitlines()] - assert sql_lines == _get_lazy_xcom_access_expected_sql_lines() - - assert len(processed) == 1 - assert list(processed) == [123] - - -@mock.patch("airflow.models.taskinstance.XCom.deserialize_value", side_effect=XCom.deserialize_value) -def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_value, dag_maker, session): - """Ensure we access XCom lazily when pulling from a mapped operator.""" - with dag_maker(dag_id="test_xcom", session=session): - # Use the private _expand() method to avoid the empty kwargs check. - # We don't care about how the operator runs here, only its presence. - task_1 = EmptyOperator.partial(task_id="task_1")._expand(EXPAND_INPUT_EMPTY, strict=False) - EmptyOperator(task_id="task_2") - - dagrun = dag_maker.create_dagrun() - - ti_1_0 = dagrun.get_task_instance("task_1", session=session) - ti_1_0.map_index = 0 - ti_1_1 = session.merge(TaskInstance(task_1, run_id=dagrun.run_id, map_index=1, state=ti_1_0.state)) - session.flush() - - ti_1_0.xcom_push(key=XCOM_RETURN_KEY, value="a", session=session) - ti_1_1.xcom_push(key=XCOM_RETURN_KEY, value="b", session=session) - - ti_2 = dagrun.get_task_instance("task_2", session=session) - - # Simply pulling the joined XCom value should not deserialize. - joined = ti_2.xcom_pull("task_1", session=session) - assert isinstance(joined, LazyXComSelectSequence) - assert mock_deserialize_value.call_count == 0 - - # Only when we go through the iterable does deserialization happen. - it = iter(joined) - assert next(it) == "a" - assert mock_deserialize_value.call_count == 1 - assert next(it) == "b" - assert mock_deserialize_value.call_count == 2 - with pytest.raises(StopIteration): - next(it) - - -def test_ti_mapped_depends_on_mapped_xcom_arg(dag_maker, session): - with dag_maker(session=session) as dag: - - @dag.task - def add_one(x): - return x + 1 - - two_three_four = add_one.expand(x=[1, 2, 3]) - add_one.expand(x=two_three_four) - - dagrun = dag_maker.create_dagrun() - for map_index in range(3): - ti = dagrun.get_task_instance("add_one", map_index=map_index, session=session) - ti.refresh_from_task(dag.get_task("add_one")) - ti.run() - - task_345 = dag.get_task("add_one__1") - for ti in TaskMap.expand_mapped_task(task_345, dagrun.run_id, session=session)[0]: - ti.refresh_from_task(task_345) - ti.run() - - query = XCom.get_many(run_id=dagrun.run_id, task_ids=["add_one__1"], session=session) - assert [x.value for x in query.order_by(None).order_by(XCom.map_index)] == [3, 4, 5] - - -def test_mapped_upstream_return_none_should_skip(dag_maker, session): - results = set() - - with dag_maker(dag_id="test_mapped_upstream_return_none_should_skip", session=session) as dag: - - @dag.task() - def transform(value): - if value == "b": # Now downstream doesn't map against this! - return None - return value - - @dag.task() - def pull(value): - results.add(value) - - original = ["a", "b", "c"] - transformed = transform.expand(value=original) # ["a", None, "c"] - pull.expand(value=transformed) # ["a", "c"] - - dr = dag_maker.create_dagrun() - - decision = dr.task_instance_scheduling_decisions(session=session) - tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} - assert sorted(tis) == [("transform", 0), ("transform", 1), ("transform", 2)] - for ti in tis.values(): - ti.run() - - decision = dr.task_instance_scheduling_decisions(session=session) - tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} - assert sorted(tis) == [("pull", 0), ("pull", 1)] - for ti in tis.values(): - ti.run() - - assert results == {"a", "c"} - - def test_expand_non_templated_field(dag_maker, session): """Test expand on non-templated fields sets upstream deps properly.""" diff --git a/tests/models/test_taskmap.py b/tests/models/test_taskmap.py new file mode 100644 index 0000000000000..10fb4d99bba09 --- /dev/null +++ b/tests/models/test_taskmap.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.models.taskinstance import TaskInstance +from airflow.models.taskmap import TaskMap, TaskMapVariant +from airflow.providers.standard.operators.empty import EmptyOperator + +pytestmark = pytest.mark.db_test + + +def test_task_map_from_task_instance_xcom(): + task = EmptyOperator(task_id="test_task") + ti = TaskInstance(task=task, run_id="test_run", map_index=0) + ti.dag_id = "test_dag" + value = {"key1": "value1", "key2": "value2"} + + # Test case where run_id is not None + task_map = TaskMap.from_task_instance_xcom(ti, value) + assert task_map.dag_id == ti.dag_id + assert task_map.task_id == ti.task_id + assert task_map.run_id == ti.run_id + assert task_map.map_index == ti.map_index + assert task_map.length == len(value) + assert task_map.keys == list(value) + + # Test case where run_id is None + ti.run_id = None + with pytest.raises(ValueError, match="cannot record task map for unrun task instance"): + TaskMap.from_task_instance_xcom(ti, value) + + +def test_task_map_with_invalid_task_instance(): + task = EmptyOperator(task_id="test_task") + ti = TaskInstance(task=task, run_id=None, map_index=0) + ti.dag_id = "test_dag" + + # Define some arbitrary XCom-like value data + value = {"example_key": "example_value"} + + with pytest.raises(ValueError, match="cannot record task map for unrun task instance"): + TaskMap.from_task_instance_xcom(ti, value) + + +def test_task_map_variant(): + # Test case where keys is None + task_map = TaskMap( + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + map_index=0, + length=3, + keys=None, + ) + assert task_map.variant == TaskMapVariant.LIST + + # Test case where keys is not None + task_map.keys = ["key1", "key2"] + assert task_map.variant == TaskMapVariant.DICT diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py deleted file mode 100644 index 6d1ea09f3b217..0000000000000 --- a/tests/models/test_xcom_arg_map.py +++ /dev/null @@ -1,433 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pytest - -from airflow.exceptions import AirflowSkipException -from airflow.models.taskinstance import TaskInstance -from airflow.models.taskmap import TaskMap, TaskMapVariant -from airflow.providers.standard.operators.empty import EmptyOperator -from airflow.utils.state import TaskInstanceState -from airflow.utils.trigger_rule import TriggerRule - -pytestmark = pytest.mark.db_test - - -def test_xcom_map(dag_maker, session): - results = set() - with dag_maker(session=session) as dag: - - @dag.task - def push(): - return ["a", "b", "c"] - - @dag.task - def pull(value): - results.add(value) - - pull.expand_kwargs(push().map(lambda v: {"value": v * 2})) - - # The function passed to "map" is *NOT* a task. - assert set(dag.task_dict) == {"push", "pull"} - - dr = dag_maker.create_dagrun(session=session) - - # Run "push". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - session.commit() - - # Run "pull". - decision = dr.task_instance_scheduling_decisions(session=session) - tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} - assert sorted(tis) == [("pull", 0), ("pull", 1), ("pull", 2)] - for ti in tis.values(): - ti.run(session=session) - - assert results == {"aa", "bb", "cc"} - - -def test_xcom_map_transform_to_none(dag_maker, session): - results = set() - - with dag_maker(session=session) as dag: - - @dag.task() - def push(): - return ["a", "b", "c"] - - @dag.task() - def pull(value): - results.add(value) - - def c_to_none(v): - if v == "c": - return None - return v - - pull.expand(value=push().map(c_to_none)) - - dr = dag_maker.create_dagrun(session=session) - - # Run "push". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - - # Run "pull". This should automatically convert "c" to None. - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - assert results == {"a", "b", None} - - -def test_xcom_convert_to_kwargs_fails_task(dag_maker, session): - results = set() - - with dag_maker(session=session) as dag: - - @dag.task() - def push(): - return ["a", "b", "c"] - - @dag.task() - def pull(value): - results.add(value) - - def c_to_none(v): - if v == "c": - return None - return {"value": v} - - pull.expand_kwargs(push().map(c_to_none)) - - dr = dag_maker.create_dagrun(session=session) - - # Run "push". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - - # Prepare to run "pull"... - decision = dr.task_instance_scheduling_decisions(session=session) - tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} - - # The first two "pull" tis should also succeed. - tis[("pull", 0)].run(session=session) - tis[("pull", 1)].run(session=session) - - # But the third one fails because the map() result cannot be used as kwargs. - with pytest.raises(ValueError) as ctx: - tis[("pull", 2)].run(session=session) - assert str(ctx.value) == "expand_kwargs() expects a list[dict], not list[None]" - - assert [tis[("pull", i)].state for i in range(3)] == [ - TaskInstanceState.SUCCESS, - TaskInstanceState.SUCCESS, - TaskInstanceState.FAILED, - ] - - -def test_xcom_map_error_fails_task(dag_maker, session): - with dag_maker(session=session) as dag: - - @dag.task() - def push(): - return ["a", "b", "c"] - - @dag.task() - def pull(value): - print(value) - - def does_not_work_with_c(v): - if v == "c": - raise ValueError("nope") - return {"value": v * 2} - - pull.expand_kwargs(push().map(does_not_work_with_c)) - - dr = dag_maker.create_dagrun(session=session) - - # The "push" task should not fail. - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - assert [ti.state for ti in decision.schedulable_tis] == [TaskInstanceState.SUCCESS] - - # Prepare to run "pull"... - decision = dr.task_instance_scheduling_decisions(session=session) - tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} - - # The first two "pull" tis should also succeed. - tis[("pull", 0)].run(session=session) - tis[("pull", 1)].run(session=session) - - # But the third one (for "c") will fail. - with pytest.raises(ValueError) as ctx: - tis[("pull", 2)].run(session=session) - assert str(ctx.value) == "nope" - - assert [tis[("pull", i)].state for i in range(3)] == [ - TaskInstanceState.SUCCESS, - TaskInstanceState.SUCCESS, - TaskInstanceState.FAILED, - ] - - -def test_task_map_from_task_instance_xcom(): - task = EmptyOperator(task_id="test_task") - ti = TaskInstance(task=task, run_id="test_run", map_index=0) - ti.dag_id = "test_dag" - value = {"key1": "value1", "key2": "value2"} - - # Test case where run_id is not None - task_map = TaskMap.from_task_instance_xcom(ti, value) - assert task_map.dag_id == ti.dag_id - assert task_map.task_id == ti.task_id - assert task_map.run_id == ti.run_id - assert task_map.map_index == ti.map_index - assert task_map.length == len(value) - assert task_map.keys == list(value) - - # Test case where run_id is None - ti.run_id = None - with pytest.raises(ValueError, match="cannot record task map for unrun task instance"): - TaskMap.from_task_instance_xcom(ti, value) - - -def test_task_map_with_invalid_task_instance(): - task = EmptyOperator(task_id="test_task") - ti = TaskInstance(task=task, run_id=None, map_index=0) - ti.dag_id = "test_dag" - - # Define some arbitrary XCom-like value data - value = {"example_key": "example_value"} - - with pytest.raises(ValueError, match="cannot record task map for unrun task instance"): - TaskMap.from_task_instance_xcom(ti, value) - - -def test_task_map_variant(): - # Test case where keys is None - task_map = TaskMap( - dag_id="test_dag", - task_id="test_task", - run_id="test_run", - map_index=0, - length=3, - keys=None, - ) - assert task_map.variant == TaskMapVariant.LIST - - # Test case where keys is not None - task_map.keys = ["key1", "key2"] - assert task_map.variant == TaskMapVariant.DICT - - -def test_xcom_map_raise_to_skip(dag_maker, session): - result = None - - with dag_maker(session=session) as dag: - - @dag.task() - def push(): - return ["a", "b", "c"] - - @dag.task() - def forward(value): - return value - - @dag.task(trigger_rule=TriggerRule.ALL_DONE) - def collect(value): - nonlocal result - result = list(value) - - def skip_c(v): - if v == "c": - raise AirflowSkipException - return {"value": v} - - collect(value=forward.expand_kwargs(push().map(skip_c))) - - dr = dag_maker.create_dagrun(session=session) - - # Run "push". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - - # Run "forward". This should automatically skip "c". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - - # Now "collect" should only get "a" and "b". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - assert result == ["a", "b"] - - -def test_xcom_map_nest(dag_maker, session): - results = set() - - with dag_maker(session=session) as dag: - - @dag.task() - def push(): - return ["a", "b", "c"] - - @dag.task() - def pull(value): - results.add(value) - - converted = push().map(lambda v: v * 2).map(lambda v: {"value": v}) - pull.expand_kwargs(converted) - - dr = dag_maker.create_dagrun(session=session) - - # Run "push". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - - session.flush() - session.commit() - - # Now "pull" should apply the mapping functions in order. - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - assert results == {"aa", "bb", "cc"} - - -def test_xcom_map_zip_nest(dag_maker, session): - results = set() - - with dag_maker(session=session) as dag: - - @dag.task - def push_letters(): - return ["a", "b", "c", "d"] - - @dag.task - def push_numbers(): - return [1, 2, 3, 4] - - @dag.task - def pull(value): - results.add(value) - - doubled = push_numbers().map(lambda v: v * 2) - combined = doubled.zip(push_letters()) - - def convert_zipped(zipped): - letter, number = zipped - return letter * number - - pull.expand(value=combined.map(convert_zipped)) - - dr = dag_maker.create_dagrun(session=session) - - # Run "push_letters" and "push_numbers". - decision = dr.task_instance_scheduling_decisions(session=session) - assert decision.schedulable_tis - assert all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis) - for ti in decision.schedulable_tis: - ti.run(session=session) - session.commit() - - # Run "pull". - decision = dr.task_instance_scheduling_decisions(session=session) - assert decision.schedulable_tis - assert all(ti.task_id == "pull" for ti in decision.schedulable_tis) - for ti in decision.schedulable_tis: - ti.run(session=session) - - assert results == {"aa", "bbbb", "cccccc", "dddddddd"} - - -def test_xcom_concat(dag_maker, session): - from airflow.sdk.definitions.xcom_arg import _ConcatResult - - agg_results = set() - all_results = None - - with dag_maker(session=session) as dag: - - @dag.task - def push_letters(): - return ["a", "b", "c"] - - @dag.task - def push_numbers(): - return [1, 2] - - @dag.task - def pull_one(value): - agg_results.add(value) - - @dag.task - def pull_all(value): - assert isinstance(value, _ConcatResult) - assert value[0] == "a" - assert value[1] == "b" - assert value[2] == "c" - assert value[3] == 1 - assert value[4] == 2 - with pytest.raises(IndexError): - value[5] - assert value[-5] == "a" - assert value[-4] == "b" - assert value[-3] == "c" - assert value[-2] == 1 - assert value[-1] == 2 - with pytest.raises(IndexError): - value[-6] - nonlocal all_results - all_results = list(value) - - pushed_values = push_letters().concat(push_numbers()) - - pull_one.expand(value=pushed_values) - pull_all(pushed_values) - - dr = dag_maker.create_dagrun(session=session) - - # Run "push_letters" and "push_numbers". - decision = dr.task_instance_scheduling_decisions(session=session) - assert len(decision.schedulable_tis) == 2 - assert all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis) - for ti in decision.schedulable_tis: - ti.run(session=session) - session.commit() - - # Run "pull_one" and "pull_all". - decision = dr.task_instance_scheduling_decisions(session=session) - assert len(decision.schedulable_tis) == 6 - assert all(ti.task_id.startswith("pull_") for ti in decision.schedulable_tis) - for ti in decision.schedulable_tis: - ti.run(session=session) - - assert agg_results == {"a", "b", "c", 1, 2} - assert all_results == ["a", "b", "c", 1, 2] - - decision = dr.task_instance_scheduling_decisions(session=session) - assert not decision.schedulable_tis diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 37364e2fcbbd1..281bcc73c346b 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -40,7 +40,7 @@ from typing import TYPE_CHECKING from unittest import mock -import attr +import attrs import pendulum import pytest from dateutil.relativedelta import FR, relativedelta @@ -694,7 +694,7 @@ def validate_deserialized_task( } else: # Promised to be mapped by the assert above. assert isinstance(serialized_task, MappedOperator) - fields_to_check = {f.name for f in attr.fields(MappedOperator)} + fields_to_check = {f.name for f in attrs.fields(MappedOperator)} fields_to_check -= { "map_index_template", # Matching logic in BaseOperator.get_serialized_fields(). @@ -707,6 +707,7 @@ def validate_deserialized_task( # Checked separately. "operator_class", "partial_kwargs", + "expand_input", } fields_to_check |= {"deps"} @@ -753,6 +754,9 @@ def validate_deserialized_task( original_partial_kwargs = {**default_partial_kwargs, **task.partial_kwargs} assert serialized_partial_kwargs == original_partial_kwargs + # ExpandInputs have different classes between scheduler and definition + assert attrs.asdict(serialized_task.expand_input) == attrs.asdict(task.expand_input) + @pytest.mark.parametrize( "dag_start_date, task_start_date, expected_task_start_date", [ @@ -2452,7 +2456,8 @@ def test_operator_expand_serde(): def test_operator_expand_xcomarg_serde(): - from airflow.models.xcom_arg import PlainXComArg, XComArg + from airflow.models.xcom_arg import SchedulerPlainXComArg + from airflow.sdk.definitions.xcom_arg import XComArg from airflow.serialization.serialized_objects import _XComRef with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as dag: @@ -2501,13 +2506,13 @@ def test_operator_expand_xcomarg_serde(): serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) xcom_arg = serialized_dag.task_dict["task_2"].expand_input.value["arg2"] - assert isinstance(xcom_arg, PlainXComArg) + assert isinstance(xcom_arg, SchedulerPlainXComArg) assert xcom_arg.operator is serialized_dag.task_dict["op1"] @pytest.mark.parametrize("strict", [True, False]) def test_operator_expand_kwargs_literal_serde(strict): - from airflow.models.xcom_arg import PlainXComArg, XComArg + from airflow.sdk.definitions.xcom_arg import XComArg from airflow.serialization.serialized_objects import _XComRef with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as dag: @@ -2567,15 +2572,16 @@ def test_operator_expand_kwargs_literal_serde(strict): serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) resolved_expand_value = serialized_dag.task_dict["task_2"].expand_input.value - resolved_expand_value == [ + assert resolved_expand_value == [ {"a": "x"}, - {"a": PlainXComArg(serialized_dag.task_dict["op1"])}, + {"a": _XComRef({"task_id": "op1", "key": "return_value"})}, ] @pytest.mark.parametrize("strict", [True, False]) def test_operator_expand_kwargs_xcomarg_serde(strict): - from airflow.models.xcom_arg import PlainXComArg, XComArg + from airflow.models.xcom_arg import SchedulerPlainXComArg + from airflow.sdk.definitions.xcom_arg import XComArg from airflow.serialization.serialized_objects import _XComRef with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as dag: @@ -2620,7 +2626,7 @@ def test_operator_expand_kwargs_xcomarg_serde(strict): serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) xcom_arg = serialized_dag.task_dict["task_2"].expand_input.value - assert isinstance(xcom_arg, PlainXComArg) + assert isinstance(xcom_arg, SchedulerPlainXComArg) assert xcom_arg.operator is serialized_dag.task_dict["op1"] @@ -2901,7 +2907,7 @@ def x(arg1, arg2, arg3): def test_mapped_task_group_serde(): from airflow.decorators.task_group import task_group - from airflow.models.expandinput import DictOfListsExpandInput + from airflow.models.expandinput import SchedulerDictOfListsExpandInput from airflow.sdk.definitions.taskgroup import MappedTaskGroup with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as dag: @@ -2944,7 +2950,7 @@ def tg(a: str) -> None: serde_dag = SerializedDAG.deserialize_dag(ser_dag[Encoding.VAR]) serde_tg = serde_dag.task_group.children["tg"] assert isinstance(serde_tg, MappedTaskGroup) - assert serde_tg._expand_input == DictOfListsExpandInput({"a": [".", ".."]}) + assert serde_tg._expand_input == SchedulerDictOfListsExpandInput({"a": [".", ".."]}) @pytest.mark.db_test diff --git a/tests/ti_deps/deps/test_mapped_task_upstream_dep.py b/tests/ti_deps/deps/test_mapped_task_upstream_dep.py index 3f1abdc3136ca..b47da12e672a1 100644 --- a/tests/ti_deps/deps/test_mapped_task_upstream_dep.py +++ b/tests/ti_deps/deps/test_mapped_task_upstream_dep.py @@ -27,8 +27,9 @@ from airflow.ti_deps.deps.base_ti_dep import TIDepStatus from airflow.ti_deps.deps.mapped_task_upstream_dep import MappedTaskUpstreamDep from airflow.utils.state import TaskInstanceState +from airflow.utils.xcom import XCOM_RETURN_KEY -pytestmark = pytest.mark.db_test +pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag] if TYPE_CHECKING: from sqlalchemy.orm.session import Session @@ -274,35 +275,6 @@ def tg(x, y): assert not schedulable_tis -def test_nested_mapped_task_groups(dag_maker, session: Session): - from airflow.decorators import task, task_group - - with dag_maker(session=session): - - @task - def t(): - return [[1, 2], [3, 4]] - - @task - def m(x): - return x - - @task_group - def g1(x): - @task_group - def g2(y): - return m(y) - - return g2.expand(y=x) - - g1.expand(x=t()) - - # Add a test once nested mapped task groups become supported - with pytest.raises(NotImplementedError) as ctx: - dag_maker.create_dagrun() - assert str(ctx.value) == "" - - def test_mapped_in_mapped_task_group(dag_maker, session: Session): from airflow.decorators import task, task_group @@ -391,7 +363,7 @@ def tg(): assert not get_dep_statuses(dr, "tg.op1", session) -@pytest.mark.parametrize("upstream_instance_state", [None, SKIPPED, FAILED]) +@pytest.mark.parametrize("upstream_instance_state", [SUCCESS, SKIPPED, FAILED]) @pytest.mark.parametrize("testcase", ["task", "group"]) def test_upstream_mapped_expanded( dag_maker, session: Session, upstream_instance_state: TaskInstanceState | None, testcase: str @@ -432,35 +404,46 @@ def tg(x): assert sorted(schedulable_tis) == [f"{mapped_task_1}_0", f"{mapped_task_1}_1", f"{mapped_task_1}_2"] assert not finished_tis_states - # Run expanded m1 tasks - schedulable_tis[f"{mapped_task_1}_1"].run() - schedulable_tis[f"{mapped_task_1}_2"].run() - if upstream_instance_state != FAILED: - schedulable_tis[f"{mapped_task_1}_0"].run() - else: - with pytest.raises(AirflowFailException): - schedulable_tis[f"{mapped_task_1}_0"].run() + # "Run" expanded m1 tasks + for ti, state in ( + (schedulable_tis[f"{mapped_task_1}_1"], SUCCESS), + (schedulable_tis[f"{mapped_task_1}_2"], SUCCESS), + (schedulable_tis[f"{mapped_task_1}_0"], upstream_instance_state), + ): + ti.state = state + if state == SUCCESS: + ti.xcom_push(XCOM_RETURN_KEY, "doesn't matter", session=session) + session.flush() schedulable_tis, finished_tis_states = _one_scheduling_decision_iteration(dr, session) # Expect that m2 can still be expanded since the dependency check does not fail. If one of the expanded # m1 tasks fails or is skipped, there is one fewer m2 expanded tasks expected_schedulable = [f"{mapped_task_2}_0", f"{mapped_task_2}_1"] - if upstream_instance_state is None: + if upstream_instance_state is SUCCESS: expected_schedulable.append(f"{mapped_task_2}_2") assert list(schedulable_tis.keys()) == expected_schedulable # Run the expanded m2 tasks - schedulable_tis[f"{mapped_task_2}_0"].run() - schedulable_tis[f"{mapped_task_2}_1"].run() - if upstream_instance_state is None: - schedulable_tis[f"{mapped_task_2}_2"].run() + + to_run: tuple[tuple[TaskInstance, TaskInstanceState], ...] = ( + (schedulable_tis[f"{mapped_task_2}_0"], SUCCESS), + (schedulable_tis[f"{mapped_task_2}_1"], SUCCESS), + ) + if upstream_instance_state == SUCCESS: + to_run += ((schedulable_tis[f"{mapped_task_2}_2"], upstream_instance_state),) + for ti, state in to_run: + ti.state = state + if state is SUCCESS: + ti.xcom_push(XCOM_RETURN_KEY, "doesn't matter", session=session) + session.flush() schedulable_tis, finished_tis_states = _one_scheduling_decision_iteration(dr, session) assert not schedulable_tis + expected_finished_tis_states = { ti: "success" for ti in (f"{mapped_task_1}_1", f"{mapped_task_1}_2", f"{mapped_task_2}_0", f"{mapped_task_2}_1") } - if upstream_instance_state is None: + if upstream_instance_state is SUCCESS: expected_finished_tis_states[f"{mapped_task_1}_0"] = "success" expected_finished_tis_states[f"{mapped_task_2}_2"] = "success" else: diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py index 732714debd9d6..8b2ee9e6646a3 100644 --- a/tests_common/pytest_plugin.py +++ b/tests_common/pytest_plugin.py @@ -544,8 +544,7 @@ def skip_db_test(item): # also automatically skip tests marked with `backend` marker as they are implicitly # db tests pytest.skip( - f"The test is skipped as it is DB test " - f"and --skip-db-tests is flag is passed to pytest. {item}" + f"The test is skipped as it is DB test and --skip-db-tests is flag is passed to pytest. {item}" ) @@ -788,7 +787,12 @@ def __enter__(self): self.dag.__enter__() if self.want_serialized: - return lazy_object_proxy.Proxy(self._serialized_dag) + + class DAGProxy(lazy_object_proxy.Proxy): + # Make `@dag.task` decorator work when need_serialized_dag marker is set + task = self.dag.task + + return DAGProxy(self._serialized_dag) return self.dag def _serialized_dag(self): @@ -868,6 +872,10 @@ def __exit__(self, type, value, traceback): if self.want_activate_assets: self._activate_assets() if sdm: + sdm._SerializedDagModel__data_cache = ( + self.serialized_model._SerializedDagModel__data_cache + ) + sdm._data = self.serialized_model._data self.serialized_model = sdm else: self.session.merge(self.serialized_model) @@ -937,9 +945,13 @@ def create_dagrun(self, *, logical_date=None, **kwargs): kwargs.pop("dag_version", None) kwargs.pop("triggered_by", None) kwargs["execution_date"] = logical_date + + if self.want_serialized: + dag = self.serialized_model.dag self.dag_run = dag.create_dagrun(**kwargs) for ti in self.dag_run.task_instances: - ti.refresh_from_task(dag.get_task(ti.task_id)) + # This need to always operate on the _real_ dag + ti.refresh_from_task(self.dag.get_task(ti.task_id)) if self.want_serialized: self.session.commit() return self.dag_run diff --git a/tests_common/test_utils/mock_context.py b/tests_common/test_utils/mock_context.py index 3c8a101ce2058..4391490dfb4a2 100644 --- a/tests_common/test_utils/mock_context.py +++ b/tests_common/test_utils/mock_context.py @@ -56,14 +56,16 @@ def xcom_pull( default: Any = None, run_id: str | None = None, ) -> Any: - if map_indexes: - return values.get( - f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}", default - ) - return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}", default) + key = f"{self.task_id}_{self.dag_id}_{key}" + if map_indexes is not None and (not isinstance(map_indexes, int) or map_indexes >= 0): + key += f"_{map_indexes}" + return values.get(key, default) def xcom_push(self, key: str, value: Any, session: Session = NEW_SESSION, **kwargs) -> None: - values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] = value + key = f"{self.task_id}_{self.dag_id}_{key}" + if self.map_index is not None and self.map_index >= 0: + key += f"_{self.map_index}" + values[key] = value values["ti"] = MockedTaskInstance(task=task) diff --git a/tests_common/test_utils/mock_operators.py b/tests_common/test_utils/mock_operators.py index 9c9e3cfb5488f..81f53abf648ce 100644 --- a/tests_common/test_utils/mock_operators.py +++ b/tests_common/test_utils/mock_operators.py @@ -17,7 +17,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import attr @@ -27,8 +27,6 @@ from tests_common.test_utils.compat import BaseOperatorLink if TYPE_CHECKING: - import jinja2 - from airflow.sdk.definitions.context import Context @@ -46,48 +44,6 @@ def execute(self, context: Context): pass -class NestedFields: - """Nested fields for testing purposes.""" - - def __init__(self, field_1, field_2): - self.field_1 = field_1 - self.field_2 = field_2 - - -class MockOperatorWithNestedFields(BaseOperator): - """Operator with nested fields for testing purposes.""" - - template_fields: Sequence[str] = ("arg1", "arg2") - - def __init__(self, arg1: str = "", arg2: NestedFields | None = None, **kwargs): - super().__init__(**kwargs) - self.arg1 = arg1 - self.arg2 = arg2 - - def _render_nested_template_fields( - self, - content: Any, - context: Context, - jinja_env: jinja2.Environment, - seen_oids: set, - ) -> None: - if id(content) not in seen_oids: - template_fields: tuple | None = None - - if isinstance(content, NestedFields): - template_fields = ("field_1", "field_2") - - if template_fields: - seen_oids.add(id(content)) - self._do_render_template_fields(content, template_fields, context, jinja_env, seen_oids) - return - - super()._render_nested_template_fields(content, context, jinja_env, seen_oids) - - def execute(self, context: Context): - pass - - class AirflowLink(BaseOperatorLink): """Operator Link for Apache Airflow Website."""