Skip to content

Commit

Permalink
AIP-72: Add Taskflow API support & template rendering in Task SDK (ap…
Browse files Browse the repository at this point in the history
…ache#45444)

closes apache#45232
part of apache#44481

The Templater class has been moved to the Task SDK to align with the language-specific aspects of template rendering. Templating logic is inherently tied to Python constructs. By keeping the Templater class within the Task SDK, we ensure that the core templating logic remains coupled with language-specific implementations.

Options I had were keeping it on the Schdeuler or the Execution side of Task SDK, neither of which is ideal as we would want to change the code in definition like DAG, Operator alongwith how it renders.

With [`tutorial_taskflow_api`](https://github.com/apache/airflow/blob/5581e65fd5575364fbf2c0e5c8cf4f4afe2b841b/airflow/example_dags/tutorial_taskflow_api.py#L38):

<img width="1705" alt="image" src="https://github.com/user-attachments/assets/c84327ed-5956-4f48-ab32-97a77ae44016" />

---
With [`example_xcom_args`](https://github.com/apache/airflow/blob/5581e65fd5575364fbf2c0e5c8cf4f4afe2b841b/airflow/example_dags/example_xcomargs.py):

<img width="1720" alt="image" src="https://github.com/user-attachments/assets/f9e0190f-1030-437d-ab6b-8247a5f8cdb0" />
  • Loading branch information
kaxil authored and HariGS-DB committed Jan 16, 2025
1 parent ec731df commit bf08777
Show file tree
Hide file tree
Showing 38 changed files with 949 additions and 746 deletions.
2 changes: 1 addition & 1 deletion airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def __attrs_post_init__(self):
XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value)

def _expand_mapped_kwargs(
self, context: Context, session: Session, *, include_xcom: bool
self, context: Mapping[str, Any], session: Session, *, include_xcom: bool
) -> tuple[Mapping[str, Any], set[int]]:
# We only use op_kwargs_expand_input so this must always be empty.
if self.expand_input is not EXPAND_INPUT_EMPTY:
Expand Down
44 changes: 3 additions & 41 deletions airflow/macros/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,59 +17,20 @@
# under the License.
from __future__ import annotations

import json # noqa: F401
import time # noqa: F401
import uuid # noqa: F401
from datetime import datetime, timedelta
from random import random # noqa: F401
from datetime import datetime
from typing import TYPE_CHECKING, Any

import dateutil # noqa: F401
from babel import Locale
from babel.dates import LC_TIME, format_datetime

import airflow.utils.yaml as yaml # noqa: F401
from airflow.sdk.definitions.macros import ds_add, ds_format, json, time, uuid # noqa: F401

if TYPE_CHECKING:
from pendulum import DateTime


def ds_add(ds: str, days: int) -> str:
"""
Add or subtract days from a YYYY-MM-DD.
:param ds: anchor date in ``YYYY-MM-DD`` format to add to
:param days: number of days to add to the ds, you can use negative values
>>> ds_add("2015-01-01", 5)
'2015-01-06'
>>> ds_add("2015-01-06", -5)
'2015-01-01'
"""
if not days:
return str(ds)
dt = datetime.strptime(str(ds), "%Y-%m-%d") + timedelta(days=days)
return dt.strftime("%Y-%m-%d")


def ds_format(ds: str, input_format: str, output_format: str) -> str:
"""
Output datetime string in a given format.
:param ds: Input string which contains a date.
:param input_format: Input string format (e.g., '%Y-%m-%d').
:param output_format: Output string format (e.g., '%Y-%m-%d').
>>> ds_format("2015-01-01", "%Y-%m-%d", "%m-%d-%y")
'01-01-15'
>>> ds_format("1/5/2015", "%m/%d/%Y", "%Y-%m-%d")
'2015-01-05'
>>> ds_format("12/07/2024", "%d/%m/%Y", "%A %d %B %Y", "en_US")
'Friday 12 July 2024'
"""
return datetime.strptime(str(ds), input_format).strftime(output_format)


def ds_format_locale(
ds: str, input_format: str, output_format: str, locale: Locale | str | None = None
) -> str:
Expand Down Expand Up @@ -99,6 +60,7 @@ def ds_format_locale(
)


# TODO: Task SDK: Move this to the Task SDK once we evaluate "pendulum"'s dependency
def datetime_diff_for_humans(dt: Any, since: DateTime | None = None) -> str:
"""
Return a human-readable/approximate difference between datetimes.
Expand Down
78 changes: 4 additions & 74 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import datetime
import inspect
from collections.abc import Iterable, Iterator, Sequence
from collections.abc import Iterable, Iterator, Mapping, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable

Expand All @@ -30,10 +30,9 @@
from airflow.exceptions import AirflowException
from airflow.models.expandinput import NotFullyPopulated
from airflow.sdk.definitions.abstractoperator import AbstractOperator as TaskSDKAbstractOperator
from airflow.template.templater import Templater
from airflow.utils.context import Context
from airflow.utils.db import exists_query
from airflow.utils.log.secrets_masker import redact
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import State, TaskInstanceState
Expand All @@ -42,8 +41,6 @@
from airflow.utils.weight_rule import WeightRule, db_safe_priority

if TYPE_CHECKING:
from collections.abc import Mapping

import jinja2 # Slow import.
from sqlalchemy.orm import Session

Expand All @@ -52,7 +49,6 @@
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.node import DAGNode
from airflow.task.priority_strategy import PriorityWeightStrategy
from airflow.triggers.base import StartTriggerArgs
Expand Down Expand Up @@ -88,7 +84,7 @@ class NotMapped(Exception):
"""Raise if a task is neither mapped nor has any parent mapped groups."""


class AbstractOperator(Templater, TaskSDKAbstractOperator):
class AbstractOperator(LoggingMixin, TaskSDKAbstractOperator):
"""
Common implementation for operators, including unmapped and mapped.
Expand Down Expand Up @@ -128,72 +124,6 @@ def on_failure_fail_dagrun(self, value):
)
self._on_failure_fail_dagrun = value

def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
"""Get the template environment for rendering templates."""
if dag is None:
dag = self.get_dag()
return super().get_template_env(dag=dag)

def _render(self, template, context, dag: DAG | None = None):
if dag is None:
dag = self.get_dag()
return super()._render(template, context, dag=dag)

def _do_render_template_fields(
self,
parent: Any,
template_fields: Iterable[str],
context: Mapping[str, Any],
jinja_env: jinja2.Environment,
seen_oids: set[int],
) -> None:
"""Override the base to use custom error logging."""
for attr_name in template_fields:
try:
value = getattr(parent, attr_name)
except AttributeError:
raise AttributeError(
f"{attr_name!r} is configured as a template field "
f"but {parent.task_type} does not have this attribute."
)
try:
if not value:
continue
except Exception:
# This may happen if the templated field points to a class which does not support `__bool__`,
# such as Pandas DataFrames:
# https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
self.log.info(
"Unable to check if the value of type '%s' is False for task '%s', field '%s'.",
type(value).__name__,
self.task_id,
attr_name,
)
# We may still want to render custom classes which do not support __bool__
pass

try:
if callable(value):
rendered_content = value(context=context, jinja_env=jinja_env)
else:
rendered_content = self.render_template(
value,
context,
jinja_env,
seen_oids,
)
except Exception:
value_masked = redact(name=attr_name, value=value)
self.log.exception(
"Exception rendering Jinja template for task '%s', field '%s'. Template: %r",
self.task_id,
attr_name,
value_masked,
)
raise
else:
setattr(parent, attr_name, rendered_content)

def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]:
"""
Return mapped nodes that are direct dependencies of the current task.
Expand Down Expand Up @@ -582,7 +512,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence

def render_template_fields(
self,
context: Context,
context: Mapping[str, Any],
jinja_env: jinja2.Environment | None = None,
) -> None:
"""
Expand Down
18 changes: 0 additions & 18 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@
if TYPE_CHECKING:
from types import ClassMethodDescriptorType

import jinja2 # Slow import.
from sqlalchemy.orm import Session

from airflow.models.abstractoperator import TaskStateChangeCallback
Expand Down Expand Up @@ -738,23 +737,6 @@ def post_execute(self, context: Any, result: Any = None):
logger=self.log,
).run(context, result)

def render_template_fields(
self,
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> None:
"""
Template all attributes listed in *self.template_fields*.
This mutates the attributes in-place and is irreversible.
:param context: Context dict with values to apply on content.
:param jinja_env: Jinja's environment to use for rendering.
"""
if not jinja_env:
jinja_env = self.get_template_env()
self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())

@provide_session
def clear(
self,
Expand Down
13 changes: 7 additions & 6 deletions airflow/models/expandinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import attr

from airflow.utils.mixins import ResolveMixin
from airflow.sdk.definitions.mixins import ResolveMixin
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
Expand All @@ -35,7 +35,6 @@
from airflow.models.xcom_arg import XComArg
from airflow.serialization.serialized_objects import _ExpandInputRef
from airflow.typing_compat import TypeGuard
from airflow.utils.context import Context

ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]

Expand Down Expand Up @@ -69,7 +68,9 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
yield from self._input.iter_references()

@provide_session
def resolve(self, context: Context, *, include_xcom: bool = True, session: Session = NEW_SESSION) -> Any:
def resolve(
self, context: Mapping[str, Any], *, include_xcom: bool = True, session: Session = NEW_SESSION
) -> Any:
data, _ = self._input.resolve(context, session=session, include_xcom=include_xcom)
return data[self._key]

Expand Down Expand Up @@ -166,7 +167,7 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1)

def _expand_mapped_field(
self, key: str, value: Any, context: Context, *, session: Session, include_xcom: bool
self, key: str, value: Any, context: Mapping[str, Any], *, session: Session, include_xcom: bool
) -> Any:
if _needs_run_time_resolution(value):
value = (
Expand Down Expand Up @@ -210,7 +211,7 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
yield from x.iter_references()

def resolve(
self, context: Context, session: Session, *, include_xcom: bool = True
self, context: Mapping[str, Any], session: Session, *, include_xcom: bool = True
) -> tuple[Mapping[str, Any], set[int]]:
data = {
k: self._expand_mapped_field(k, v, context, session=session, include_xcom=include_xcom)
Expand Down Expand Up @@ -260,7 +261,7 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
yield from x.iter_references()

def resolve(
self, context: Context, session: Session, *, include_xcom: bool = True
self, context: Mapping[str, Any], session: Session, *, include_xcom: bool = True
) -> tuple[Mapping[str, Any], set[int]]:
map_index = context["ti"].map_index
if map_index < 0:
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
return DagAttributeTypes.OP, self.task_id

def _expand_mapped_kwargs(
self, context: Context, session: Session, *, include_xcom: bool
self, context: Mapping[str, Any], session: Session, *, include_xcom: bool
) -> tuple[Mapping[str, Any], set[int]]:
"""
Get the kwargs to create the unmapped operator.
Expand Down Expand Up @@ -869,7 +869,7 @@ def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:

def render_template_fields(
self,
context: Context,
context: Mapping[str, Any],
jinja_env: jinja2.Environment | None = None,
) -> None:
"""
Expand Down
7 changes: 3 additions & 4 deletions airflow/models/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@
import copy
import json
import logging
from collections.abc import ItemsView, Iterable, MutableMapping, ValuesView
from collections.abc import ItemsView, Iterable, Mapping, MutableMapping, ValuesView
from typing import TYPE_CHECKING, Any, ClassVar

from airflow.exceptions import AirflowException, ParamValidationError
from airflow.utils.mixins import ResolveMixin
from airflow.sdk.definitions.mixins import ResolveMixin
from airflow.utils.types import NOTSET, ArgNotSet

if TYPE_CHECKING:
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions.dag import DAG
from airflow.utils.context import Context

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -295,7 +294,7 @@ def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET):
def iter_references(self) -> Iterable[tuple[Operator, str]]:
return ()

def resolve(self, context: Context, *, include_xcom: bool = True) -> Any:
def resolve(self, context: Mapping[str, Any], *, include_xcom: bool = True) -> Any:
"""Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
with contextlib.suppress(KeyError):
return context["dag_run"].conf[self._name]
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@
from airflow.models.xcom import LazyXComSelectSequence, XCom
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
from airflow.sdk.definitions.templater import SandboxedEnvironment
from airflow.sentry import Sentry
from airflow.settings import task_instance_mutation_hook
from airflow.stats import Stats
from airflow.templates import SandboxedEnvironment
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
from airflow.traces.tracer import Trace
Expand Down
Loading

0 comments on commit bf08777

Please sign in to comment.