Skip to content

Commit

Permalink
fs: improve pull file speed and add progress bar
Browse files Browse the repository at this point in the history
  • Loading branch information
loriwitt committed Nov 30, 2023
1 parent c032f03 commit bc264d6
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 18 deletions.
59 changes: 42 additions & 17 deletions src/rpcclient/rpcclient/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import os
import posixpath
import tempfile
from contextlib import nullcontext
from pathlib import Path
from typing import List

from click import progressbar
from parameter_decorators import path_to_str

from rpcclient.allocated import Allocated
Expand All @@ -13,7 +15,7 @@
from rpcclient.exceptions import ArgumentError, BadReturnValueError, RpcClientException, RpcFileExistsError, \
RpcFileNotFoundError, RpcIsADirectoryError
from rpcclient.structs.consts import DT_DIR, DT_LNK, DT_REG, DT_UNKNOWN, O_CREAT, O_RDONLY, O_RDWR, O_TRUNC, O_WRONLY, \
R_OK, S_IFDIR, S_IFLNK, S_IFMT, S_IFREG, SEEK_CUR
R_OK, S_IFDIR, S_IFLNK, S_IFMT, S_IFREG, SEEK_CUR, SEEK_END, SEEK_SET


class DirEntry:
Expand Down Expand Up @@ -141,19 +143,27 @@ def _read(self, buf: DarwinSymbol, size: int) -> bytes:
self._client.raise_errno_exception(f'read() failed for fd: {self.fd}')
return buf.peek(err)

def read(self, size: int = -1, chunk_size: int = CHUNK_SIZE) -> bytes:
def read_using_chunk(self, chunk: DarwinSymbol, chunk_size: int, size: int) -> bytes:
buf = b''
while size == -1 or len(buf) < size:
read_chunk = self._read(chunk, chunk_size)
if not read_chunk:
# EOF
break
buf += read_chunk
return buf

def read(self, size: int = -1, chunk_size: int = CHUNK_SIZE, chunk: DarwinSymbol = None) -> bytes:
""" read file at remote """
if size != -1 and size < chunk_size:
chunk_size = size

buf = b''
with self._client.safe_malloc(chunk_size) as chunk:
while size == -1 or len(buf) < size:
read_chunk = self._read(chunk, chunk_size)
if not read_chunk:
# EOF
break
buf += read_chunk
if chunk:
return self.read_using_chunk(chunk, size, chunk_size)
else:
with self._client.safe_malloc(chunk_size) as temp_chunk:
return self.read_using_chunk(temp_chunk, size, chunk_size)
return buf

def pread(self, length: int, offset: int) -> bytes:
Expand Down Expand Up @@ -360,13 +370,28 @@ def read_file(self, file: str) -> bytes:

@path_to_str('remote')
@path_to_str('local')
def _pull_file(self, remote: str, local: str):
def _pull_file(self, remote: str, local: str, with_progress: bool):
with open(local, 'wb') as local_file:
with self.open(remote, 'r') as remote_file:
buf = remote_file.read(File.CHUNK_SIZE)
while len(buf) > 0:
local_file.write(buf)
buf = remote_file.read(File.CHUNK_SIZE)
remote_file.seek(0, SEEK_END)
remote_file_size = remote_file.tell()
remote_file.seek(0, SEEK_SET)

if with_progress:
progress_bar = progressbar(length=remote_file_size)
else:
progress_bar = nullcontext()

with progress_bar:
with self._client.safe_malloc(File.CHUNK_SIZE) as chunk:
buf = remote_file.read(File.CHUNK_SIZE, File.CHUNK_SIZE, chunk)
if with_progress:
progress_bar.update(len(buf))
while len(buf) > 0:
local_file.write(buf)
buf = remote_file.read(File.CHUNK_SIZE, File.CHUNK_SIZE, chunk)
if with_progress:
progress_bar.update(len(buf))

@path_to_str('remote')
@path_to_str('local')
Expand All @@ -376,10 +401,10 @@ def _push_file(self, local: str, remote: str):

@path_to_str('remote')
@path_to_str('local')
def pull(self, remote: str, local: str, onerror=None):
def pull(self, remote: str, local: str, onerror=None, with_progress=False):
""" pull complete directory tree """
if self.is_file(remote):
self._pull_file(remote, local)
self._pull_file(remote, local, with_progress)
return

cwd = os.getcwd()
Expand All @@ -395,7 +420,7 @@ def pull(self, remote: str, local: str, onerror=None):
Path(name).mkdir(exist_ok=True)
for name in files:
try:
self._pull_file(os.path.join(root, name), name)
self._pull_file(os.path.join(root, name), name, with_progress)
except RpcClientException as e:
if onerror:
onerror(e)
Expand Down
2 changes: 1 addition & 1 deletion src/rpcclient/rpcclient/xonshrc.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def _listdir(self, path: str) -> List[str]:
return self.client.fs.listdir(path)

def _pull(self, remote_filename, local_filename):
self.client.fs.pull(remote_filename, local_filename, onerror=lambda x: None)
self.client.fs.pull(remote_filename, local_filename, onerror=lambda x: None, with_progress=True)

def _push(self, local_filename, remote_filename):
self.client.fs.push(local_filename, remote_filename, onerror=lambda x: None)
Expand Down

0 comments on commit bc264d6

Please sign in to comment.