Skip to content

Commit

Permalink
Start porting mapped task to SDK (#45627)
Browse files Browse the repository at this point in the history
This PR restructures the Mapped Operator and Mapped Task Group code to live in
the Task SDK at definition time.

The big thing this change _does not do_ is make it possible to execute mapped
tasks via the Task Execution API server etc -- that is up next.

There were some un-avoidable changes to the scheduler/expansion part of mapped
tasks here. Of note:

`BaseOperator.get_mapped_ti_count` has moved from an instance method on
BaseOperator to be a class method. The reason for this was that with the move
of more and more of the "definition time" code into the TaskSDK BaseOperator
and AbstractOperator it is no longer possible to add DB-accessing code to a
base class and have it apply to the subclasses. (i.e.
`airflow.models.abstractoperator.AbstractOperator` is now _not always_ in the
MRO for tasks. Eventually that class will be deleted, but not yet)

On a similar vein XComArg's `get_task_map_length` is also moved to a single
dispatch class method on the TaskMap model since now the definition time
objects live in the TaskSDK, and there is no realistic way to get a per-type
subclass with DB logic (i.e. it's very complex to end up with a
PlainDBXComArg, a MapDBXComArg, etc. that we can attach the method too)

For those who aren't aware, singledispatch (and singledispatchmethod) are a
part of the standard library when the type of the first argument is used to
determine which implementation to call. If you are familiar with C++ or Java
this is very similar to method overloading, the one caveat is that it _only_
examines the type of the first argument, not the full signature.
  • Loading branch information
ashb authored Jan 21, 2025
1 parent f3913bc commit d1b2a44
Show file tree
Hide file tree
Showing 70 changed files with 2,911 additions and 2,484 deletions.
3 changes: 3 additions & 0 deletions airflow/api/common/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def set_state(
if not run_id:
raise ValueError("Received tasks with no run_id")

if TYPE_CHECKING:
assert isinstance(dag, DAG)

dag_run_ids = get_run_ids(dag, run_id, future, past, session=session)
task_id_map_index_list = list(find_task_relatives(tasks, downstream, upstream))
# now look for the task instances that are affected
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_connexion/schemas/common_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from dateutil import relativedelta
from marshmallow import Schema, fields, validate

from airflow.models.mappedoperator import MappedOperator
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator


Expand Down Expand Up @@ -128,7 +128,7 @@ def _get_module(self, obj):

def _get_class_name(self, obj):
if isinstance(obj, (MappedOperator, SerializedBaseOperator)):
return obj._task_type
return obj.task_type
if isinstance(obj, type):
return obj.__name__
return type(obj).__name__
Expand Down
7 changes: 1 addition & 6 deletions airflow/api_connexion/schemas/task_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
TimeDeltaSchema,
WeightRuleField,
)
from airflow.models.mappedoperator import MappedOperator

if TYPE_CHECKING:
from airflow.models.operator import Operator
Expand Down Expand Up @@ -62,7 +61,7 @@ class TaskSchema(Schema):
template_fields = fields.List(fields.String(), dump_only=True)
downstream_task_ids = fields.List(fields.String(), dump_only=True)
params = fields.Method("_get_params", dump_only=True)
is_mapped = fields.Method("_get_is_mapped", dump_only=True)
is_mapped = fields.Boolean(dump_only=True)
doc_md = fields.String(dump_only=True)

@staticmethod
Expand All @@ -80,10 +79,6 @@ def _get_params(obj):
params = obj.params
return {k: v.dump() for k, v in params.items()}

@staticmethod
def _get_is_mapped(obj):
return isinstance(obj, MappedOperator)


class TaskCollection(NamedTuple):
"""List of Tasks with metadata."""
Expand Down
23 changes: 5 additions & 18 deletions airflow/api_fastapi/core_api/datamodels/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,18 @@

from airflow.api_fastapi.common.types import TimeDeltaWithValidation
from airflow.api_fastapi.core_api.base import BaseModel
from airflow.models.mappedoperator import MappedOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator, encode_priority_weight_strategy
from airflow.serialization.serialized_objects import encode_priority_weight_strategy
from airflow.task.priority_strategy import PriorityWeightStrategy


def _get_class_ref(obj) -> dict[str, str | None]:
"""Return the class_ref dict for obj."""
is_mapped_or_serialized = isinstance(obj, (MappedOperator, SerializedBaseOperator))

module_path = None
if is_mapped_or_serialized:
module_path = obj._task_module
else:
module_path = getattr(obj, "_task_module", None)
if module_path is None:
module_type = inspect.getmodule(obj)
module_path = module_type.__name__ if module_type else None

class_name = None
if is_mapped_or_serialized:
class_name = obj._task_type
elif obj.__class__ is type:
class_name = obj.__name__
else:
class_name = type(obj).__name__
class_name = obj.task_type

return {
"module_path": module_path,
Expand Down Expand Up @@ -89,9 +78,7 @@ class TaskResponse(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_model(cls, task: Any) -> Any:
task.__dict__.update(
{"class_ref": _get_class_ref(task), "is_mapped": isinstance(task, MappedOperator)}
)
task.__dict__.update({"class_ref": _get_class_ref(task), "is_mapped": task.is_mapped})
return task

@field_validator("weight_rule", mode="before")
Expand Down
9 changes: 5 additions & 4 deletions airflow/api_fastapi/core_api/services/ui/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@
)
from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException
from airflow.models import MappedOperator
from airflow.models.baseoperator import BaseOperator
from airflow.models.baseoperator import BaseOperator as DBBaseOperator
from airflow.models.taskmap import TaskMap
from airflow.sdk import BaseOperator
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
from airflow.utils.state import TaskInstanceState
from airflow.utils.task_group import MappedTaskGroup, TaskGroup


@cache
Expand Down Expand Up @@ -139,7 +140,7 @@ def _get_total_task_count(
node
if isinstance(node, int)
else (
node.get_mapped_ti_count(run_id=run_id, session=session)
DBBaseOperator.get_mapped_ti_count(node, run_id=run_id, session=session) or 0
if isinstance(node, (MappedTaskGroup, MappedOperator))
else node
)
Expand Down
5 changes: 5 additions & 0 deletions airflow/cli/commands/remote_commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@ def _get_ti(
dag = task.dag
if dag is None:
raise ValueError("Cannot get task instance for a task not assigned to a DAG")
if not isinstance(dag, DAG):
# TODO: Task-SDK: Shouldn't really happen, and this command will go away before 3.0
raise ValueError(
f"We need a {DAG.__module__}.DAG, but we got {type(dag).__module__}.{type(dag).__name__}!"
)

# this check is imperfect because diff dags could have tasks with same name
# but in a task, dag_id is a property that accesses its dag, and we don't
Expand Down
6 changes: 3 additions & 3 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@
ListOfDictsExpandInput,
is_mappable,
)
from airflow.models.mappedoperator import MappedOperator, ensure_xcomarg_return_value
from airflow.models.xcom_arg import XComArg
from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
from airflow.sdk.definitions.mappedoperator import MappedOperator, ensure_xcomarg_return_value
from airflow.sdk.definitions.xcom_arg import XComArg
from airflow.typing_compat import ParamSpec
from airflow.utils import timezone
from airflow.utils.context import KNOWN_CONTEXT_KEYS
Expand All @@ -62,9 +62,9 @@
OperatorExpandArgument,
OperatorExpandKwargsArgument,
)
from airflow.models.mappedoperator import ValidationSource
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.mappedoperator import ValidationSource
from airflow.utils.task_group import TaskGroup


Expand Down
Loading

0 comments on commit d1b2a44

Please sign in to comment.