Skip to content

Commit

Permalink
Render a pod spec using the pod_template_file override, if passed to …
Browse files Browse the repository at this point in the history
…the executor

If a task was created by a custom `executor_options['pod_template_file']` option,
we make sure to render the `TaskInstance`'s associated `k8s_pod_spec`
with this specific `pod_template_file`, to avoid seeing discrepancies
between the spec visible in airflow and the one deployed to Kubernetes.

Signed-off-by: Balthazar Rouberol <[email protected]>
  • Loading branch information
brouberol committed Feb 7, 2025
1 parent 1cde11a commit dba7c1d
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@
def render_k8s_pod_yaml(task_instance: TaskInstance) -> dict | None:
"""Render k8s pod yaml."""
kube_config = KubeConfig()
if task_instance.executor_config and task_instance.executor_config.get("pod_template_file"):
# If a specific pod_template_file was passed to the executor, we make
# sure to render the k8s pod spec using this one, and not the default one.
pod_template_file = task_instance.executor_config["pod_template_file"]
else:
# If no such pod_template_file override was passed, we can simply render
# The pod spec using the default template.
pod_template_file = kube_config.pod_template_file
pod = PodGenerator.construct_pod(
dag_id=task_instance.dag_id,
run_id=task_instance.run_id,
Expand All @@ -48,7 +56,7 @@ def render_k8s_pod_yaml(task_instance: TaskInstance) -> dict | None:
pod_override_object=PodGenerator.from_obj(task_instance.executor_config),
scheduler_job_id="0",
namespace=kube_config.executor_namespace,
base_worker_pod=PodGenerator.deserialize_model_file(kube_config.pod_template_file),
base_worker_pod=PodGenerator.deserialize_model_file(pod_template_file),
with_mutation_hook=True,
)
sanitized_pod = ApiClient().sanitize_for_serialization(pod)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from unittest import mock

import pytest
import yaml
from kubernetes.client import models as k8s
from sqlalchemy.orm import make_transient

from airflow.models.renderedtifields import RenderedTaskInstanceFields, RenderedTaskInstanceFields as RTIF
Expand Down Expand Up @@ -84,11 +86,69 @@ def test_render_k8s_pod_yaml(pod_mutation_hook, create_task_instance):
]
},
}

assert render_k8s_pod_yaml(ti) == expected_pod_spec
pod_mutation_hook.assert_called_once_with(mock.ANY)


@mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"})
@mock.patch("airflow.settings.pod_mutation_hook")
def test_render_k8s_pod_yaml_with_custom_pod_template(pod_mutation_hook, create_task_instance, tmp_path):
with open(f"{tmp_path}/custom_pod_template.yaml", "w") as ptf:
template = {
"apiVersion": "v1",
"kind": "Pod",
"metadata": {"labels": {"custom_label": "custom_value"}},
}
ptf.write(yaml.dump(template))

ti = create_task_instance(
dag_id="test_render_k8s_pod_yaml",
run_id="test_run_id",
task_id="op1",
logical_date=DEFAULT_DATE,
executor_config={"pod_template_file": f"{tmp_path}/custom_pod_template.yaml"},
)

ti_pod_yaml = render_k8s_pod_yaml(ti)
assert "custom_label" in ti_pod_yaml["metadata"]["labels"]
assert ti_pod_yaml["metadata"]["labels"]["custom_label"] == "custom_value"


@mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"})
@mock.patch("airflow.settings.pod_mutation_hook")
def test_render_k8s_pod_yaml_with_custom_pod_template_and_pod_override(
pod_mutation_hook, create_task_instance, tmp_path
):
with open(f"{tmp_path}/custom_pod_template.yaml", "w") as ptf:
template = {
"apiVersion": "v1",
"kind": "Pod",
"metadata": {"labels": {"custom_label": "custom_value"}},
}
ptf.write(yaml.dump(template))

pod_override = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(annotations={"test": "annotation"}, labels={"custom_label": "override"})
)
ti = create_task_instance(
dag_id="test_render_k8s_pod_yaml",
run_id="test_run_id",
task_id="op1",
logical_date=DEFAULT_DATE,
executor_config={
"pod_template_file": f"{tmp_path}/custom_pod_template.yaml",
"pod_override": pod_override,
},
)

ti_pod_yaml = render_k8s_pod_yaml(ti)
assert "custom_label" in ti_pod_yaml["metadata"]["labels"]
# The initial value associated with the custom_label label in the pod_template_file
# was overriden by the pod_override
assert ti_pod_yaml["metadata"]["labels"]["custom_label"] == "override"
assert ti_pod_yaml["metadata"]["annotations"]["test"] == "annotation"


@mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"})
@mock.patch.object(RenderedTaskInstanceFields, "get_k8s_pod_yaml")
@mock.patch("airflow.providers.cncf.kubernetes.template_rendering.render_k8s_pod_yaml")
Expand Down

0 comments on commit dba7c1d

Please sign in to comment.