Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error handling #35

Merged
merged 1 commit into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ There isn't even proper exception handling, so expect some Rust panics to fall t
With that said, try it out and let me know your thoughts!

# Future Features
- Proper exception handling
- Concurrent actions class
- Async Connection class
- Low level bindings
Expand Down
151 changes: 107 additions & 44 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,14 @@
//!
//! Note: The `read` method sends an EOF to the shell, so you won't be able to send more commands after calling `read`. If you want to send more commands, you would need to create a new `InteractiveShell` instance.
use pyo3::create_exception;
use pyo3::exceptions::PyTimeoutError;
use pyo3::prelude::*;
use ssh2::{Channel, Session};
use std::io::{BufReader, BufWriter, Read, Seek, Write};
use std::net::TcpStream;
use std::path::Path;

use pyo3::exceptions::{PyIOError, PyTimeoutError};

const MAX_BUFF_SIZE: usize = 65536;
create_exception!(
connection,
Expand All @@ -70,7 +71,6 @@ create_exception!(
);

fn read_from_channel(channel: &mut Channel) -> Result<SSHResult, PyErr> {
// TODO: handle errors better instead of just raising a PyTimeoutError
let mut stdout = String::new();
channel
.read_to_string(&mut stdout)
Expand Down Expand Up @@ -323,23 +323,33 @@ impl Connection {
/// Otherwise, the contents of the file are returned as a string.
#[pyo3(signature = (remote_path, local_path=None))]
fn scp_read(&self, remote_path: String, local_path: Option<String>) -> PyResult<String> {
let (mut remote_file, stat) = self.session.scp_recv(Path::new(&remote_path)).unwrap();
let (mut remote_file, stat) = self
.session
.scp_recv(Path::new(&remote_path))
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Failed scp_recv: {}", e)))?;
match local_path {
Some(local_path) => {
let mut local_file = std::fs::File::create(local_path).unwrap();
let mut local_file = std::fs::File::create(&local_path)
.map_err(|e| PyErr::new::<PyIOError, _>(format!("File create error: {}", e)))?;
let mut buffer = vec![0; std::cmp::min(stat.size() as usize, MAX_BUFF_SIZE)];
loop {
let len = remote_file.read(&mut buffer).unwrap();
let len = remote_file
.read(&mut buffer)
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Read error: {}", e)))?;
if len == 0 {
break;
}
local_file.write_all(&buffer[..len]).unwrap();
local_file
.write_all(&buffer[..len])
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Write error: {}", e)))?;
}
Ok("Ok".to_string())
}
None => {
let mut contents = String::new();
remote_file.read_to_string(&mut contents).unwrap();
remote_file.read_to_string(&mut contents).map_err(|e| {
PyErr::new::<PyIOError, _>(format!("Read to string failed: {}", e))
})?;
Ok(contents)
}
}
Expand All @@ -361,21 +371,28 @@ impl Connection {
} else {
remote_path
};
let mut local_file = std::fs::File::open(&local_path).unwrap();
let mut local_file = std::fs::File::open(&local_path)
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Local file open error: {}", e)))?;
let metadata = local_file.metadata().unwrap();
// TODO: better handle permissions. Perhaps from metadata.permissions()?
let mut remote_file = self
.session
.scp_send(Path::new(&remote_path), 0o644, metadata.len(), None)
.unwrap();
.map_err(|e| PyErr::new::<PyIOError, _>(format!("scp_send error: {}", e)))?;
// create a variable-sized buffer to read the file and loop until EOF
let mut read_buffer = vec![0; std::cmp::min(metadata.len() as usize, MAX_BUFF_SIZE)];
loop {
let bytes_read = local_file.read(&mut read_buffer).unwrap();
let bytes_read = local_file
.read(&mut read_buffer)
.map_err(|e| PyErr::new::<PyIOError, _>(format!("File read error: {}", e)))?;
if bytes_read == 0 {
break;
}
remote_file.write_all(&read_buffer[..bytes_read]).unwrap();
remote_file
.write_all(&read_buffer[..bytes_read])
.map_err(|e| {
PyErr::new::<PyIOError, _>(format!("Remote file write error: {}", e))
})?;
}
remote_file.flush().unwrap();
remote_file.send_eof().unwrap();
Expand All @@ -390,8 +407,10 @@ impl Connection {
let mut remote_file = self
.session
.scp_send(Path::new(&remote_path), 0o644, data.len() as u64, None)
.unwrap();
remote_file.write_all(data.as_bytes()).unwrap();
.map_err(|e| PyErr::new::<PyIOError, _>(format!("scp_send error: {}", e)))?;
remote_file
.write_all(data.as_bytes())
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Data write error: {}", e)))?;
remote_file.send_eof().unwrap();
remote_file.wait_eof().unwrap();
remote_file.close().unwrap();
Expand All @@ -404,56 +423,83 @@ impl Connection {
/// Otherwise, the contents of the file are returned as a string.
#[pyo3(signature = (remote_path, local_path=None))]
fn sftp_read(&mut self, remote_path: String, local_path: Option<String>) -> PyResult<String> {
let mut remote_file = BufReader::new(self.sftp().open(Path::new(&remote_path)).unwrap());
let mut remote_file = BufReader::new(
self.sftp()
.open(Path::new(&remote_path))
.map_err(|e| PyErr::new::<PyIOError, _>(format!("SFTP open error: {}", e)))?,
);
match local_path {
Some(local_path) => {
let local_file = std::fs::File::create(local_path)?;
let local_file = std::fs::File::create(&local_path)
.map_err(|e| PyErr::new::<PyIOError, _>(format!("File create error: {}", e)))?;
let mut writer = BufWriter::new(local_file);
let mut buffer = vec![0; MAX_BUFF_SIZE];
loop {
let len = remote_file.read(&mut buffer)?;
let len = remote_file.read(&mut buffer).map_err(|e| {
PyErr::new::<PyIOError, _>(format!("File read error: {}", e))
})?;
if len == 0 {
break;
}
writer.write_all(&buffer[..len])?;
writer.write_all(&buffer[..len]).map_err(|e| {
PyErr::new::<PyIOError, _>(format!("File write error: {}", e))
})?;
}
writer.flush()?;
writer
.flush()
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Flush error: {}", e)))?;
Ok("Ok".to_string())
}
None => {
let mut contents = String::new();
remote_file.read_to_string(&mut contents)?;
remote_file.read_to_string(&mut contents).map_err(|e| {
PyErr::new::<PyIOError, _>(format!("Read to string failed: {}", e))
})?;
Ok(contents)
}
}
}

/// Writes a file over SFTP.
/// If `remote_path` is not provided, the local file is written to the same path on the remote system.
/// Writes a file over SFTP. If `remote_path` is not provided, the local file is written to the same path on the remote system.
#[pyo3(signature = (local_path, remote_path=None))]
fn sftp_write(&mut self, local_path: String, remote_path: Option<String>) -> PyResult<()> {
let mut local_file = std::fs::File::open(&local_path).unwrap();
let mut local_file = std::fs::File::open(&local_path)
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Local file open error: {}", e)))?;
let remote_path = remote_path.unwrap_or_else(|| local_path.clone());
let metadata = local_file.metadata().unwrap();
let mut remote_file = self.sftp().create(Path::new(&remote_path)).unwrap();
let mut remote_file = self.sftp().create(Path::new(&remote_path)).map_err(|e| {
PyErr::new::<PyIOError, _>(format!("Remote file creation error: {}", e))
})?;
// create a variable-sized buffer to read the file and loop until EOF
let mut read_buffer = vec![0; std::cmp::min(metadata.len() as usize, MAX_BUFF_SIZE)];
loop {
let bytes_read = local_file.read(&mut read_buffer)?;
let bytes_read = local_file
.read(&mut read_buffer)
.map_err(|e| PyErr::new::<PyIOError, _>(format!("File read error: {}", e)))?;
if bytes_read == 0 {
break;
}
remote_file.write_all(&read_buffer[..bytes_read])?;
remote_file
.write_all(&read_buffer[..bytes_read])
.map_err(|e| {
PyErr::new::<PyIOError, _>(format!("Remote file write error: {}", e))
})?;
}
remote_file.close().unwrap();
Ok(())
}

/// Writes data over SFTP.
fn sftp_write_data(&mut self, data: String, remote_path: String) -> PyResult<()> {
let mut remote_file = self.sftp().create(Path::new(&remote_path)).unwrap();
remote_file.write_all(data.as_bytes()).unwrap();
remote_file.close().unwrap();
let mut remote_file = self.sftp().create(Path::new(&remote_path)).map_err(|e| {
PyErr::new::<PyIOError, _>(format!("Remote file creation error: {}", e))
})?;
remote_file
.write_all(data.as_bytes())
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Data write error: {}", e)))?;
remote_file
.close()
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Close error: {}", e)))?;
Ok(())
}

Expand All @@ -468,19 +514,26 @@ impl Connection {
let mut remote_file = BufReader::new(
self.session
.sftp()
.unwrap()
.map_err(|e| PyErr::new::<PyIOError, _>(format!("SFTP error: {}", e)))?
.open(Path::new(&source_path))
.unwrap(),
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Remote open error: {}", e)))?,
);
let dest_path = dest_path.unwrap_or_else(|| source_path.clone());
let mut other_file = dest_conn.sftp().create(Path::new(&dest_path)).unwrap();
let mut other_file = dest_conn
.sftp()
.create(Path::new(&dest_path))
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Dest file creation error: {}", e)))?;
let mut buffer = vec![0; MAX_BUFF_SIZE];
loop {
let len = remote_file.read(&mut buffer).unwrap();
let len = remote_file
.read(&mut buffer)
.map_err(|e| PyErr::new::<PyIOError, _>(format!("File read error: {}", e)))?;
if len == 0 {
break;
}
other_file.write_all(&buffer[..len]).unwrap();
other_file
.write_all(&buffer[..len])
.map_err(|e| PyErr::new::<PyIOError, _>(format!("File write error: {}", e)))?;
}
Ok(())
}
Expand Down Expand Up @@ -586,12 +639,20 @@ impl InteractiveShell {
/// Reads the output from the shell and returns an `SSHResult`.
/// Note: This sends an EOF to the shell, so you won't be able to send more commands after calling `read`.
fn read(&mut self) -> PyResult<SSHResult> {
self.channel.channel.flush().unwrap();
self.channel.channel.send_eof().unwrap();
self.channel
.channel
.flush()
.map_err(|e| PyErr::new::<PyTimeoutError, _>(format!("Channel flush error: {}", e)))?;
self.channel
.channel
.send_eof()
.map_err(|e| PyErr::new::<PyTimeoutError, _>(format!("Send EOF error: {}", e)))?;
match read_from_channel(&mut self.channel.channel) {
Ok(result) => Ok(result),
Err(e) => {
self.channel.channel.close().unwrap();
self.channel.channel.close().map_err(|e| {
PyErr::new::<PyTimeoutError, _>(format!("Channel close error: {}", e))
})?;
self.result = None;
Err(e)
}
Expand Down Expand Up @@ -685,24 +746,26 @@ impl FileTailer {

// Determine the current end of the remote file
fn seek_end(&mut self) -> PyResult<Option<u64>> {
let len = self
let metadata = self
.sftp_conn
.stat(Path::new(&self.remote_file))
.unwrap()
.size;
self.last_pos = len.unwrap();
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Stat error: {}", e)))?;
self.last_pos = metadata.size.unwrap_or(0);
if self.init_pos.is_none() {
self.init_pos = len;
self.init_pos = metadata.size;
}
Ok(len)
Ok(metadata.size)
}

// Read the contents of the remote file from a given position
#[pyo3(signature = (from_pos=None))]
fn read(&mut self, from_pos: Option<u64>) -> String {
let from_pos = from_pos.unwrap_or(self.last_pos);
let mut remote_file =
BufReader::new(self.sftp_conn.open(Path::new(&self.remote_file)).unwrap());
let mut remote_file = BufReader::new(
self.sftp_conn
.open(Path::new(&self.remote_file))
.expect("Opening remote file failed"),
);
remote_file
.seek(std::io::SeekFrom::Start(from_pos))
.unwrap();
Expand Down
45 changes: 33 additions & 12 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,18 +188,6 @@ def test_hangup_shell_context(conn):
assert sh.result.stdout


def test_session_timeout():
"""Test that we can trigger a timeout on session handshake."""
with pytest.raises(TimeoutError):
Connection(host="localhost", port=8022, password="toor", timeout=10)


def test_command_timeout(conn):
"""Test that we can trigger a timeout on command execution."""
with pytest.raises(TimeoutError):
conn.execute("sleep 5", timeout=3000)


def test_remote_copy(conn, run_second_server):
"""Test that we can copy a file from one server to another."""
# First copy the test file to the first server
Expand All @@ -220,3 +208,36 @@ def test_tail(conn):
assert tf.last_pos == len(TEST_STR)
conn.execute("echo goodbye >> /root/hello.txt")
assert tf.contents == "goodbye\n"


# ------------- Negative Tests -------------


def test_session_timeout():
"""Test that we can trigger a timeout on session handshake."""
with pytest.raises(TimeoutError):
Connection(host="localhost", port=8022, password="toor", timeout=10)


def test_command_timeout(conn):
"""Test that we can trigger a timeout on command execution."""
with pytest.raises(TimeoutError):
conn.execute("sleep 5", timeout=3000)


def test_scp_write_missing_directory(conn):
"""Test that IOError is raised if scp_write attempts to write to a missing directory."""
with pytest.raises(IOError): # noqa: PT011
JacobCallahan marked this conversation as resolved.
Show resolved Hide resolved
conn.scp_write_data("data", "/no_such_dir/test.txt")


def test_sftp_read_invalid_path(conn):
"""Test that IOError is raised if sftp_read is given an invalid remote path."""
with pytest.raises(IOError): # noqa: PT011
conn.sftp_read("/invalid/path/file.txt")


def test_scp_read_directory_as_file(conn):
"""Test that IOError is raised if scp_read tries to read a directory as a file."""
with pytest.raises(IOError): # noqa: PT011
conn.scp_read("/root")
Loading