Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable to add inline ssh key in GitHook #46181

Merged
merged 14 commits into from
Feb 5, 2025
48 changes: 39 additions & 9 deletions airflow/dag_processing/bundles/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

from __future__ import annotations

import contextlib
import json
import os
import tempfile
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse

Expand Down Expand Up @@ -60,6 +62,7 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
"extra": json.dumps(
{
"key_file": "optional/path/to/keyfile",
"private_key": "optional inline private key",
}
)
},
Expand All @@ -70,15 +73,22 @@ def __init__(self, git_conn_id="git_default", *args, **kwargs):
connection = self.get_connection(git_conn_id)
self.repo_url = connection.host
self.auth_token = connection.password
self.private_key = connection.extra_dejson.get("private_key")
self.key_file = connection.extra_dejson.get("key_file")
strict_host_key_checking = connection.extra_dejson.get("strict_host_key_checking", "no")
self.strict_host_key_checking = connection.extra_dejson.get("strict_host_key_checking", "no")
self.env: dict[str, str] = {}
if self.key_file:
self.env["GIT_SSH_COMMAND"] = (
f"ssh -i {self.key_file} -o IdentitiesOnly=yes -o StrictHostKeyChecking={strict_host_key_checking}"
)

if self.key_file and self.private_key:
raise AirflowException("Both 'key_file' and 'private_key' cannot be provided at the same time")
self._process_git_auth_url()

def _build_ssh_command(self, key_path: str) -> str:
return (
f"ssh -i {key_path} "
f"-o IdentitiesOnly=yes "
f"-o StrictHostKeyChecking={self.strict_host_key_checking}"
)

def _process_git_auth_url(self):
if not isinstance(self.repo_url, str):
return
Expand All @@ -87,6 +97,22 @@ def _process_git_auth_url(self):
elif not self.repo_url.startswith("git@") or not self.repo_url.startswith("https://"):
self.repo_url = os.path.expanduser(self.repo_url)

def set_git_env(self, key: str) -> None:
self.env["GIT_SSH_COMMAND"] = self._build_ssh_command(key)

@contextlib.contextmanager
def configure_hook_env(self):
if self.private_key:
with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp_keyfile:
tmp_keyfile.write(self.private_key)
tmp_keyfile.flush()
os.chmod(tmp_keyfile.name, 0o600)
self.set_git_env(tmp_keyfile.name)
yield
else:
self.set_git_env(self.key_file)
yield


class GitDagBundle(BaseDagBundle, LoggingMixin):
"""
Expand Down Expand Up @@ -128,8 +154,10 @@ def __init__(
self.log.warning("Could not create GitHook for connection %s : %s", self.git_conn_id, e)

def _initialize(self):
self._clone_bare_repo_if_required()
self._ensure_version_in_bare_repo()
with self.hook.configure_hook_env():
self._clone_bare_repo_if_required()
self._ensure_version_in_bare_repo()

self._clone_repo_if_required()
self.repo.git.checkout(self.tracking_ref)
if self.version:
Expand Down Expand Up @@ -230,8 +258,10 @@ def _fetch_bare_repo(self):
def refresh(self) -> None:
if self.version:
raise AirflowException("Refreshing a specific version is not supported")
self._fetch_bare_repo()
self.repo.remotes.origin.pull()

with self.hook.configure_hook_env():
self._fetch_bare_repo()
self.repo.remotes.origin.pull()

@staticmethod
def _convert_git_ssh_url_to_https(url: str) -> str:
Expand Down
78 changes: 71 additions & 7 deletions tests/dag_processing/test_dag_bundles.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import os
import re
import tempfile
from pathlib import Path
Expand Down Expand Up @@ -108,6 +109,8 @@ def git_repo(tmp_path_factory):
CONN_HTTPS = "my_git_conn"
CONN_HTTPS_PASSWORD = "my_git_conn_https_password"
CONN_ONLY_PATH = "my_git_conn_only_path"
CONN_ONLY_INLINE_KEY = "my_git_conn_only_inline_key"
CONN_BOTH_PATH_INLINE = "my_git_conn_both_path_inline"
CONN_NO_REPO_URL = "my_git_conn_no_repo_url"


Expand Down Expand Up @@ -146,6 +149,16 @@ def setup_class(cls) -> None:
conn_type="git",
)
)
db.merge_conn(
Connection(
conn_id=CONN_ONLY_INLINE_KEY,
host="path/to/repo",
conn_type="git",
extra={
"private_key": "inline_key",
},
)
)

@pytest.mark.parametrize(
"conn_id, expected_repo_url",
Expand All @@ -160,11 +173,12 @@ def test_correct_repo_urls(self, conn_id, expected_repo_url):
hook = GitHook(git_conn_id=conn_id)
assert hook.repo_url == expected_repo_url

def test_env_var(self, session):
hook = GitHook(git_conn_id=CONN_DEFAULT)
assert hook.env == {
"GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=no"
}
def test_env_var_with_configure_hook_env(self, session):
default_hook = GitHook(git_conn_id=CONN_DEFAULT)
with default_hook.configure_hook_env():
assert default_hook.env == {
"GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=no"
}
db.merge_conn(
Connection(
conn_id="my_git_conn_strict",
Expand All @@ -174,11 +188,61 @@ def test_env_var(self, session):
)
)

hook = GitHook(git_conn_id="my_git_conn_strict")
strict_default_hook = GitHook(git_conn_id="my_git_conn_strict")
with strict_default_hook.configure_hook_env():
assert strict_default_hook.env == {
"GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=yes"
}

def test_given_both_private_key_and_key_file(self):
db.merge_conn(
Connection(
conn_id=CONN_BOTH_PATH_INLINE,
host="path/to/repo",
conn_type="git",
extra={
"key_file": "path/to/key",
"private_key": "inline_key",
},
)
)

with pytest.raises(
AirflowException, match="Both 'key_file' and 'private_key' cannot be provided at the same time"
):
GitHook(git_conn_id=CONN_BOTH_PATH_INLINE)

def test_key_file_git_hook_has_env_with_configure_hook_env(self):
hook = GitHook(git_conn_id=CONN_DEFAULT)

assert hasattr(hook, "env")
with hook.configure_hook_env():
assert hook.env == {
"GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=no"
}

def test_private_key_lazy_env_var(self):
hook = GitHook(git_conn_id=CONN_ONLY_INLINE_KEY)
assert hook.env == {}

hook.set_git_env("dummy_inline_key")
assert hook.env == {
"GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=yes"
"GIT_SSH_COMMAND": "ssh -i dummy_inline_key -o IdentitiesOnly=yes -o StrictHostKeyChecking=no"
}

def test_configure_hook_env(self):
hook = GitHook(git_conn_id=CONN_ONLY_INLINE_KEY)
assert hasattr(hook, "private_key")

hook.set_git_env("dummy_inline_key")

with hook.configure_hook_env():
command = hook.env.get("GIT_SSH_COMMAND")
temp_key_path = command.split()[2]
assert os.path.exists(temp_key_path)

assert not os.path.exists(temp_key_path)


class TestGitDagBundle:
@classmethod
Expand Down