Skip to content

Commit

Permalink
Improve SFTP hook's directory transfer to use a single connection for…
Browse files Browse the repository at this point in the history
… multiple files (#46582)

* Improve SFTP directory transfer to use a single connection in multiple files

* Add with_conn wrapper

* Fix delete_directory

* Delete wrapper and update get_conn

* Add test code
  • Loading branch information
Dawnpool authored Mar 1, 2025
1 parent 28436fb commit 998fcd6
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 23 deletions.
69 changes: 48 additions & 21 deletions providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import stat
import warnings
from collections.abc import Generator, Sequence
from contextlib import closing, contextmanager
from contextlib import contextmanager
from fnmatch import fnmatch
from io import BytesIO
from pathlib import Path
Expand All @@ -38,6 +38,7 @@
from airflow.providers.ssh.hooks.ssh import SSHHook

if TYPE_CHECKING:
from paramiko import SSHClient
from paramiko.sftp_attr import SFTPAttributes
from paramiko.sftp_client import SFTPClient

Expand Down Expand Up @@ -110,6 +111,10 @@ def __init__(
kwargs["host_proxy_cmd"] = host_proxy_cmd
self.ssh_conn_id = ssh_conn_id

self._ssh_conn: SSHClient | None = None
self._sftp_conn: SFTPClient | None = None
self._conn_count = 0

super().__init__(*args, **kwargs)

def get_conn(self) -> SFTPClient: # type: ignore[override]
Expand All @@ -127,9 +132,25 @@ def close_conn(self) -> None:
@contextmanager
def get_managed_conn(self) -> Generator[SFTPClient, None, None]:
"""Context manager that closes the connection after use."""
with closing(super().get_conn()) as conn:
with closing(conn.open_sftp()) as sftp:
yield sftp
if self._sftp_conn is None:
ssh_conn: SSHClient = super().get_conn()
self._ssh_conn = ssh_conn
self._sftp_conn = ssh_conn.open_sftp()
self._conn_count += 1

try:
yield self._sftp_conn
finally:
self._conn_count -= 1
if self._conn_count == 0 and self._ssh_conn is not None and self._sftp_conn is not None:
self._sftp_conn.close()
self._sftp_conn = None
self._ssh_conn.close()
self._ssh_conn = None

def get_conn_count(self) -> int:
"""Get the number of open connections."""
return self._conn_count

def describe_directory(self, path: str) -> dict[str, dict[str, str | int | None]]:
"""
Expand Down Expand Up @@ -309,13 +330,14 @@ def retrieve_directory(self, remote_full_path: str, local_full_path: str, prefet
if Path(local_full_path).exists():
raise AirflowException(f"{local_full_path} already exists")
Path(local_full_path).mkdir(parents=True)
files, dirs, _ = self.get_tree_map(remote_full_path)
for dir_path in dirs:
new_local_path = os.path.join(local_full_path, os.path.relpath(dir_path, remote_full_path))
Path(new_local_path).mkdir(parents=True, exist_ok=True)
for file_path in files:
new_local_path = os.path.join(local_full_path, os.path.relpath(file_path, remote_full_path))
self.retrieve_file(file_path, new_local_path, prefetch)
with self.get_conn():
files, dirs, _ = self.get_tree_map(remote_full_path)
for dir_path in dirs:
new_local_path = os.path.join(local_full_path, os.path.relpath(dir_path, remote_full_path))
Path(new_local_path).mkdir(parents=True, exist_ok=True)
for file_path in files:
new_local_path = os.path.join(local_full_path, os.path.relpath(file_path, remote_full_path))
self.retrieve_file(file_path, new_local_path, prefetch)

def store_directory(self, remote_full_path: str, local_full_path: str, confirm: bool = True) -> None:
"""
Expand All @@ -329,16 +351,21 @@ def store_directory(self, remote_full_path: str, local_full_path: str, confirm:
"""
if self.path_exists(remote_full_path):
raise AirflowException(f"{remote_full_path} already exists")
self.create_directory(remote_full_path)
for root, dirs, files in os.walk(local_full_path):
for dir_name in dirs:
dir_path = os.path.join(root, dir_name)
new_remote_path = os.path.join(remote_full_path, os.path.relpath(dir_path, local_full_path))
self.create_directory(new_remote_path)
for file_name in files:
file_path = os.path.join(root, file_name)
new_remote_path = os.path.join(remote_full_path, os.path.relpath(file_path, local_full_path))
self.store_file(new_remote_path, file_path, confirm)
with self.get_conn():
self.create_directory(remote_full_path)
for root, dirs, files in os.walk(local_full_path):
for dir_name in dirs:
dir_path = os.path.join(root, dir_name)
new_remote_path = os.path.join(
remote_full_path, os.path.relpath(dir_path, local_full_path)
)
self.create_directory(new_remote_path)
for file_name in files:
file_path = os.path.join(root, file_name)
new_remote_path = os.path.join(
remote_full_path, os.path.relpath(file_path, local_full_path)
)
self.store_file(new_remote_path, file_path, confirm)

def get_mod_time(self, path: str) -> str:
"""
Expand Down
10 changes: 8 additions & 2 deletions providers/sftp/tests/unit/sftp/hooks/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,14 @@ def test_close_conn(self):
assert self.hook.conn is None

def test_get_managed_conn(self):
with self.hook.get_managed_conn() as conn:
assert isinstance(conn, paramiko.SFTPClient)
with self.hook.get_managed_conn() as conn1:
assert isinstance(conn1, paramiko.SFTPClient)
with self.hook.get_managed_conn() as conn2:
assert conn1 == conn2
assert self.hook.get_conn_count() == 2
assert self.hook.get_conn_count() == 1
assert self.hook.get_conn_count() == 0
assert self.hook.conn is None

@patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn")
def test_get_close_conn(self, mock_get_conn):
Expand Down

0 comments on commit 998fcd6

Please sign in to comment.