Skip to content

Commit

Permalink
Enable to add inline ssh key in GitHook (apache#46181)
Browse files Browse the repository at this point in the history
  • Loading branch information
jx2lee authored and insomnes committed Feb 6, 2025
1 parent 532a7bf commit 1654ddc
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 16 deletions.
48 changes: 39 additions & 9 deletions airflow/dag_processing/bundles/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

from __future__ import annotations

import contextlib
import json
import os
import tempfile
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse

Expand Down Expand Up @@ -60,6 +62,7 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
"extra": json.dumps(
{
"key_file": "optional/path/to/keyfile",
"private_key": "optional inline private key",
}
)
},
Expand All @@ -70,15 +73,22 @@ def __init__(self, git_conn_id="git_default", *args, **kwargs):
connection = self.get_connection(git_conn_id)
self.repo_url = connection.host
self.auth_token = connection.password
self.private_key = connection.extra_dejson.get("private_key")
self.key_file = connection.extra_dejson.get("key_file")
strict_host_key_checking = connection.extra_dejson.get("strict_host_key_checking", "no")
self.strict_host_key_checking = connection.extra_dejson.get("strict_host_key_checking", "no")
self.env: dict[str, str] = {}
if self.key_file:
self.env["GIT_SSH_COMMAND"] = (
f"ssh -i {self.key_file} -o IdentitiesOnly=yes -o StrictHostKeyChecking={strict_host_key_checking}"
)

if self.key_file and self.private_key:
raise AirflowException("Both 'key_file' and 'private_key' cannot be provided at the same time")
self._process_git_auth_url()

def _build_ssh_command(self, key_path: str) -> str:
return (
f"ssh -i {key_path} "
f"-o IdentitiesOnly=yes "
f"-o StrictHostKeyChecking={self.strict_host_key_checking}"
)

def _process_git_auth_url(self):
if not isinstance(self.repo_url, str):
return
Expand All @@ -87,6 +97,22 @@ def _process_git_auth_url(self):
elif not self.repo_url.startswith("git@") or not self.repo_url.startswith("https://"):
self.repo_url = os.path.expanduser(self.repo_url)

def set_git_env(self, key: str) -> None:
self.env["GIT_SSH_COMMAND"] = self._build_ssh_command(key)

@contextlib.contextmanager
def configure_hook_env(self):
if self.private_key:
with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp_keyfile:
tmp_keyfile.write(self.private_key)
tmp_keyfile.flush()
os.chmod(tmp_keyfile.name, 0o600)
self.set_git_env(tmp_keyfile.name)
yield
else:
self.set_git_env(self.key_file)
yield


class GitDagBundle(BaseDagBundle, LoggingMixin):
"""
Expand Down Expand Up @@ -128,8 +154,10 @@ def __init__(
self.log.warning("Could not create GitHook for connection %s : %s", self.git_conn_id, e)

def _initialize(self):
self._clone_bare_repo_if_required()
self._ensure_version_in_bare_repo()
with self.hook.configure_hook_env():
self._clone_bare_repo_if_required()
self._ensure_version_in_bare_repo()

self._clone_repo_if_required()
self.repo.git.checkout(self.tracking_ref)
if self.version:
Expand Down Expand Up @@ -230,8 +258,10 @@ def _fetch_bare_repo(self):
def refresh(self) -> None:
if self.version:
raise AirflowException("Refreshing a specific version is not supported")
self._fetch_bare_repo()
self.repo.remotes.origin.pull()

with self.hook.configure_hook_env():
self._fetch_bare_repo()
self.repo.remotes.origin.pull()

@staticmethod
def _convert_git_ssh_url_to_https(url: str) -> str:
Expand Down
78 changes: 71 additions & 7 deletions tests/dag_processing/test_dag_bundles.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import os
import re
import tempfile
from pathlib import Path
Expand Down Expand Up @@ -108,6 +109,8 @@ def git_repo(tmp_path_factory):
CONN_HTTPS = "my_git_conn"
CONN_HTTPS_PASSWORD = "my_git_conn_https_password"
CONN_ONLY_PATH = "my_git_conn_only_path"
CONN_ONLY_INLINE_KEY = "my_git_conn_only_inline_key"
CONN_BOTH_PATH_INLINE = "my_git_conn_both_path_inline"
CONN_NO_REPO_URL = "my_git_conn_no_repo_url"


Expand Down Expand Up @@ -146,6 +149,16 @@ def setup_class(cls) -> None:
conn_type="git",
)
)
db.merge_conn(
Connection(
conn_id=CONN_ONLY_INLINE_KEY,
host="path/to/repo",
conn_type="git",
extra={
"private_key": "inline_key",
},
)
)

@pytest.mark.parametrize(
"conn_id, expected_repo_url",
Expand All @@ -160,11 +173,12 @@ def test_correct_repo_urls(self, conn_id, expected_repo_url):
hook = GitHook(git_conn_id=conn_id)
assert hook.repo_url == expected_repo_url

def test_env_var(self, session):
hook = GitHook(git_conn_id=CONN_DEFAULT)
assert hook.env == {
"GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=no"
}
def test_env_var_with_configure_hook_env(self, session):
default_hook = GitHook(git_conn_id=CONN_DEFAULT)
with default_hook.configure_hook_env():
assert default_hook.env == {
"GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=no"
}
db.merge_conn(
Connection(
conn_id="my_git_conn_strict",
Expand All @@ -174,11 +188,61 @@ def test_env_var(self, session):
)
)

hook = GitHook(git_conn_id="my_git_conn_strict")
strict_default_hook = GitHook(git_conn_id="my_git_conn_strict")
with strict_default_hook.configure_hook_env():
assert strict_default_hook.env == {
"GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=yes"
}

def test_given_both_private_key_and_key_file(self):
db.merge_conn(
Connection(
conn_id=CONN_BOTH_PATH_INLINE,
host="path/to/repo",
conn_type="git",
extra={
"key_file": "path/to/key",
"private_key": "inline_key",
},
)
)

with pytest.raises(
AirflowException, match="Both 'key_file' and 'private_key' cannot be provided at the same time"
):
GitHook(git_conn_id=CONN_BOTH_PATH_INLINE)

def test_key_file_git_hook_has_env_with_configure_hook_env(self):
hook = GitHook(git_conn_id=CONN_DEFAULT)

assert hasattr(hook, "env")
with hook.configure_hook_env():
assert hook.env == {
"GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=no"
}

def test_private_key_lazy_env_var(self):
hook = GitHook(git_conn_id=CONN_ONLY_INLINE_KEY)
assert hook.env == {}

hook.set_git_env("dummy_inline_key")
assert hook.env == {
"GIT_SSH_COMMAND": "ssh -i /files/pkey.pem -o IdentitiesOnly=yes -o StrictHostKeyChecking=yes"
"GIT_SSH_COMMAND": "ssh -i dummy_inline_key -o IdentitiesOnly=yes -o StrictHostKeyChecking=no"
}

def test_configure_hook_env(self):
hook = GitHook(git_conn_id=CONN_ONLY_INLINE_KEY)
assert hasattr(hook, "private_key")

hook.set_git_env("dummy_inline_key")

with hook.configure_hook_env():
command = hook.env.get("GIT_SSH_COMMAND")
temp_key_path = command.split()[2]
assert os.path.exists(temp_key_path)

assert not os.path.exists(temp_key_path)


class TestGitDagBundle:
@classmethod
Expand Down

0 comments on commit 1654ddc

Please sign in to comment.