diff --git a/airflow/dag_processing/bundles/git.py b/airflow/dag_processing/bundles/git.py index 40f0516fcf203e..bc746e063c2002 100644 --- a/airflow/dag_processing/bundles/git.py +++ b/airflow/dag_processing/bundles/git.py @@ -106,14 +106,16 @@ def set_git_env(self, key: str) -> dict[str, str]: return self.env @contextlib.contextmanager - def setup_inline_key(self): + def configure_hook_env(self): if self.private_key: with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp_keyfile: tmp_keyfile.write(self.private_key) tmp_keyfile.flush() os.chmod(tmp_keyfile.name, 0o600) - yield tmp_keyfile.name + self.set_git_env(tmp_keyfile.name) + yield else: + self.set_git_env(self.private_key) yield @@ -157,7 +159,7 @@ def __init__( self.log.warning("Could not create GitHook for connection %s : %s", self.git_conn_id, e) def _initialize(self): - with self.hook.setup_inline_key() as tmp_keyfile: + with self.hook.configure_hook_env() as tmp_keyfile: self.hook.env = self.hook.set_git_env(tmp_keyfile) self._clone_bare_repo_if_required() self._ensure_version_in_bare_repo() @@ -263,8 +265,7 @@ def refresh(self) -> None: if self.version: raise AirflowException("Refreshing a specific version is not supported") - with self.hook.setup_inline_key() as tmp_keyfile: - self.hook.env = self.hook.set_git_env(tmp_keyfile) + with self.hook.configure_hook_env(): self._fetch_bare_repo() self.repo.remotes.origin.pull() diff --git a/tests/dag_processing/test_dag_bundles.py b/tests/dag_processing/test_dag_bundles.py index 01309896599bce..c4d4ff59ba2449 100644 --- a/tests/dag_processing/test_dag_bundles.py +++ b/tests/dag_processing/test_dag_bundles.py @@ -227,16 +227,18 @@ def test_private_key_lazy_env_var(self): "GIT_SSH_COMMAND": "ssh -i dummy_inline_key -o IdentitiesOnly=yes -o StrictHostKeyChecking=no" } - def test_setup_inline_key(self): + def test_configure_hook_env(self): hook = GitHook(git_conn_id=CONN_ONLY_INLINE_KEY) assert hasattr(hook, "private_key") hook.set_git_env("dummy_inline_key") - with hook.setup_inline_key() as tmp_keyfile: - assert os.path.exists(tmp_keyfile) + with hook.configure_hook_env(): + command = hook.env.get("GIT_SSH_COMMAND") + temp_key_path = command.split()[2] + assert os.path.exists(temp_key_path) - assert not os.path.exists(tmp_keyfile) + assert not os.path.exists(temp_key_path) class TestGitDagBundle: