diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/template_rendering.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/template_rendering.py index 6abeeddde7884d..ed2fa7e92f4773 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/template_rendering.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/template_rendering.py @@ -35,6 +35,14 @@ def render_k8s_pod_yaml(task_instance: TaskInstance) -> dict | None: """Render k8s pod yaml.""" kube_config = KubeConfig() + if task_instance.executor_config and task_instance.executor_config.get("pod_template_file"): + # If a specific pod_template_file was passed to the executor, we make + # sure to render the k8s pod spec using this one, and not the default one. + pod_template_file = task_instance.executor_config["pod_template_file"] + else: + # If no such pod_template_file override was passed, we can simply render + # The pod spec using the default template. + pod_template_file = kube_config.pod_template_file pod = PodGenerator.construct_pod( dag_id=task_instance.dag_id, run_id=task_instance.run_id, @@ -48,7 +56,7 @@ def render_k8s_pod_yaml(task_instance: TaskInstance) -> dict | None: pod_override_object=PodGenerator.from_obj(task_instance.executor_config), scheduler_job_id="0", namespace=kube_config.executor_namespace, - base_worker_pod=PodGenerator.deserialize_model_file(kube_config.pod_template_file), + base_worker_pod=PodGenerator.deserialize_model_file(pod_template_file), with_mutation_hook=True, ) sanitized_pod = ApiClient().sanitize_for_serialization(pod) diff --git a/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_template_rendering.py b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_template_rendering.py index 6f512cdfe805ea..a956977b43937a 100644 --- a/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_template_rendering.py +++ b/providers/cncf/kubernetes/tests/provider_tests/cncf/kubernetes/test_template_rendering.py @@ -20,6 +20,8 @@ from unittest import mock import pytest +import yaml +from kubernetes.client import models as k8s from sqlalchemy.orm import make_transient from airflow.models.renderedtifields import RenderedTaskInstanceFields, RenderedTaskInstanceFields as RTIF @@ -84,11 +86,69 @@ def test_render_k8s_pod_yaml(pod_mutation_hook, create_task_instance): ] }, } - assert render_k8s_pod_yaml(ti) == expected_pod_spec pod_mutation_hook.assert_called_once_with(mock.ANY) +@mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"}) +@mock.patch("airflow.settings.pod_mutation_hook") +def test_render_k8s_pod_yaml_with_custom_pod_template(pod_mutation_hook, create_task_instance, tmp_path): + with open(f"{tmp_path}/custom_pod_template.yaml", "w") as ptf: + template = { + "apiVersion": "v1", + "kind": "Pod", + "metadata": {"labels": {"custom_label": "custom_value"}}, + } + ptf.write(yaml.dump(template)) + + ti = create_task_instance( + dag_id="test_render_k8s_pod_yaml", + run_id="test_run_id", + task_id="op1", + logical_date=DEFAULT_DATE, + executor_config={"pod_template_file": f"{tmp_path}/custom_pod_template.yaml"}, + ) + + ti_pod_yaml = render_k8s_pod_yaml(ti) + assert "custom_label" in ti_pod_yaml["metadata"]["labels"] + assert ti_pod_yaml["metadata"]["labels"]["custom_label"] == "custom_value" + + +@mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"}) +@mock.patch("airflow.settings.pod_mutation_hook") +def test_render_k8s_pod_yaml_with_custom_pod_template_and_pod_override( + pod_mutation_hook, create_task_instance, tmp_path +): + with open(f"{tmp_path}/custom_pod_template.yaml", "w") as ptf: + template = { + "apiVersion": "v1", + "kind": "Pod", + "metadata": {"labels": {"custom_label": "custom_value"}}, + } + ptf.write(yaml.dump(template)) + + pod_override = k8s.V1Pod( + metadata=k8s.V1ObjectMeta(annotations={"test": "annotation"}, labels={"custom_label": "override"}) + ) + ti = create_task_instance( + dag_id="test_render_k8s_pod_yaml", + run_id="test_run_id", + task_id="op1", + logical_date=DEFAULT_DATE, + executor_config={ + "pod_template_file": f"{tmp_path}/custom_pod_template.yaml", + "pod_override": pod_override, + }, + ) + + ti_pod_yaml = render_k8s_pod_yaml(ti) + assert "custom_label" in ti_pod_yaml["metadata"]["labels"] + # The initial value associated with the custom_label label in the pod_template_file + # was overriden by the pod_override + assert ti_pod_yaml["metadata"]["labels"]["custom_label"] == "override" + assert ti_pod_yaml["metadata"]["annotations"]["test"] == "annotation" + + @mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"}) @mock.patch.object(RenderedTaskInstanceFields, "get_k8s_pod_yaml") @mock.patch("airflow.providers.cncf.kubernetes.template_rendering.render_k8s_pod_yaml")