Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various fixes #29

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 66 additions & 17 deletions scp.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,22 @@ def get(self, remote_path, local_path='',
self.channel = self._open()
self._pushed = 0
self.channel.settimeout(self.socket_timeout)
self.channel.exec_command(b"scp" +
rcsv +
prsv +
b" -f " +
b' '.join(remote_path))
try:
self.channel.exec_command(b"scp" +
rcsv +
prsv +
b" -f " +
b' '.join(remote_path))
self._recv_all()
except:
# Check to see if we have some data on the channel.
data = self.channel.recv(self.buff_size)
if not data:
raise

code = data[0]
message = data[1:]
raise SCPException('%s %s' % (code, message))
self._recv_all()
self.close()

Expand Down Expand Up @@ -230,11 +241,14 @@ def _send_files(self, files):
self._send_time(mtime, atime)
file_hdl = open(name, 'rb')

send_name = basename
if os.name == 'nt':
send_name = send_name.encode('utf-8')
# The protocol can't handle \n in the filename.
# Quote them as the control sequence \^J for now,
# which is how openssh handles it.
self.channel.sendall(("C%s %d " % (mode, size)).encode('ascii') +
basename.replace(b'\n', b'\\^J') + b"\n")
send_name.replace(b'\n', b'\\^J') + b"\n")
self._recv_confirm()
file_pos = 0
if self._progress:
Expand Down Expand Up @@ -333,19 +347,27 @@ def _recv_all(self):
b'T': self._set_time,
b'D': self._recv_pushd,
b'E': self._recv_popd}
while not self.channel.closed:
# wait for command as long as we're open
self.channel.sendall('\x00')
msg = self.channel.recv(1024)
if not msg: # chan closed while recving
while True:
# Read next command
data = self.channel.recv(1024)
if not data:
# No more data to receive.
break
assert msg[-1:] == b'\n'
msg = msg[:-1]
code = msg[0:1]

if '\n' not in data:
# Command is not yet completely read.
data += self.channel.recv(1024)

code = data[0:1]
try:
command[code](msg[1:])
command[code](data[1:])
except KeyError:
raise SCPException(asunicode(msg[1:]))
raise SCPException(asunicode(data[1:]))

if not self.channel.closed:
# Confirm command end.
self.channel.sendall('\x00')

# directory times can't be set until we're done writing files
self._set_dirtimes()

Expand All @@ -362,6 +384,13 @@ def _set_time(self, cmd):

def _recv_file(self, cmd):
chan = self.channel
# Command might already contain data from the file.
# Command end with a new line then data follows.
parts = cmd.split('\n', 1)
cmd = parts[0]
already_received = ''
if len(parts) > 1:
already_received = parts[1]
parts = cmd.strip().split(b' ', 2)

try:
Expand Down Expand Up @@ -397,17 +426,37 @@ def _recv_file(self, cmd):
buff_size = self.buff_size
pos = 0
chan.send(b'\x00')

# Write data which was already received buffer.
if already_received:
data = already_received[:size]
already_received = already_received[size:]
file_hdl.write(data)
pos = file_hdl.tell()

try:
while pos < size:
# we have to make sure we don't read the final byte
if size - pos <= buff_size:
buff_size = size - pos
file_hdl.write(chan.recv(buff_size))
data = chan.recv(buff_size)

# Channel return empty string when no more data can be
# received (ex channel close, eof received)
if len(data) == 0:
pos = size

file_hdl.write(data)
pos = file_hdl.tell()
if self._progress:
self._progress(path, size, pos)

# Check final response code.
msg = chan.recv(512)
if already_received:
# We might already have received the status in the initial
# read. Ex for small files.
msg = already_received + msg
if msg and msg[0:1] != b'\x00':
raise SCPException(asunicode(msg[1:]))
except SocketTimeout:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@
long_description=description,
py_modules = ['scp'],
install_requires = ['paramiko'],
test_suite = 'test'
)
5 changes: 3 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
'username': os.environ.get('SCPPY_USERNAME', None),
}

TIMEOUT = 0.5

# Environment info
PY3 = sys.version_info >= (3,)
Expand Down Expand Up @@ -107,7 +108,7 @@ def download_test(self, filename, recursive, destination=None,
os.mkdir(temp_in)
os.chdir(temp_in)
try:
with SCPClient(self.ssh.get_transport()) as scp:
with SCPClient(self.ssh.get_transport(), socket_timeout=TIMEOUT) as scp:
scp.get(filename,
destination if destination is not None else u'.',
preserve_times=True, recursive=recursive)
Expand Down Expand Up @@ -206,7 +207,7 @@ def upload_test(self, filenames, recursive, expected=[]):
previous = os.getcwd()
try:
os.chdir(self._temp)
with SCPClient(self.ssh.get_transport()) as scp:
with SCPClient(self.ssh.get_transport(), socket_timeout=TIMEOUT) as scp:
scp.put(filenames, destination, recursive)

chan = self.ssh.get_transport().open_session()
Expand Down