Skip to content

Commit

Permalink
AIP-72: Add Taskflow API support & template rendering in Task SDK
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil committed Jan 6, 2025
1 parent 5581e65 commit a8b529b
Show file tree
Hide file tree
Showing 26 changed files with 409 additions and 381 deletions.
40 changes: 3 additions & 37 deletions airflow/macros/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import json # noqa: F401
import time # noqa: F401
import uuid # noqa: F401
from datetime import datetime, timedelta
from datetime import datetime
from random import random # noqa: F401
from typing import TYPE_CHECKING, Any

Expand All @@ -29,47 +29,12 @@
from babel.dates import LC_TIME, format_datetime

import airflow.utils.yaml as yaml # noqa: F401
from airflow.sdk.definitions.macros import ds_add, ds_format # noqa: F401

if TYPE_CHECKING:
from pendulum import DateTime


def ds_add(ds: str, days: int) -> str:
"""
Add or subtract days from a YYYY-MM-DD.
:param ds: anchor date in ``YYYY-MM-DD`` format to add to
:param days: number of days to add to the ds, you can use negative values
>>> ds_add("2015-01-01", 5)
'2015-01-06'
>>> ds_add("2015-01-06", -5)
'2015-01-01'
"""
if not days:
return str(ds)
dt = datetime.strptime(str(ds), "%Y-%m-%d") + timedelta(days=days)
return dt.strftime("%Y-%m-%d")


def ds_format(ds: str, input_format: str, output_format: str) -> str:
"""
Output datetime string in a given format.
:param ds: Input string which contains a date.
:param input_format: Input string format (e.g., '%Y-%m-%d').
:param output_format: Output string format (e.g., '%Y-%m-%d').
>>> ds_format("2015-01-01", "%Y-%m-%d", "%m-%d-%y")
'01-01-15'
>>> ds_format("1/5/2015", "%m/%d/%Y", "%Y-%m-%d")
'2015-01-05'
>>> ds_format("12/07/2024", "%d/%m/%Y", "%A %d %B %Y", "en_US")
'Friday 12 July 2024'
"""
return datetime.strptime(str(ds), input_format).strftime(output_format)


def ds_format_locale(
ds: str, input_format: str, output_format: str, locale: Locale | str | None = None
) -> str:
Expand Down Expand Up @@ -99,6 +64,7 @@ def ds_format_locale(
)


# TODO: Task SDK: Move this to the Task SDK once we evaluate "pendulum"'s dependency
def datetime_diff_for_humans(dt: Any, since: DateTime | None = None) -> str:
"""
Return a human-readable/approximate difference between datetimes.
Expand Down
74 changes: 2 additions & 72 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@
from airflow.exceptions import AirflowException
from airflow.models.expandinput import NotFullyPopulated
from airflow.sdk.definitions.abstractoperator import AbstractOperator as TaskSDKAbstractOperator
from airflow.template.templater import Templater
from airflow.utils.context import Context
from airflow.utils.db import exists_query
from airflow.utils.log.secrets_masker import redact
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import State, TaskInstanceState
Expand All @@ -42,8 +41,6 @@
from airflow.utils.weight_rule import WeightRule, db_safe_priority

if TYPE_CHECKING:
from collections.abc import Mapping

import jinja2 # Slow import.
from sqlalchemy.orm import Session

Expand All @@ -52,7 +49,6 @@
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.node import DAGNode
from airflow.task.priority_strategy import PriorityWeightStrategy
from airflow.triggers.base import StartTriggerArgs
Expand Down Expand Up @@ -88,7 +84,7 @@ class NotMapped(Exception):
"""Raise if a task is neither mapped nor has any parent mapped groups."""


class AbstractOperator(Templater, TaskSDKAbstractOperator):
class AbstractOperator(LoggingMixin, TaskSDKAbstractOperator):
"""
Common implementation for operators, including unmapped and mapped.
Expand Down Expand Up @@ -128,72 +124,6 @@ def on_failure_fail_dagrun(self, value):
)
self._on_failure_fail_dagrun = value

def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
"""Get the template environment for rendering templates."""
if dag is None:
dag = self.get_dag()
return super().get_template_env(dag=dag)

def _render(self, template, context, dag: DAG | None = None):
if dag is None:
dag = self.get_dag()
return super()._render(template, context, dag=dag)

def _do_render_template_fields(
self,
parent: Any,
template_fields: Iterable[str],
context: Mapping[str, Any],
jinja_env: jinja2.Environment,
seen_oids: set[int],
) -> None:
"""Override the base to use custom error logging."""
for attr_name in template_fields:
try:
value = getattr(parent, attr_name)
except AttributeError:
raise AttributeError(
f"{attr_name!r} is configured as a template field "
f"but {parent.task_type} does not have this attribute."
)
try:
if not value:
continue
except Exception:
# This may happen if the templated field points to a class which does not support `__bool__`,
# such as Pandas DataFrames:
# https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
self.log.info(
"Unable to check if the value of type '%s' is False for task '%s', field '%s'.",
type(value).__name__,
self.task_id,
attr_name,
)
# We may still want to render custom classes which do not support __bool__
pass

try:
if callable(value):
rendered_content = value(context=context, jinja_env=jinja_env)
else:
rendered_content = self.render_template(
value,
context,
jinja_env,
seen_oids,
)
except Exception:
value_masked = redact(name=attr_name, value=value)
self.log.exception(
"Exception rendering Jinja template for task '%s', field '%s'. Template: %r",
self.task_id,
attr_name,
value_masked,
)
raise
else:
setattr(parent, attr_name, rendered_content)

def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]:
"""
Return mapped nodes that are direct dependencies of the current task.
Expand Down
18 changes: 0 additions & 18 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@
if TYPE_CHECKING:
from types import ClassMethodDescriptorType

import jinja2 # Slow import.
from sqlalchemy.orm import Session

from airflow.models.abstractoperator import TaskStateChangeCallback
Expand Down Expand Up @@ -738,23 +737,6 @@ def post_execute(self, context: Any, result: Any = None):
logger=self.log,
).run(context, result)

def render_template_fields(
self,
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> None:
"""
Template all attributes listed in *self.template_fields*.
This mutates the attributes in-place and is irreversible.
:param context: Context dict with values to apply on content.
:param jinja_env: Jinja's environment to use for rendering.
"""
if not jinja_env:
jinja_env = self.get_template_env()
self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())

@provide_session
def clear(
self,
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/expandinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import attr

from airflow.utils.mixins import ResolveMixin
from airflow.sdk.definitions.mixins import ResolveMixin
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from typing import TYPE_CHECKING, Any, ClassVar

from airflow.exceptions import AirflowException, ParamValidationError
from airflow.utils.mixins import ResolveMixin
from airflow.sdk.definitions.mixins import ResolveMixin
from airflow.utils.types import NOTSET, ArgNotSet

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@
from airflow.models.xcom import LazyXComSelectSequence, XCom
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
from airflow.sdk.definitions.templater import SandboxedEnvironment
from airflow.sentry import Sentry
from airflow.settings import task_instance_mutation_hook
from airflow.stats import Stats
from airflow.templates import SandboxedEnvironment
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
from airflow.traces.tracer import Trace
Expand Down
11 changes: 5 additions & 6 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
from airflow.models import MappedOperator, TaskInstance
from airflow.models.abstractoperator import AbstractOperator
from airflow.models.taskmixin import DependencyMixin
from airflow.sdk.definitions.mixins import ResolveMixin
from airflow.sdk.types import NOTSET, ArgNotSet
from airflow.utils.db import exists_query
from airflow.utils.mixins import ResolveMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.state import State
Expand Down Expand Up @@ -206,8 +206,7 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
"""
raise NotImplementedError()

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any:
def resolve(self, context: Context, session: Session | None = None, *, include_xcom: bool = True) -> Any:
"""
Pull XCom value.
Expand Down Expand Up @@ -420,8 +419,8 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
)
return session.scalar(query)

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any:
# TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK
def resolve(self, context: Context, session: Session | None = None, *, include_xcom: bool = True) -> Any:
ti = context["ti"]
if TYPE_CHECKING:
assert isinstance(ti, TaskInstance)
Expand All @@ -431,12 +430,12 @@ def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_x
context["expanded_ti_count"],
session=session,
)

result = ti.xcom_pull(
task_ids=task_id,
map_indexes=map_indexes,
key=self.key,
default=NOTSET,
session=session,
)
if not isinstance(result, ArgNotSet):
return result
Expand Down
5 changes: 3 additions & 2 deletions airflow/notifications/basenotifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING

from airflow.template.templater import Templater
from airflow.sdk.definitions.templater import Templater
from airflow.utils.context import context_merge
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
import jinja2
Expand All @@ -31,7 +32,7 @@
from airflow.utils.context import Context


class BaseNotifier(Templater):
class BaseNotifier(LoggingMixin, Templater):
"""BaseNotifier class for sending notifications."""

template_fields: Sequence[str] = ()
Expand Down
94 changes: 0 additions & 94 deletions airflow/templates.py

This file was deleted.

Loading

0 comments on commit a8b529b

Please sign in to comment.