From 1654ddcde86e58d46553858c59c14e3adba9bdf1 Mon Sep 17 00:00:00 2001 From: "jj.lee" <63435794+jx2lee@users.noreply.github.com> Date: Thu, 6 Feb 2025 00:34:05 +0900 Subject: [PATCH] Enable to add inline ssh key in GitHook (#46181) --- airflow/dag_processing/bundles/git.py | 48 ++++++++++++--- tests/dag_processing/test_dag_bundles.py | 78 +++++++++++++++++++++--- 2 files changed, 110 insertions(+), 16 deletions(-) diff --git a/airflow/dag_processing/bundles/git.py b/airflow/dag_processing/bundles/git.py index abebbb4a33820..60da6a678ed53 100644 --- a/airflow/dag_processing/bundles/git.py +++ b/airflow/dag_processing/bundles/git.py @@ -17,8 +17,10 @@ from __future__ import annotations +import contextlib import json import os +import tempfile from typing import TYPE_CHECKING, Any from urllib.parse import urlparse @@ -60,6 +62,7 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: "extra": json.dumps( { "key_file": "optional/path/to/keyfile", + "private_key": "optional inline private key", } ) }, @@ -70,15 +73,22 @@ def __init__(self, git_conn_id="git_default", *args, **kwargs): connection = self.get_connection(git_conn_id) self.repo_url = connection.host self.auth_token = connection.password + self.private_key = connection.extra_dejson.get("private_key") self.key_file = connection.extra_dejson.get("key_file") - strict_host_key_checking = connection.extra_dejson.get("strict_host_key_checking", "no") + self.strict_host_key_checking = connection.extra_dejson.get("strict_host_key_checking", "no") self.env: dict[str, str] = {} - if self.key_file: - self.env["GIT_SSH_COMMAND"] = ( - f"ssh -i {self.key_file} -o IdentitiesOnly=yes -o StrictHostKeyChecking={strict_host_key_checking}" - ) + + if self.key_file and self.private_key: + raise AirflowException("Both 'key_file' and 'private_key' cannot be provided at the same time") self._process_git_auth_url() + def _build_ssh_command(self, key_path: str) -> str: + return ( + f"ssh -i {key_path} " + f"-o IdentitiesOnly=yes " + f"-o StrictHostKeyChecking={self.strict_host_key_checking}" + ) + def _process_git_auth_url(self): if not isinstance(self.repo_url, str): return @@ -87,6 +97,22 @@ def _process_git_auth_url(self): elif not self.repo_url.startswith("git@") or not self.repo_url.startswith("https://"): self.repo_url = os.path.expanduser(self.repo_url) + def set_git_env(self, key: str) -> None: + self.env["GIT_SSH_COMMAND"] = self._build_ssh_command(key) + + @contextlib.contextmanager + def configure_hook_env(self): + if self.private_key: + with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp_keyfile: + tmp_keyfile.write(self.private_key) + tmp_keyfile.flush() + os.chmod(tmp_keyfile.name, 0o600) + self.set_git_env(tmp_keyfile.name) + yield + else: + self.set_git_env(self.key_file) + yield + class GitDagBundle(BaseDagBundle, LoggingMixin): """ @@ -128,8 +154,10 @@ def __init__( self.log.warning("Could not create GitHook for connection %s : %s", self.git_conn_id, e) def _initialize(self): - self._clone_bare_repo_if_required() - self._ensure_version_in_bare_repo() + with self.hook.configure_hook_env(): + self._clone_bare_repo_if_required() + self._ensure_version_in_bare_repo() + self._clone_repo_if_required() self.repo.git.checkout(self.tracking_ref) if self.version: @@ -230,8 +258,10 @@ def _fetch_bare_repo(self): def refresh(self) -> None: if self.version: raise AirflowException("Refreshing a specific version is not supported") - self._fetch_bare_repo() - self.repo.remotes.origin.pull() + + with self.hook.configure_hook_env(): + self._fetch_bare_repo() + self.repo.remotes.origin.pull() @staticmethod def _convert_git_ssh_url_to_https(url: str) -> str: diff --git a/tests/dag_processing/test_dag_bundles.py b/tests/dag_processing/test_dag_bundles.py index 6ab1f3ea68fb6..a9a48139ba9eb 100644 --- a/tests/dag_processing/test_dag_bundles.py +++ b/tests/dag_processing/test_dag_bundles.py @@ -17,6 +17,7 @@ from __future__ import annotations +import os import re import tempfile from pathlib import Path @@ -108,6 +109,8 @@ def git_repo(tmp_path_factory): CONN_HTTPS = "my_git_conn" CONN_HTTPS_PASSWORD = "my_git_conn_https_password" CONN_ONLY_PATH = "my_git_conn_only_path" +CONN_ONLY_INLINE_KEY = "my_git_conn_only_inline_key" +CONN_BOTH_PATH_INLINE = "my_git_conn_both_path_inline" CONN_NO_REPO_URL = "my_git_conn_no_repo_url" @@ -146,6 +149,16 @@ def setup_class(cls) -> None: conn_type="git", ) ) + db.merge_conn( + Connection( + conn_id=CONN_ONLY_INLINE_KEY, + host="path/to/repo", + conn_type="git", + extra={ + "private_key": "inline_key", + }, + ) + ) @pytest.mark.parametrize( "conn_id, expected_repo_url", @@ -160,11 +173,12 @@ def test_correct_repo_urls(self, conn_id, expected_repo_url): hook = GitHook(git_conn_id=conn_id) assert hook.repo_url == expected_repo_url - def test_env_var(self, session): - hook = GitHook(git_conn_id=CONN_DEFAULT) - assert hook.env == { - "GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=no" - } + def test_env_var_with_configure_hook_env(self, session): + default_hook = GitHook(git_conn_id=CONN_DEFAULT) + with default_hook.configure_hook_env(): + assert default_hook.env == { + "GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=no" + } db.merge_conn( Connection( conn_id="my_git_conn_strict", @@ -174,11 +188,61 @@ def test_env_var(self, session): ) ) - hook = GitHook(git_conn_id="my_git_conn_strict") + strict_default_hook = GitHook(git_conn_id="my_git_conn_strict") + with strict_default_hook.configure_hook_env(): + assert strict_default_hook.env == { + "GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=yes" + } + + def test_given_both_private_key_and_key_file(self): + db.merge_conn( + Connection( + conn_id=CONN_BOTH_PATH_INLINE, + host="path/to/repo", + conn_type="git", + extra={ + "key_file": "path/to/key", + "private_key": "inline_key", + }, + ) + ) + + with pytest.raises( + AirflowException, match="Both 'key_file' and 'private_key' cannot be provided at the same time" + ): + GitHook(git_conn_id=CONN_BOTH_PATH_INLINE) + + def test_key_file_git_hook_has_env_with_configure_hook_env(self): + hook = GitHook(git_conn_id=CONN_DEFAULT) + + assert hasattr(hook, "env") + with hook.configure_hook_env(): + assert hook.env == { + "GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=no" + } + + def test_private_key_lazy_env_var(self): + hook = GitHook(git_conn_id=CONN_ONLY_INLINE_KEY) + assert hook.env == {} + + hook.set_git_env("dummy_inline_key") assert hook.env == { - "GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=yes" + "GIT_SSH_COMMAND": "ssh -i dummy_inline_key -o IdentitiesOnly=yes -o StrictHostKeyChecking=no" } + def test_configure_hook_env(self): + hook = GitHook(git_conn_id=CONN_ONLY_INLINE_KEY) + assert hasattr(hook, "private_key") + + hook.set_git_env("dummy_inline_key") + + with hook.configure_hook_env(): + command = hook.env.get("GIT_SSH_COMMAND") + temp_key_path = command.split()[2] + assert os.path.exists(temp_key_path) + + assert not os.path.exists(temp_key_path) + class TestGitDagBundle: @classmethod