From 2dc0dfeb2fe2d16b5a7427ef15bb01acaa24f75d Mon Sep 17 00:00:00 2001 From: Lori Witt Date: Mon, 25 Dec 2023 14:08:06 +0200 Subject: [PATCH] fs: improve pull and push --- src/rpcclient/rpcclient/darwin/fs.py | 22 ++- src/rpcclient/rpcclient/fs.py | 222 ++++++++++++++++----------- src/rpcclient/rpcclient/xonshrc.py | 46 ++++-- src/rpcclient/tests/test_fs.py | 22 +++ 4 files changed, 208 insertions(+), 104 deletions(-) diff --git a/src/rpcclient/rpcclient/darwin/fs.py b/src/rpcclient/rpcclient/darwin/fs.py index 31072c96..15c4854d 100644 --- a/src/rpcclient/rpcclient/darwin/fs.py +++ b/src/rpcclient/rpcclient/darwin/fs.py @@ -1,9 +1,10 @@ -from typing import List, Mapping +from pathlib import Path +from typing import Any, List, Mapping from parameter_decorators import path_to_str from rpcclient.darwin.structs import stat64, statfs64 -from rpcclient.fs import Fs +from rpcclient.fs import Fs, RemotePath def do_stat(client, stat_name, filename: str): @@ -15,6 +16,20 @@ def do_stat(client, stat_name, filename: str): return stat64.parse_stream(buf) +class DarwinRemotePath(RemotePath): + def __init__(self, path: str, client) -> None: + super().__init__(path, client) + + def stat(self): + return do_stat(self._client, 'stat64', self._path) + + def lstat(self): + return do_stat(self._client, 'lstat64', self._path) + + def __truediv__(self, key: Path) -> Any: + return DarwinRemotePath(str(super().__truediv__(key)), self._client) + + class DarwinFs(Fs): @path_to_str('path') def stat(self, path: str): @@ -80,3 +95,6 @@ def chflags(self, path: str, flags: int) -> None: """ call chflags(path, flags) at remote. see manpage for more info """ if 0 != self._client.symbols.chflags(path, flags): self._client.raise_errno_exception(f'chflags failed for: {path}') + + def _remote_path(self, path: str) -> DarwinRemotePath: + return DarwinRemotePath(path, self._client) diff --git a/src/rpcclient/rpcclient/fs.py b/src/rpcclient/rpcclient/fs.py index 49745c0b..f0711d6e 100644 --- a/src/rpcclient/rpcclient/fs.py +++ b/src/rpcclient/rpcclient/fs.py @@ -1,21 +1,23 @@ import contextlib +import logging import os import posixpath +import stat import tempfile -from contextlib import nullcontext -from pathlib import Path -from typing import List +from pathlib import Path, PosixPath +from typing import Any, List, Union -from click import progressbar from parameter_decorators import path_to_str from rpcclient.allocated import Allocated from rpcclient.darwin.structs import MAXPATHLEN from rpcclient.darwin.symbol import DarwinSymbol -from rpcclient.exceptions import ArgumentError, BadReturnValueError, RpcClientException, RpcFileExistsError, \ - RpcFileNotFoundError, RpcIsADirectoryError +from rpcclient.exceptions import ArgumentError, BadReturnValueError, RpcFileExistsError, RpcFileNotFoundError, \ + RpcIsADirectoryError from rpcclient.structs.consts import DT_DIR, DT_LNK, DT_REG, DT_UNKNOWN, O_CREAT, O_RDONLY, O_RDWR, O_TRUNC, O_WRONLY, \ - R_OK, S_IFDIR, S_IFLNK, S_IFMT, S_IFREG, SEEK_CUR, SEEK_END, SEEK_SET + R_OK, S_IFDIR, S_IFLNK, S_IFMT, S_IFREG, SEEK_CUR + +logger = logging.getLogger(__name__) class DirEntry: @@ -130,11 +132,12 @@ def _write(self, buf: bytes) -> int: self._client.raise_errno_exception(f'failed to write on fd: {self.fd}') return n - def write(self, buf: bytes): + def write(self, buf: bytes) -> int: """ continue call write() until """ while buf: err = self._write(buf) buf = buf[err:] + return len(buf) def _read(self, buf: DarwinSymbol, size: int) -> bytes: """ read file at remote """ @@ -200,12 +203,119 @@ def __repr__(self): return f'<{self.__class__.__name__} FD:{self.fd}>' +class RemotePath(PosixPath): + def __init__(self, path: str, client) -> None: + super().__init__() + self._path = path + self._client = client + + def __new__(cls, path: str, client): + return super().__new__(cls, *[path]) + + def chmod(self, mode: int): + return self._client.fs.chmod(self._path, mode) + + def readlink(self): + return self._client.fs.readlink(self._path) + + def exists(self) -> bool: + try: + self.stat() + return True + except Exception: + return False + + def is_dir(self) -> bool: + return bool(self.stat().st_mode & S_IFDIR) + + def lstat(self): + return self._client.fs.lstat(self._path) + + def mkdir(self, mode: int): + self._client.fs.mkdir(self._path, mode) + + def read_bytes(self) -> bytes: + with self._open('r') as f: + return f.read() + + def stat(self): + raise self._client.fs.stat(self._path) + + def symlink_to(self, target: Path, target_is_directory: bool = False) -> None: + return self._client.fs.symlink(target, self._path) + + def write_bytes(self, buf: bytes) -> int: + with self._open('w') as f: + return f.write(buf) + + def _open(self, mode: str, access: int = 0o777) -> File: + return self._client.fs.open(self._path, mode, access) + + def __truediv__(self, key: Path) -> Any: + return RemotePath(str(super().__truediv__(key)), self._client) + + class Fs: """ filesystem utils """ def __init__(self, client): self._client = client + def _cp_dir(self, source: Path, dest: Path, force: bool): + if not dest.exists(): + dest.mkdir(source.lstat().st_mode) + + files = self.listdir(str(source)) + for file in files: + src_file = source / file + dest_file = dest / file + + src_lstat = src_file.lstat() + if stat.S_ISDIR(src_lstat.st_mode): + self._cp_dir(src_file, dest_file, force) + elif stat.S_ISLNK(src_lstat.st_mode): + symlink_full = src_file.readlink() + dest_file.symlink_to(symlink_full) + elif dest_file.exists() and not force: + pass + else: + dest_file.write_bytes(src_file.read_bytes()) + + def _cp(self, sources: List[Path], dest: Path, recursive: bool, force: bool): + dest_exists = dest.exists() + is_dest_dir = dest_exists and dest.is_dir() + + if ((not dest_exists or not is_dest_dir) and len(sources) > 1): + raise ArgumentError(f'target {dest} is not a directory') + + if recursive: + if not dest_exists: + try: + dest.mkdir(0o777) + except Exception as e: + if not dest.exists(): + raise e + + for source in sources: + if not source.exists(): + raise ArgumentError(f'cannot stat {source}: No such file or directory') + + if source.is_dir(): + if not recursive: + logger.info(f'omitting directory {source}') + else: + cur_dest = dest / source.name + source_mode = source.stat().st_mode + if not cur_dest.exists(): + cur_dest.mkdir(source_mode) + self._cp_dir(source, cur_dest, force) + else: # source is a file + cur_dest = dest + if dest.exists() and dest.is_dir(): + cur_dest = dest / source.name + if not cur_dest.exists() or force: + cur_dest.write_bytes(source.read_bytes()) + @path_to_str('path') def is_file(self, path: str) -> bool: """ Return True if the entry is a file """ @@ -368,96 +478,24 @@ def read_file(self, file: str) -> bytes: with self.open(file, 'r') as f: return f.read() - @path_to_str('remote') - @path_to_str('local') - def _pull_file(self, remote: str, local: str, with_progress: bool): - with open(local, 'wb') as local_file, self.open(remote, 'r') as remote_file, \ - self._client.safe_malloc(File.CHUNK_SIZE) as chunk: - remote_file.seek(0, SEEK_END) - remote_file_size = remote_file.tell() - remote_file.seek(0, SEEK_SET) - - if with_progress: - progress_bar = progressbar(length=remote_file_size) - else: - progress_bar = nullcontext() - progress_bar.length = remote_file_size - - with progress_bar: - buf = remote_file.read(File.CHUNK_SIZE, File.CHUNK_SIZE, chunk) - if with_progress: - progress_bar.update(len(buf)) - while len(buf) > 0: - local_file.write(buf) - buf = remote_file.read(File.CHUNK_SIZE, File.CHUNK_SIZE, chunk) - if with_progress: - progress_bar.update(len(buf)) - - @path_to_str('remote') - @path_to_str('local') - def _push_file(self, local: str, remote: str): - with open(local, 'rb') as f: - self.write_file(remote, f.read()) + def _remote_path(self, path: str) -> RemotePath: + return RemotePath(path, self._client) - @path_to_str('remote') + @path_to_str('remotes') @path_to_str('local') - def pull(self, remote: str, local: str, onerror=None, with_progress=False): + def pull(self, remotes: Union[List[str], str], local: str, recursive: bool = False, force: bool = False): """ pull complete directory tree """ - if self.is_file(remote): - self._pull_file(remote, local, with_progress) - return - - cwd = os.getcwd() - remote = Path(remote) - local = Path(local) - - try: - for root, dirs, files in self.walk(remote, topdown=True, onerror=onerror): - local_root = local / Path(root).relative_to(remote) - local_root.mkdir(exist_ok=True) - os.chdir(str(local_root)) - for name in dirs: - Path(name).mkdir(exist_ok=True) - for name in files: - try: - self._pull_file(os.path.join(root, name), name, with_progress) - except RpcClientException as e: - if onerror: - onerror(e) - else: - raise - finally: - os.chdir(cwd) + if not isinstance(remotes, list): + remotes = [remotes] + self._cp([self._remote_path(remote) for remote in remotes], Path(str(local)), recursive, force) - @path_to_str('local') + @path_to_str('locals') @path_to_str('remote') - def push(self, local: str, remote: str, onerror=None): + def push(self, locals: Union[List[str], str], remote: str, recursive: bool = False, force: bool = False): """ push complete directory tree """ - cwd = self.pwd() - remote = Path(remote) - local = Path(local) - - if local.is_file(): - self._push_file(local, remote) - return - - try: - for root, dirs, files in os.walk(local, topdown=True, onerror=onerror): - remote_root = remote / Path(root).relative_to(local) - self.mkdir(remote_root, exist_ok=True) - self.chdir(remote_root) - for name in dirs: - self.mkdir(name, exist_ok=True) - for name in files: - try: - self._push_file(os.path.join(root, name), name) - except RpcClientException as e: - if onerror: - onerror(e) - else: - raise - finally: - self.chdir(cwd) + if not isinstance(locals, list): + locals = [locals] + self._cp([Path(str(local)) for local in locals], self._remote_path(remote), recursive, force) @path_to_str('file') def touch(self, file: str, mode: int = None): diff --git a/src/rpcclient/rpcclient/xonshrc.py b/src/rpcclient/rpcclient/xonshrc.py index 20c8ca58..4389ff29 100644 --- a/src/rpcclient/rpcclient/xonshrc.py +++ b/src/rpcclient/rpcclient/xonshrc.py @@ -543,17 +543,43 @@ def _rpc_bat(self, filename: Annotated[List[str], Arg(completer=path_completer)] with self._remote_file(filename) as f: os.system(f'bat "{f}"') - def _rpc_pull(self, remote: Annotated[str, Arg(completer=path_completer)], local: str): + def _rpc_pull( + self, files: Annotated[List[str], + Arg(nargs='+', completer=path_completer)], + recursive: bool = False, force: bool = False): """ - pull a file from remote + pull files from remote + + Parameters + ---------- + files : + remote files + recursive : -r, --recursive + remove recursively + force : -f, --force + ignore errors """ - return self._pull(remote, local) + local = files.pop() + return self._pull(files, local, recursive, force) - def _rpc_push(self, local: str, remote: Annotated[str, Arg(completer=path_completer)]): + def _rpc_push( + self, files: Annotated[List[str], + Arg(nargs='+', completer=path_completer)], + recursive: bool = False, force: bool = False): """ - push a file into remote + push files to remote + + Parameters + ---------- + files : + remote files + recursive : -r, --recursive + remove recursively + force : -f, --force + ignore errors """ - return self._push(local, remote) + remote = files.pop() + return self._push(files, remote, recursive, force) def _rpc_chmod(self, mode: str, filename: Annotated[str, Arg(completer=path_completer)], recursive=False): """ @@ -687,11 +713,11 @@ def _relative_path(self, filename): def _listdir(self, path: str) -> List[str]: return self.client.fs.listdir(path) - def _pull(self, remote_filename, local_filename): - self.client.fs.pull(remote_filename, local_filename, onerror=lambda x: None, with_progress=True) + def _pull(self, remote_filename, local_filename, recursive: bool = False, force: bool = False): + self.client.fs.pull(remote_filename, local_filename, recursive, force) - def _push(self, local_filename, remote_filename): - self.client.fs.push(local_filename, remote_filename, onerror=lambda x: None) + def _push(self, local_filename, remote_filename, recursive: bool = False, force: bool = False): + self.client.fs.push(local_filename, remote_filename, recursive, force) # actual RC contents diff --git a/src/rpcclient/tests/test_fs.py b/src/rpcclient/tests/test_fs.py index 11697bfc..1ce136b9 100644 --- a/src/rpcclient/tests/test_fs.py +++ b/src/rpcclient/tests/test_fs.py @@ -1,3 +1,4 @@ +import tempfile from pathlib import Path from stat import S_IMODE @@ -83,6 +84,7 @@ def test_push_pull_with_different_sizes(client, tmp_path, file_size): f.write(b'\0' * file_size) client.fs.push(local, remote) + assert client.fs.lstat(remote).st_size == file_size client.fs.pull(remote, local_pull) assert local_pull.stat().st_size == file_size @@ -90,6 +92,26 @@ def test_push_pull_with_different_sizes(client, tmp_path, file_size): local_pull.unlink(missing_ok=True) +def test_pull(client, tmp_path): + client.fs.touch(tmp_path / 'a') + with tempfile.TemporaryDirectory() as local_dir: + local_dir = Path(local_dir) + client.fs.pull(tmp_path / 'a', local_dir) + assert (local_dir / 'a').exists() + with tempfile.TemporaryDirectory() as local_dir: + local_dir = Path(local_dir) + client.fs.pull(tmp_path / 'a', local_dir / 'a') + assert (local_dir / 'a').exists() + + client.fs.mkdir(tmp_path / 'b') + with tempfile.TemporaryDirectory() as local_dir: + local_dir = Path(local_dir) + client.fs.pull(tmp_path / 'b', local_dir, recursive=True) + assert (local_dir / 'b').exists() + client.fs.pull(tmp_path / 'b', local_dir / 'b', recursive=True) + assert (local_dir / 'b' / 'b').exists() + + def test_scandir_sanity(client, tmp_path): entries = [e for e in client.fs.scandir(tmp_path)] assert not entries