Skip to content

Commit

Permalink
Use labels instead of pod name for pod log read in k8s exec (#28546)
Browse files Browse the repository at this point in the history
This means we don't have to use ti.hostname as a proxy for pod name, and allows us to lift the 63 charcter limit, which was a consequence of getting pod name through hostname.
  • Loading branch information
dstandish authored Dec 30, 2022
1 parent 0688862 commit c22fc00
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 58 deletions.
9 changes: 4 additions & 5 deletions airflow/kubernetes/kubernetes_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,15 @@ def create_pod_id(
dag_id: str | None = None,
task_id: str | None = None,
*,
max_length: int = 63, # must be 63 for now, see below
max_length: int = 80,
unique: bool = True,
) -> str:
"""
Generates unique pod ID given a dag_id and / or task_id.
Because of the way that the task log handler reads from running k8s executor pods,
we must keep pod name <= 63 characters. The handler gets pod name from ti.hostname.
TI hostname is derived from the container hostname, which is truncated to 63 characters.
We could lift this limit by using label selectors instead of pod name to find the pod.
The default of 80 for max length is somewhat arbitrary, mainly a balance between
content and not overwhelming terminal windows of reasonable width. The true
upper limit is 253, and this is enforced in construct_pod.
:param dag_id: DAG ID
:param task_id: Task ID
Expand Down
99 changes: 77 additions & 22 deletions airflow/kubernetes/pod_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,9 @@ def reconcile_containers(
base_containers[1:], client_containers[1:]
)

@staticmethod
@classmethod
def construct_pod(
cls,
dag_id: str,
task_id: str,
pod_id: str,
Expand Down Expand Up @@ -370,15 +371,6 @@ def construct_pod(
"pod_id supplied is longer than 253 characters; truncating and adding unique suffix."
)
pod_id = add_pod_suffix(pod_name=pod_id, max_len=253)
if len(pod_id) > 63:
# because in task handler we get pod name from ti hostname (which truncates
# pod_id to 63 characters) we won't be able to find the pod unless it is <= 63 characters.
# our code creates pod names shorter than this so this warning should not normally be triggered.
warnings.warn(
"Supplied pod_id is longer than 63 characters. Due to implementation details, the webserver "
"may not be able to stream logs while task is running. Please choose a shorter pod name."
)

try:
image = pod_override_object.spec.containers[0].image # type: ignore
if not image:
Expand All @@ -391,30 +383,27 @@ def construct_pod(
"task_id": task_id,
"try_number": str(try_number),
}
labels = {
"airflow-worker": make_safe_label_value(scheduler_job_id),
"dag_id": make_safe_label_value(dag_id),
"task_id": make_safe_label_value(task_id),
"try_number": str(try_number),
"airflow_version": airflow_version.replace("+", "-"),
"kubernetes_executor": "True",
}
if map_index >= 0:
annotations["map_index"] = str(map_index)
labels["map_index"] = str(map_index)
if date:
annotations["execution_date"] = date.isoformat()
labels["execution_date"] = datetime_to_label_safe_datestring(date)
if run_id:
annotations["run_id"] = run_id
labels["run_id"] = make_safe_label_value(run_id)

dynamic_pod = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(
namespace=namespace,
annotations=annotations,
name=pod_id,
labels=labels,
labels=cls.build_labels_for_k8s_executor_pod(
dag_id=dag_id,
task_id=task_id,
try_number=try_number,
airflow_worker=scheduler_job_id,
map_index=map_index,
execution_date=date,
run_id=run_id,
),
),
spec=k8s.V1PodSpec(
containers=[
Expand Down Expand Up @@ -447,6 +436,72 @@ def construct_pod(

return pod

@classmethod
def build_selector_for_k8s_executor_pod(
cls,
*,
dag_id,
task_id,
try_number,
map_index=None,
execution_date=None,
run_id=None,
airflow_worker=None,
):
"""
Generate selector for kubernetes executor pod
:meta private:
"""
labels = cls.build_labels_for_k8s_executor_pod(
dag_id=dag_id,
task_id=task_id,
try_number=try_number,
map_index=map_index,
execution_date=execution_date,
run_id=run_id,
airflow_worker=airflow_worker,
)
label_strings = [f"{label_id}={label}" for label_id, label in sorted(labels.items())]
selector = ",".join(label_strings)
if not airflow_worker: # this filters out KPO pods even when we don't know the scheduler job id
selector += ",airflow-worker"
return selector

@classmethod
def build_labels_for_k8s_executor_pod(
cls,
*,
dag_id,
task_id,
try_number,
airflow_worker=None,
map_index=None,
execution_date=None,
run_id=None,
):
"""
Generate labels for kubernetes executor pod
:meta private:
"""
labels = {
"dag_id": make_safe_label_value(dag_id),
"task_id": make_safe_label_value(task_id),
"try_number": str(try_number),
"kubernetes_executor": "True",
"airflow_version": airflow_version.replace("+", "-"),
}
if airflow_worker is not None:
labels["airflow-worker"] = make_safe_label_value(str(airflow_worker))
if map_index is not None and map_index >= 0:
labels["map_index"] = str(map_index)
if execution_date:
labels["execution_date"] = datetime_to_label_safe_datestring(execution_date)
if run_id:
labels["run_id"] = make_safe_label_value(run_id)
return labels

@staticmethod
def serialize_pod(pod: k8s.V1Pod) -> dict:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _create_pod_id(
dag_id: str | None = None,
task_id: str | None = None,
*,
max_length: int = 63,
max_length: int = 80,
unique: bool = True,
) -> str:
"""
Expand Down
38 changes: 30 additions & 8 deletions airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import logging
import os
import warnings
from contextlib import suppress
from pathlib import Path
from typing import TYPE_CHECKING, Any
from urllib.parse import urljoin
Expand Down Expand Up @@ -191,19 +192,32 @@ def _read(self, ti: TaskInstance, try_number: int, metadata: dict[str, Any] | No
log += f"*** {str(e)}\n"
return log, {"end_of_log": True}
elif self._should_check_k8s(ti.queue):
pod_override = ti.executor_config.get("pod_override")
if pod_override and pod_override.metadata and pod_override.metadata.namespace:
namespace = pod_override.metadata.namespace
else:
namespace = conf.get("kubernetes_executor", "namespace")
try:
from airflow.kubernetes.kube_client import get_kube_client
from airflow.kubernetes.pod_generator import PodGenerator

kube_client = get_kube_client()
client = get_kube_client()

log += f"*** Trying to get logs (last 100 lines) from worker pod {ti.hostname} ***\n\n"
res = kube_client.read_namespaced_pod_log(
name=ti.hostname,
selector = PodGenerator.build_selector_for_k8s_executor_pod(
dag_id=ti.dag_id,
task_id=ti.task_id,
try_number=ti.try_number,
map_index=ti.map_index,
run_id=ti.run_id,
airflow_worker=ti.queued_by_job_id,
)
namespace = self._get_pod_namespace(ti)
pod_list = client.list_namespaced_pod(
namespace=namespace,
label_selector=selector,
).items
if not pod_list:
raise RuntimeError("Cannot find pod for ti %s", ti)
elif len(pod_list) > 1:
raise RuntimeError("Found multiple pods for ti %s: %s", ti, pod_list)
res = client.read_namespaced_pod_log(
name=pod_list[0].metadata.name,
namespace=namespace,
container="base",
follow=False,
Expand Down Expand Up @@ -272,6 +286,14 @@ def _read(self, ti: TaskInstance, try_number: int, metadata: dict[str, Any] | No

return log, {"end_of_log": end_of_log, "log_pos": log_pos}

@staticmethod
def _get_pod_namespace(ti: TaskInstance):
pod_override = ti.executor_config.get("pod_override")
namespace = None
with suppress(Exception):
namespace = pod_override.metadata.namespace
return namespace or conf.get("kubernetes_executor", "namespace", fallback="default")

@staticmethod
def _get_log_retrieval_url(ti: TaskInstance, log_relative_path: str) -> str:
url = urljoin(
Expand Down
8 changes: 4 additions & 4 deletions tests/kubernetes/test_kubernetes_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@ def test_create_pod_id_dag_and_task(self, dag_id, task_id, expected, create_pod_

def test_create_pod_id_dag_too_long_with_suffix(self, create_pod_id):
actual = create_pod_id("0" * 254)
assert len(actual) == 63
assert re.match(r"0{54}-[a-z0-9]{8}", actual)
assert len(actual) == 80
assert re.match(r"0{71}-[a-z0-9]{8}", actual)
assert re.match(pod_name_regex, actual)

def test_create_pod_id_dag_too_long_non_unique(self, create_pod_id):
actual = create_pod_id("0" * 254, unique=False)
assert len(actual) == 63
assert re.match(r"0{63}", actual)
assert len(actual) == 80
assert re.match(r"0{80}", actual)
assert re.match(pod_name_regex, actual)

@pytest.mark.parametrize("unique", [True, False])
Expand Down
51 changes: 50 additions & 1 deletion tests/kubernetes/test_pod_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from unittest import mock
from unittest.mock import MagicMock

import pendulum
import pytest
from dateutil import parser
from kubernetes.client import ApiClient, models as k8s
Expand All @@ -38,6 +39,8 @@
)
from airflow.kubernetes.secret import Secret

now = pendulum.now("UTC")


class TestPodGenerator:
def setup_method(self):
Expand Down Expand Up @@ -476,7 +479,7 @@ def test_construct_pod_mapped_task(self):
result_dict = self.k8s_client.sanitize_for_serialization(result)
expected_dict = self.k8s_client.sanitize_for_serialization(expected)

assert expected_dict == result_dict
assert result_dict == expected_dict

def test_construct_pod_empty_executor_config(self):
path = sys.path[0] + "/tests/kubernetes/pod_generator_base_with_secrets.yaml"
Expand Down Expand Up @@ -772,3 +775,49 @@ def test_validate_pod_generator(self):
PodGenerator()
PodGenerator(pod_template_file="tests/kubernetes/pod.yaml")
PodGenerator(pod=k8s.V1Pod())

@pytest.mark.parametrize(
"extra, extra_expected",
[
param(dict(), {}, id="base"),
param(dict(airflow_worker=2), {"airflow-worker": "2"}, id="worker"),
param(dict(map_index=2), {"map_index": "2"}, id="map_index"),
param(dict(run_id="2"), {"run_id": "2"}, id="run_id"),
param(
dict(execution_date=now),
{"execution_date": datetime_to_label_safe_datestring(now)},
id="date",
),
param(
dict(airflow_worker=2, map_index=2, run_id="2", execution_date=now),
{
"airflow-worker": "2",
"map_index": "2",
"run_id": "2",
"execution_date": datetime_to_label_safe_datestring(now),
},
id="all",
),
],
)
def test_build_labels_for_k8s_executor_pod(self, extra, extra_expected):
from airflow.version import version as airflow_version

kwargs = dict(
dag_id="dag*",
task_id="task*",
try_number=1,
)
expected = dict(
dag_id="dag-6b24921d4",
task_id="task-b6aca8991",
try_number="1",
airflow_version=airflow_version,
kubernetes_executor="True",
)
labels = PodGenerator.build_labels_for_k8s_executor_pod(**kwargs, **extra)
assert labels == {**expected, **extra_expected}
exp_selector = ",".join([f"{k}={v}" for k, v in sorted(labels.items())])
if "airflow_worker" not in extra:
exp_selector += ",airflow-worker"
assert PodGenerator.build_selector_for_k8s_executor_pod(**kwargs, **extra) == exp_selector
Loading

0 comments on commit c22fc00

Please sign in to comment.