Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Move DAG Params to Task SDK #46127

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class TaskInstance(BaseModel):
dag_id: str
run_id: str
try_number: int
map_index: int = -1
map_index: int | None = None
hostname: str | None = None


Expand Down
2 changes: 1 addition & 1 deletion airflow/cli/commands/remote_commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
from airflow.models import TaskInstance
from airflow.models.dag import DAG, _run_inline_trigger
from airflow.models.dagrun import DagRun
from airflow.models.param import ParamsDict
from airflow.models.taskinstance import TaskReturnCode
from airflow.sdk.definitions.param import ParamsDict
from airflow.settings import IS_EXECUTOR_CONTAINER, IS_K8S_EXECUTOR_POD
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS
Expand Down
10 changes: 5 additions & 5 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.models.expandinput import (
ExpandInput,
OperatorExpandArgument,
Expand Down Expand Up @@ -184,7 +182,9 @@ def __init__(
kwargs_to_upstream: dict[str, Any] | None = None,
**kwargs,
) -> None:
task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
if not getattr(self, "_BaseOperator__from_mapped", False):
# If we are being created from calling unmap(), then don't mangle the task id
task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
self.python_callable = python_callable
kwargs_to_upstream = kwargs_to_upstream or {}
op_args = op_args or []
Expand Down Expand Up @@ -569,12 +569,12 @@ def __attrs_post_init__(self):
XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value)

def _expand_mapped_kwargs(
self, context: Mapping[str, Any], session: Session, *, include_xcom: bool
self, context: Mapping[str, Any], *, include_xcom: bool
) -> tuple[Mapping[str, Any], set[int]]:
# We only use op_kwargs_expand_input so this must always be empty.
if self.expand_input is not EXPAND_INPUT_EMPTY:
raise AssertionError(f"unexpected expand_input: {self.expand_input}")
op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, session, include_xcom=include_xcom)
op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, include_xcom=include_xcom)
return {"op_kwargs": op_kwargs}, resolved_oids

def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion airflow/example_dags/example_params_trigger_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from airflow.decorators import task
from airflow.models.dag import DAG
from airflow.models.param import Param, ParamsDict
from airflow.sdk.definitions.param import Param, ParamsDict
from airflow.utils.trigger_rule import TriggerRule

# [START params_trigger]
Expand Down
2 changes: 1 addition & 1 deletion airflow/example_dags/example_params_ui_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from airflow.decorators import task
from airflow.models.dag import DAG
from airflow.models.param import Param, ParamsDict
from airflow.sdk.definitions.param import Param, ParamsDict

with (
DAG(
Expand Down
5 changes: 3 additions & 2 deletions airflow/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def __getattr__(name):
"Log": "airflow.models.log",
"MappedOperator": "airflow.models.mappedoperator",
"Operator": "airflow.models.operator",
"Param": "airflow.models.param",
"Param": "airflow.sdk.definitions.param",
"ParamsDict": "airflow.sdk.definitions.param",
"Pool": "airflow.models.pool",
"RenderedTaskInstanceFields": "airflow.models.renderedtifields",
"SkipMixin": "airflow.models.skipmixin",
Expand Down Expand Up @@ -128,7 +129,6 @@ def __getattr__(name):
from airflow.models.log import Log
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.models.param import Param
from airflow.models.pool import Pool
from airflow.models.renderedtifields import RenderedTaskInstanceFields
from airflow.models.skipmixin import SkipMixin
Expand All @@ -138,3 +138,4 @@ def __getattr__(name):
from airflow.models.trigger import Trigger
from airflow.models.variable import Variable
from airflow.models.xcom import XCom
from airflow.sdk.definitions.param import Param
2 changes: 1 addition & 1 deletion airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models.expandinput import NotFullyPopulated
from airflow.sdk.definitions._internal.abstractoperator import (
AbstractOperator as TaskSDKAbstractOperator,
NotMapped as NotMapped, # Re-export this for compat
Expand Down Expand Up @@ -237,6 +236,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
)

from airflow.models.baseoperator import BaseOperator as DBBaseOperator
from airflow.models.expandinput import NotFullyPopulated

try:
total_length: int | None = DBBaseOperator.get_mapped_ti_count(self, run_id, session=session)
Expand Down
11 changes: 10 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,11 +847,20 @@ def _(cls, task: TaskSDKAbstractOperator, run_id: str, *, session: Session) -> i
@get_mapped_ti_count.register(MappedOperator)
@classmethod
def _(cls, task: MappedOperator, run_id: str, *, session: Session) -> int:
from airflow.serialization.serialized_objects import _ExpandInputRef
from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef

exp_input = task._get_specified_expand_input()
if isinstance(exp_input, _ExpandInputRef):
exp_input = exp_input.deref(task.dag)
# TODO: TaskSDK This is only needed to support `dag.test()` etc until we port it over ot use the
# task sdk runner.
if not hasattr(exp_input, "get_total_map_length"):
exp_input = _ExpandInputRef(
type(exp_input).EXPAND_INPUT_TYPE,
BaseSerialization.deserialize(BaseSerialization.serialize(exp_input.value)),
)
exp_input = exp_input.deref(task.dag)

current_count = exp_input.get_total_map_length(run_id, session=session)

group = task.get_closest_mapped_task_group()
Expand Down
4 changes: 3 additions & 1 deletion airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
from airflow.models.backfill import Backfill
from airflow.models.base import Base, StringID
from airflow.models.dag_version import DagVersion
from airflow.models.expandinput import NotFullyPopulated
from airflow.models.taskinstance import TaskInstance as TI
from airflow.models.tasklog import LogTemplate
from airflow.models.taskmap import TaskMap
Expand Down Expand Up @@ -1330,6 +1329,7 @@ def _check_for_removed_or_restored_tasks(

"""
from airflow.models.baseoperator import BaseOperator
from airflow.models.expandinput import NotFullyPopulated

tis = self.get_task_instances(session=session)

Expand Down Expand Up @@ -1467,6 +1467,7 @@ def _create_tasks(
:param task_creator: Function to create task instances
"""
from airflow.models.baseoperator import BaseOperator
from airflow.models.expandinput import NotFullyPopulated

map_indexes: Iterable[int]
for task in tasks:
Expand Down Expand Up @@ -1538,6 +1539,7 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) ->
for more details.
"""
from airflow.models.baseoperator import BaseOperator
from airflow.models.expandinput import NotFullyPopulated
from airflow.settings import task_instance_mutation_hook

try:
Expand Down
Loading