Skip to content

Commit

Permalink
Merge pull request #335 from doronz88/refactor/fs-pull-push
Browse files Browse the repository at this point in the history
fs: improve pull and push
  • Loading branch information
doronz88 authored Dec 26, 2023
2 parents 18ff583 + 2dc0dfe commit 9ef4301
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 104 deletions.
22 changes: 20 additions & 2 deletions src/rpcclient/rpcclient/darwin/fs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import List, Mapping
from pathlib import Path
from typing import Any, List, Mapping

from parameter_decorators import path_to_str

from rpcclient.darwin.structs import stat64, statfs64
from rpcclient.fs import Fs
from rpcclient.fs import Fs, RemotePath


def do_stat(client, stat_name, filename: str):
Expand All @@ -15,6 +16,20 @@ def do_stat(client, stat_name, filename: str):
return stat64.parse_stream(buf)


class DarwinRemotePath(RemotePath):
def __init__(self, path: str, client) -> None:
super().__init__(path, client)

def stat(self):
return do_stat(self._client, 'stat64', self._path)

def lstat(self):
return do_stat(self._client, 'lstat64', self._path)

def __truediv__(self, key: Path) -> Any:
return DarwinRemotePath(str(super().__truediv__(key)), self._client)


class DarwinFs(Fs):
@path_to_str('path')
def stat(self, path: str):
Expand Down Expand Up @@ -80,3 +95,6 @@ def chflags(self, path: str, flags: int) -> None:
""" call chflags(path, flags) at remote. see manpage for more info """
if 0 != self._client.symbols.chflags(path, flags):
self._client.raise_errno_exception(f'chflags failed for: {path}')

def _remote_path(self, path: str) -> DarwinRemotePath:
return DarwinRemotePath(path, self._client)
222 changes: 130 additions & 92 deletions src/rpcclient/rpcclient/fs.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import contextlib
import logging
import os
import posixpath
import stat
import tempfile
from contextlib import nullcontext
from pathlib import Path
from typing import List
from pathlib import Path, PosixPath
from typing import Any, List, Union

from click import progressbar
from parameter_decorators import path_to_str

from rpcclient.allocated import Allocated
from rpcclient.darwin.structs import MAXPATHLEN
from rpcclient.darwin.symbol import DarwinSymbol
from rpcclient.exceptions import ArgumentError, BadReturnValueError, RpcClientException, RpcFileExistsError, \
RpcFileNotFoundError, RpcIsADirectoryError
from rpcclient.exceptions import ArgumentError, BadReturnValueError, RpcFileExistsError, RpcFileNotFoundError, \
RpcIsADirectoryError
from rpcclient.structs.consts import DT_DIR, DT_LNK, DT_REG, DT_UNKNOWN, O_CREAT, O_RDONLY, O_RDWR, O_TRUNC, O_WRONLY, \
R_OK, S_IFDIR, S_IFLNK, S_IFMT, S_IFREG, SEEK_CUR, SEEK_END, SEEK_SET
R_OK, S_IFDIR, S_IFLNK, S_IFMT, S_IFREG, SEEK_CUR

logger = logging.getLogger(__name__)


class DirEntry:
Expand Down Expand Up @@ -130,11 +132,12 @@ def _write(self, buf: bytes) -> int:
self._client.raise_errno_exception(f'failed to write on fd: {self.fd}')
return n

def write(self, buf: bytes):
def write(self, buf: bytes) -> int:
""" continue call write() until """
while buf:
err = self._write(buf)
buf = buf[err:]
return len(buf)

def _read(self, buf: DarwinSymbol, size: int) -> bytes:
""" read file at remote """
Expand Down Expand Up @@ -200,12 +203,119 @@ def __repr__(self):
return f'<{self.__class__.__name__} FD:{self.fd}>'


class RemotePath(PosixPath):
def __init__(self, path: str, client) -> None:
super().__init__()
self._path = path
self._client = client

def __new__(cls, path: str, client):
return super().__new__(cls, *[path])

def chmod(self, mode: int):
return self._client.fs.chmod(self._path, mode)

def readlink(self):
return self._client.fs.readlink(self._path)

def exists(self) -> bool:
try:
self.stat()
return True
except Exception:
return False

def is_dir(self) -> bool:
return bool(self.stat().st_mode & S_IFDIR)

def lstat(self):
return self._client.fs.lstat(self._path)

def mkdir(self, mode: int):
self._client.fs.mkdir(self._path, mode)

def read_bytes(self) -> bytes:
with self._open('r') as f:
return f.read()

def stat(self):
raise self._client.fs.stat(self._path)

def symlink_to(self, target: Path, target_is_directory: bool = False) -> None:
return self._client.fs.symlink(target, self._path)

def write_bytes(self, buf: bytes) -> int:
with self._open('w') as f:
return f.write(buf)

def _open(self, mode: str, access: int = 0o777) -> File:
return self._client.fs.open(self._path, mode, access)

def __truediv__(self, key: Path) -> Any:
return RemotePath(str(super().__truediv__(key)), self._client)


class Fs:
""" filesystem utils """

def __init__(self, client):
self._client = client

def _cp_dir(self, source: Path, dest: Path, force: bool):
if not dest.exists():
dest.mkdir(source.lstat().st_mode)

files = self.listdir(str(source))
for file in files:
src_file = source / file
dest_file = dest / file

src_lstat = src_file.lstat()
if stat.S_ISDIR(src_lstat.st_mode):
self._cp_dir(src_file, dest_file, force)
elif stat.S_ISLNK(src_lstat.st_mode):
symlink_full = src_file.readlink()
dest_file.symlink_to(symlink_full)
elif dest_file.exists() and not force:
pass
else:
dest_file.write_bytes(src_file.read_bytes())

def _cp(self, sources: List[Path], dest: Path, recursive: bool, force: bool):
dest_exists = dest.exists()
is_dest_dir = dest_exists and dest.is_dir()

if ((not dest_exists or not is_dest_dir) and len(sources) > 1):
raise ArgumentError(f'target {dest} is not a directory')

if recursive:
if not dest_exists:
try:
dest.mkdir(0o777)
except Exception as e:
if not dest.exists():
raise e

for source in sources:
if not source.exists():
raise ArgumentError(f'cannot stat {source}: No such file or directory')

if source.is_dir():
if not recursive:
logger.info(f'omitting directory {source}')
else:
cur_dest = dest / source.name
source_mode = source.stat().st_mode
if not cur_dest.exists():
cur_dest.mkdir(source_mode)
self._cp_dir(source, cur_dest, force)
else: # source is a file
cur_dest = dest
if dest.exists() and dest.is_dir():
cur_dest = dest / source.name
if not cur_dest.exists() or force:
cur_dest.write_bytes(source.read_bytes())

@path_to_str('path')
def is_file(self, path: str) -> bool:
""" Return True if the entry is a file """
Expand Down Expand Up @@ -368,96 +478,24 @@ def read_file(self, file: str) -> bytes:
with self.open(file, 'r') as f:
return f.read()

@path_to_str('remote')
@path_to_str('local')
def _pull_file(self, remote: str, local: str, with_progress: bool):
with open(local, 'wb') as local_file, self.open(remote, 'r') as remote_file, \
self._client.safe_malloc(File.CHUNK_SIZE) as chunk:
remote_file.seek(0, SEEK_END)
remote_file_size = remote_file.tell()
remote_file.seek(0, SEEK_SET)

if with_progress:
progress_bar = progressbar(length=remote_file_size)
else:
progress_bar = nullcontext()
progress_bar.length = remote_file_size

with progress_bar:
buf = remote_file.read(File.CHUNK_SIZE, File.CHUNK_SIZE, chunk)
if with_progress:
progress_bar.update(len(buf))
while len(buf) > 0:
local_file.write(buf)
buf = remote_file.read(File.CHUNK_SIZE, File.CHUNK_SIZE, chunk)
if with_progress:
progress_bar.update(len(buf))

@path_to_str('remote')
@path_to_str('local')
def _push_file(self, local: str, remote: str):
with open(local, 'rb') as f:
self.write_file(remote, f.read())
def _remote_path(self, path: str) -> RemotePath:
return RemotePath(path, self._client)

@path_to_str('remote')
@path_to_str('remotes')
@path_to_str('local')
def pull(self, remote: str, local: str, onerror=None, with_progress=False):
def pull(self, remotes: Union[List[str], str], local: str, recursive: bool = False, force: bool = False):
""" pull complete directory tree """
if self.is_file(remote):
self._pull_file(remote, local, with_progress)
return

cwd = os.getcwd()
remote = Path(remote)
local = Path(local)

try:
for root, dirs, files in self.walk(remote, topdown=True, onerror=onerror):
local_root = local / Path(root).relative_to(remote)
local_root.mkdir(exist_ok=True)
os.chdir(str(local_root))
for name in dirs:
Path(name).mkdir(exist_ok=True)
for name in files:
try:
self._pull_file(os.path.join(root, name), name, with_progress)
except RpcClientException as e:
if onerror:
onerror(e)
else:
raise
finally:
os.chdir(cwd)
if not isinstance(remotes, list):
remotes = [remotes]
self._cp([self._remote_path(remote) for remote in remotes], Path(str(local)), recursive, force)

@path_to_str('local')
@path_to_str('locals')
@path_to_str('remote')
def push(self, local: str, remote: str, onerror=None):
def push(self, locals: Union[List[str], str], remote: str, recursive: bool = False, force: bool = False):
""" push complete directory tree """
cwd = self.pwd()
remote = Path(remote)
local = Path(local)

if local.is_file():
self._push_file(local, remote)
return

try:
for root, dirs, files in os.walk(local, topdown=True, onerror=onerror):
remote_root = remote / Path(root).relative_to(local)
self.mkdir(remote_root, exist_ok=True)
self.chdir(remote_root)
for name in dirs:
self.mkdir(name, exist_ok=True)
for name in files:
try:
self._push_file(os.path.join(root, name), name)
except RpcClientException as e:
if onerror:
onerror(e)
else:
raise
finally:
self.chdir(cwd)
if not isinstance(locals, list):
locals = [locals]
self._cp([Path(str(local)) for local in locals], self._remote_path(remote), recursive, force)

@path_to_str('file')
def touch(self, file: str, mode: int = None):
Expand Down
46 changes: 36 additions & 10 deletions src/rpcclient/rpcclient/xonshrc.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,17 +543,43 @@ def _rpc_bat(self, filename: Annotated[List[str], Arg(completer=path_completer)]
with self._remote_file(filename) as f:
os.system(f'bat "{f}"')

def _rpc_pull(self, remote: Annotated[str, Arg(completer=path_completer)], local: str):
def _rpc_pull(
self, files: Annotated[List[str],
Arg(nargs='+', completer=path_completer)],
recursive: bool = False, force: bool = False):
"""
pull a file from remote
pull files from remote
Parameters
----------
files :
remote files
recursive : -r, --recursive
remove recursively
force : -f, --force
ignore errors
"""
return self._pull(remote, local)
local = files.pop()
return self._pull(files, local, recursive, force)

def _rpc_push(self, local: str, remote: Annotated[str, Arg(completer=path_completer)]):
def _rpc_push(
self, files: Annotated[List[str],
Arg(nargs='+', completer=path_completer)],
recursive: bool = False, force: bool = False):
"""
push a file into remote
push files to remote
Parameters
----------
files :
remote files
recursive : -r, --recursive
remove recursively
force : -f, --force
ignore errors
"""
return self._push(local, remote)
remote = files.pop()
return self._push(files, remote, recursive, force)

def _rpc_chmod(self, mode: str, filename: Annotated[str, Arg(completer=path_completer)], recursive=False):
"""
Expand Down Expand Up @@ -687,11 +713,11 @@ def _relative_path(self, filename):
def _listdir(self, path: str) -> List[str]:
return self.client.fs.listdir(path)

def _pull(self, remote_filename, local_filename):
self.client.fs.pull(remote_filename, local_filename, onerror=lambda x: None, with_progress=True)
def _pull(self, remote_filename, local_filename, recursive: bool = False, force: bool = False):
self.client.fs.pull(remote_filename, local_filename, recursive, force)

def _push(self, local_filename, remote_filename):
self.client.fs.push(local_filename, remote_filename, onerror=lambda x: None)
def _push(self, local_filename, remote_filename, recursive: bool = False, force: bool = False):
self.client.fs.push(local_filename, remote_filename, recursive, force)


# actual RC contents
Expand Down
Loading

0 comments on commit 9ef4301

Please sign in to comment.