From c22fc000b6c0075429b9d1e51c9ee3d384141ff3 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Fri, 30 Dec 2022 15:11:16 -0800 Subject: [PATCH] Use labels instead of pod name for pod log read in k8s exec (#28546) This means we don't have to use ti.hostname as a proxy for pod name, and allows us to lift the 63 charcter limit, which was a consequence of getting pod name through hostname. --- .../kubernetes/kubernetes_helper_functions.py | 9 +- airflow/kubernetes/pod_generator.py | 99 ++++++++++++++----- .../kubernetes/operators/kubernetes_pod.py | 2 +- airflow/utils/log/file_task_handler.py | 38 +++++-- .../test_kubernetes_helper_functions.py | 8 +- tests/kubernetes/test_pod_generator.py | 51 +++++++++- tests/utils/test_log_handlers.py | 50 ++++++---- 7 files changed, 199 insertions(+), 58 deletions(-) diff --git a/airflow/kubernetes/kubernetes_helper_functions.py b/airflow/kubernetes/kubernetes_helper_functions.py index ebe469342a879..7965e405acd04 100644 --- a/airflow/kubernetes/kubernetes_helper_functions.py +++ b/airflow/kubernetes/kubernetes_helper_functions.py @@ -48,16 +48,15 @@ def create_pod_id( dag_id: str | None = None, task_id: str | None = None, *, - max_length: int = 63, # must be 63 for now, see below + max_length: int = 80, unique: bool = True, ) -> str: """ Generates unique pod ID given a dag_id and / or task_id. - Because of the way that the task log handler reads from running k8s executor pods, - we must keep pod name <= 63 characters. The handler gets pod name from ti.hostname. - TI hostname is derived from the container hostname, which is truncated to 63 characters. - We could lift this limit by using label selectors instead of pod name to find the pod. + The default of 80 for max length is somewhat arbitrary, mainly a balance between + content and not overwhelming terminal windows of reasonable width. The true + upper limit is 253, and this is enforced in construct_pod. :param dag_id: DAG ID :param task_id: Task ID diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py index 64b7965a4c0e2..f27dba7e95302 100644 --- a/airflow/kubernetes/pod_generator.py +++ b/airflow/kubernetes/pod_generator.py @@ -339,8 +339,9 @@ def reconcile_containers( base_containers[1:], client_containers[1:] ) - @staticmethod + @classmethod def construct_pod( + cls, dag_id: str, task_id: str, pod_id: str, @@ -370,15 +371,6 @@ def construct_pod( "pod_id supplied is longer than 253 characters; truncating and adding unique suffix." ) pod_id = add_pod_suffix(pod_name=pod_id, max_len=253) - if len(pod_id) > 63: - # because in task handler we get pod name from ti hostname (which truncates - # pod_id to 63 characters) we won't be able to find the pod unless it is <= 63 characters. - # our code creates pod names shorter than this so this warning should not normally be triggered. - warnings.warn( - "Supplied pod_id is longer than 63 characters. Due to implementation details, the webserver " - "may not be able to stream logs while task is running. Please choose a shorter pod name." - ) - try: image = pod_override_object.spec.containers[0].image # type: ignore if not image: @@ -391,30 +383,27 @@ def construct_pod( "task_id": task_id, "try_number": str(try_number), } - labels = { - "airflow-worker": make_safe_label_value(scheduler_job_id), - "dag_id": make_safe_label_value(dag_id), - "task_id": make_safe_label_value(task_id), - "try_number": str(try_number), - "airflow_version": airflow_version.replace("+", "-"), - "kubernetes_executor": "True", - } if map_index >= 0: annotations["map_index"] = str(map_index) - labels["map_index"] = str(map_index) if date: annotations["execution_date"] = date.isoformat() - labels["execution_date"] = datetime_to_label_safe_datestring(date) if run_id: annotations["run_id"] = run_id - labels["run_id"] = make_safe_label_value(run_id) dynamic_pod = k8s.V1Pod( metadata=k8s.V1ObjectMeta( namespace=namespace, annotations=annotations, name=pod_id, - labels=labels, + labels=cls.build_labels_for_k8s_executor_pod( + dag_id=dag_id, + task_id=task_id, + try_number=try_number, + airflow_worker=scheduler_job_id, + map_index=map_index, + execution_date=date, + run_id=run_id, + ), ), spec=k8s.V1PodSpec( containers=[ @@ -447,6 +436,72 @@ def construct_pod( return pod + @classmethod + def build_selector_for_k8s_executor_pod( + cls, + *, + dag_id, + task_id, + try_number, + map_index=None, + execution_date=None, + run_id=None, + airflow_worker=None, + ): + """ + Generate selector for kubernetes executor pod + + :meta private: + """ + labels = cls.build_labels_for_k8s_executor_pod( + dag_id=dag_id, + task_id=task_id, + try_number=try_number, + map_index=map_index, + execution_date=execution_date, + run_id=run_id, + airflow_worker=airflow_worker, + ) + label_strings = [f"{label_id}={label}" for label_id, label in sorted(labels.items())] + selector = ",".join(label_strings) + if not airflow_worker: # this filters out KPO pods even when we don't know the scheduler job id + selector += ",airflow-worker" + return selector + + @classmethod + def build_labels_for_k8s_executor_pod( + cls, + *, + dag_id, + task_id, + try_number, + airflow_worker=None, + map_index=None, + execution_date=None, + run_id=None, + ): + """ + Generate labels for kubernetes executor pod + + :meta private: + """ + labels = { + "dag_id": make_safe_label_value(dag_id), + "task_id": make_safe_label_value(task_id), + "try_number": str(try_number), + "kubernetes_executor": "True", + "airflow_version": airflow_version.replace("+", "-"), + } + if airflow_worker is not None: + labels["airflow-worker"] = make_safe_label_value(str(airflow_worker)) + if map_index is not None and map_index >= 0: + labels["map_index"] = str(map_index) + if execution_date: + labels["execution_date"] = datetime_to_label_safe_datestring(execution_date) + if run_id: + labels["run_id"] = make_safe_label_value(run_id) + return labels + @staticmethod def serialize_pod(pod: k8s.V1Pod) -> dict: """ diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py index 62d4262a84bdb..c34fc5c02bbbf 100644 --- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py +++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py @@ -92,7 +92,7 @@ def _create_pod_id( dag_id: str | None = None, task_id: str | None = None, *, - max_length: int = 63, + max_length: int = 80, unique: bool = True, ) -> str: """ diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index b8feb2997bbc7..cce6b39984777 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -21,6 +21,7 @@ import logging import os import warnings +from contextlib import suppress from pathlib import Path from typing import TYPE_CHECKING, Any from urllib.parse import urljoin @@ -191,19 +192,32 @@ def _read(self, ti: TaskInstance, try_number: int, metadata: dict[str, Any] | No log += f"*** {str(e)}\n" return log, {"end_of_log": True} elif self._should_check_k8s(ti.queue): - pod_override = ti.executor_config.get("pod_override") - if pod_override and pod_override.metadata and pod_override.metadata.namespace: - namespace = pod_override.metadata.namespace - else: - namespace = conf.get("kubernetes_executor", "namespace") try: from airflow.kubernetes.kube_client import get_kube_client + from airflow.kubernetes.pod_generator import PodGenerator - kube_client = get_kube_client() + client = get_kube_client() log += f"*** Trying to get logs (last 100 lines) from worker pod {ti.hostname} ***\n\n" - res = kube_client.read_namespaced_pod_log( - name=ti.hostname, + selector = PodGenerator.build_selector_for_k8s_executor_pod( + dag_id=ti.dag_id, + task_id=ti.task_id, + try_number=ti.try_number, + map_index=ti.map_index, + run_id=ti.run_id, + airflow_worker=ti.queued_by_job_id, + ) + namespace = self._get_pod_namespace(ti) + pod_list = client.list_namespaced_pod( + namespace=namespace, + label_selector=selector, + ).items + if not pod_list: + raise RuntimeError("Cannot find pod for ti %s", ti) + elif len(pod_list) > 1: + raise RuntimeError("Found multiple pods for ti %s: %s", ti, pod_list) + res = client.read_namespaced_pod_log( + name=pod_list[0].metadata.name, namespace=namespace, container="base", follow=False, @@ -272,6 +286,14 @@ def _read(self, ti: TaskInstance, try_number: int, metadata: dict[str, Any] | No return log, {"end_of_log": end_of_log, "log_pos": log_pos} + @staticmethod + def _get_pod_namespace(ti: TaskInstance): + pod_override = ti.executor_config.get("pod_override") + namespace = None + with suppress(Exception): + namespace = pod_override.metadata.namespace + return namespace or conf.get("kubernetes_executor", "namespace", fallback="default") + @staticmethod def _get_log_retrieval_url(ti: TaskInstance, log_relative_path: str) -> str: url = urljoin( diff --git a/tests/kubernetes/test_kubernetes_helper_functions.py b/tests/kubernetes/test_kubernetes_helper_functions.py index 52a453e2e8f01..76512c657aa0c 100644 --- a/tests/kubernetes/test_kubernetes_helper_functions.py +++ b/tests/kubernetes/test_kubernetes_helper_functions.py @@ -88,14 +88,14 @@ def test_create_pod_id_dag_and_task(self, dag_id, task_id, expected, create_pod_ def test_create_pod_id_dag_too_long_with_suffix(self, create_pod_id): actual = create_pod_id("0" * 254) - assert len(actual) == 63 - assert re.match(r"0{54}-[a-z0-9]{8}", actual) + assert len(actual) == 80 + assert re.match(r"0{71}-[a-z0-9]{8}", actual) assert re.match(pod_name_regex, actual) def test_create_pod_id_dag_too_long_non_unique(self, create_pod_id): actual = create_pod_id("0" * 254, unique=False) - assert len(actual) == 63 - assert re.match(r"0{63}", actual) + assert len(actual) == 80 + assert re.match(r"0{80}", actual) assert re.match(pod_name_regex, actual) @pytest.mark.parametrize("unique", [True, False]) diff --git a/tests/kubernetes/test_pod_generator.py b/tests/kubernetes/test_pod_generator.py index 1314f763e8086..578e8e5eddf3d 100644 --- a/tests/kubernetes/test_pod_generator.py +++ b/tests/kubernetes/test_pod_generator.py @@ -22,6 +22,7 @@ from unittest import mock from unittest.mock import MagicMock +import pendulum import pytest from dateutil import parser from kubernetes.client import ApiClient, models as k8s @@ -38,6 +39,8 @@ ) from airflow.kubernetes.secret import Secret +now = pendulum.now("UTC") + class TestPodGenerator: def setup_method(self): @@ -476,7 +479,7 @@ def test_construct_pod_mapped_task(self): result_dict = self.k8s_client.sanitize_for_serialization(result) expected_dict = self.k8s_client.sanitize_for_serialization(expected) - assert expected_dict == result_dict + assert result_dict == expected_dict def test_construct_pod_empty_executor_config(self): path = sys.path[0] + "/tests/kubernetes/pod_generator_base_with_secrets.yaml" @@ -772,3 +775,49 @@ def test_validate_pod_generator(self): PodGenerator() PodGenerator(pod_template_file="tests/kubernetes/pod.yaml") PodGenerator(pod=k8s.V1Pod()) + + @pytest.mark.parametrize( + "extra, extra_expected", + [ + param(dict(), {}, id="base"), + param(dict(airflow_worker=2), {"airflow-worker": "2"}, id="worker"), + param(dict(map_index=2), {"map_index": "2"}, id="map_index"), + param(dict(run_id="2"), {"run_id": "2"}, id="run_id"), + param( + dict(execution_date=now), + {"execution_date": datetime_to_label_safe_datestring(now)}, + id="date", + ), + param( + dict(airflow_worker=2, map_index=2, run_id="2", execution_date=now), + { + "airflow-worker": "2", + "map_index": "2", + "run_id": "2", + "execution_date": datetime_to_label_safe_datestring(now), + }, + id="all", + ), + ], + ) + def test_build_labels_for_k8s_executor_pod(self, extra, extra_expected): + from airflow.version import version as airflow_version + + kwargs = dict( + dag_id="dag*", + task_id="task*", + try_number=1, + ) + expected = dict( + dag_id="dag-6b24921d4", + task_id="task-b6aca8991", + try_number="1", + airflow_version=airflow_version, + kubernetes_executor="True", + ) + labels = PodGenerator.build_labels_for_k8s_executor_pod(**kwargs, **extra) + assert labels == {**expected, **extra_expected} + exp_selector = ",".join([f"{k}={v}" for k, v in sorted(labels.items())]) + if "airflow_worker" not in extra: + exp_selector += ",airflow-worker" + assert PodGenerator.build_selector_for_k8s_executor_pod(**kwargs, **extra) == exp_selector diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index 8b7f0145dee12..3a49a47f6bcd4 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -21,7 +21,7 @@ import logging.config import os import re -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from kubernetes.client import models as k8s @@ -235,42 +235,58 @@ def task_callable(ti): def test_read_from_k8s_under_multi_namespace_mode( self, mock_kube_client, pod_override, namespace_to_call ): - mock_read_namespaced_pod_log = MagicMock() - mock_kube_client.return_value.read_namespaced_pod_log = mock_read_namespaced_pod_log + mock_read_log = mock_kube_client.return_value.read_namespaced_pod_log + mock_list_pod = mock_kube_client.return_value.list_namespaced_pod def task_callable(ti): ti.log.info("test") - dag = DAG("dag_for_testing_file_task_handler", start_date=DEFAULT_DATE) + with DAG("dag_for_testing_file_task_handler", start_date=DEFAULT_DATE) as dag: + task = PythonOperator( + task_id="task_for_testing_file_log_handler", + python_callable=task_callable, + executor_config={"pod_override": pod_override}, + ) dagrun = dag.create_dagrun( run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE, ) - executor_config_pod = pod_override - task = PythonOperator( - task_id="task_for_testing_file_log_handler", - dag=dag, - python_callable=task_callable, - executor_config={"pod_override": executor_config_pod}, - ) ti = TaskInstance(task=task, run_id=dagrun.run_id) ti.try_number = 3 logger = ti.log ti.log.disabled = False - file_handler = next( - (handler for handler in logger.handlers if handler.name == FILE_TASK_HANDLER), None - ) + file_handler = next((h for h in logger.handlers if h.name == FILE_TASK_HANDLER), None) set_context(logger, ti) ti.run(ignore_ti_state=True) file_handler.read(ti, 3) - # Check if kube_client.read_namespaced_pod_log() is called with the namespace we expect - mock_read_namespaced_pod_log.assert_called_once_with( - name=ti.hostname, + # first we find pod name + mock_list_pod.assert_called_once() + actual_kwargs = mock_list_pod.call_args[1] + assert actual_kwargs["namespace"] == namespace_to_call + actual_selector = actual_kwargs["label_selector"] + assert re.match( + ",".join( + [ + "airflow_version=.+?", + "dag_id=dag_for_testing_file_task_handler", + "kubernetes_executor=True", + "run_id=manual__2016-01-01T0000000000-2b88d1d57", + "task_id=task_for_testing_file_log_handler", + "try_number=.+?", + "airflow-worker", + ] + ), + actual_selector, + ) + + # then we read log + mock_read_log.assert_called_once_with( + name=mock_list_pod.return_value.items[0].metadata.name, namespace=namespace_to_call, container="base", follow=False,