Skip to content

Commit

Permalink
Render a pod spec using the pod_template_file override, if passed to …
Browse files Browse the repository at this point in the history
…the executor

If a task was created by a custom `executor_options['pod_template_file']` option,
we make sure to render the `TaskInstance`'s associated `k8s_pod_spec`
with this specific `pod_template_file`, to avoid seeing discrepancies
between the spec visible in airflow and the one deployed to Kubernetes.

Signed-off-by: Balthazar Rouberol <[email protected]>
  • Loading branch information
brouberol committed Feb 3, 2025
1 parent 8a3757b commit 1f71b35
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@
def render_k8s_pod_yaml(task_instance: TaskInstance) -> dict | None:
"""Render k8s pod yaml."""
kube_config = KubeConfig()
if "pod_template_file" in (task_instance.executor_config or {}):
# If a specific pod_template_file was passed to the executor, we make
# sure to render the k8s pod spec using this one, and not the default one.
pod_template_file = task_instance.executor_config["pod_template_file"]
else:
# If no such pod_template_file override was passed, we can simply render
# The pod spec using the default template.
pod_template_file = kube_config.pod_template_file
pod = PodGenerator.construct_pod(
dag_id=task_instance.dag_id,
run_id=task_instance.run_id,
Expand All @@ -48,7 +56,7 @@ def render_k8s_pod_yaml(task_instance: TaskInstance) -> dict | None:
pod_override_object=PodGenerator.from_obj(task_instance.executor_config),
scheduler_job_id="0",
namespace=kube_config.executor_namespace,
base_worker_pod=PodGenerator.deserialize_model_file(kube_config.pod_template_file),
base_worker_pod=PodGenerator.deserialize_model_file(pod_template_file),
with_mutation_hook=True,
)
sanitized_pod = ApiClient().sanitize_for_serialization(pod)
Expand Down
26 changes: 25 additions & 1 deletion providers/tests/cncf/kubernetes/test_template_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from unittest import mock

import pytest
import yaml
from sqlalchemy.orm import make_transient

from airflow.models.renderedtifields import RenderedTaskInstanceFields, RenderedTaskInstanceFields as RTIF
Expand Down Expand Up @@ -84,11 +85,34 @@ def test_render_k8s_pod_yaml(pod_mutation_hook, create_task_instance):
]
},
}

assert render_k8s_pod_yaml(ti) == expected_pod_spec
pod_mutation_hook.assert_called_once_with(mock.ANY)


@mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"})
@mock.patch("airflow.settings.pod_mutation_hook")
def test_render_k8s_pod_yaml_with_custom_pod_template(pod_mutation_hook, create_task_instance, tmp_path):
with open(f"{tmp_path}/custom_pod_template.yaml", "w") as ptf:
template = {
"apiVersion": "v1",
"kind": "Pod",
"metadata": {"labels": {"custom_label": "custom_value"}},
}
ptf.write(yaml.dump(template))

ti = create_task_instance(
dag_id="test_render_k8s_pod_yaml",
run_id="test_run_id",
task_id="op1",
logical_date=DEFAULT_DATE,
executor_config={"pod_template_file": f"{tmp_path}/custom_pod_template.yaml"},
)

ti_pod_yaml = render_k8s_pod_yaml(ti)
assert "custom_label" in ti_pod_yaml["metadata"]["labels"]
assert ti_pod_yaml["metadata"]["labels"]["custom_label"] == "custom_value"


@mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"})
@mock.patch.object(RenderedTaskInstanceFields, "get_k8s_pod_yaml")
@mock.patch("airflow.providers.cncf.kubernetes.template_rendering.render_k8s_pod_yaml")
Expand Down

0 comments on commit 1f71b35

Please sign in to comment.