Skip to content

Commit

Permalink
Improve error handling
Browse files Browse the repository at this point in the history
This change introduces a ton more situations where errors are passed up
instead of resulting in an unrecoverable panic.
  • Loading branch information
JacobCallahan committed Dec 24, 2024
1 parent 6244edd commit 5fc4e76
Showing 1 changed file with 107 additions and 43 deletions.
150 changes: 107 additions & 43 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,14 @@
//!
//! Note: The `read` method sends an EOF to the shell, so you won't be able to send more commands after calling `read`. If you want to send more commands, you would need to create a new `InteractiveShell` instance.
use pyo3::create_exception;
use pyo3::exceptions::PyTimeoutError;
use pyo3::prelude::*;
use ssh2::{Channel, Session};
use std::io::{BufReader, BufWriter, Read, Seek, Write};
use std::net::TcpStream;
use std::path::Path;

use pyo3::exceptions::{PyIOError, PyTimeoutError};

const MAX_BUFF_SIZE: usize = 65536;
create_exception!(
connection,
Expand Down Expand Up @@ -323,23 +324,33 @@ impl Connection {
/// Otherwise, the contents of the file are returned as a string.
#[pyo3(signature = (remote_path, local_path=None))]
fn scp_read(&self, remote_path: String, local_path: Option<String>) -> PyResult<String> {
let (mut remote_file, stat) = self.session.scp_recv(Path::new(&remote_path)).unwrap();
let (mut remote_file, stat) = self
.session
.scp_recv(Path::new(&remote_path))
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Failed scp_recv: {}", e)))?;
match local_path {
Some(local_path) => {
let mut local_file = std::fs::File::create(local_path).unwrap();
let mut local_file = std::fs::File::create(&local_path)
.map_err(|e| PyErr::new::<PyIOError, _>(format!("File create error: {}", e)))?;
let mut buffer = vec![0; std::cmp::min(stat.size() as usize, MAX_BUFF_SIZE)];
loop {
let len = remote_file.read(&mut buffer).unwrap();
let len = remote_file
.read(&mut buffer)
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Read error: {}", e)))?;
if len == 0 {
break;
}
local_file.write_all(&buffer[..len]).unwrap();
local_file
.write_all(&buffer[..len])
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Write error: {}", e)))?;
}
Ok("Ok".to_string())
}
None => {
let mut contents = String::new();
remote_file.read_to_string(&mut contents).unwrap();
remote_file.read_to_string(&mut contents).map_err(|e| {
PyErr::new::<PyIOError, _>(format!("Read to string failed: {}", e))
})?;
Ok(contents)
}
}
Expand All @@ -361,21 +372,28 @@ impl Connection {
} else {
remote_path
};
let mut local_file = std::fs::File::open(&local_path).unwrap();
let mut local_file = std::fs::File::open(&local_path)
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Local file open error: {}", e)))?;
let metadata = local_file.metadata().unwrap();
// TODO: better handle permissions. Perhaps from metadata.permissions()?
let mut remote_file = self
.session
.scp_send(Path::new(&remote_path), 0o644, metadata.len(), None)
.unwrap();
.map_err(|e| PyErr::new::<PyIOError, _>(format!("scp_send error: {}", e)))?;
// create a variable-sized buffer to read the file and loop until EOF
let mut read_buffer = vec![0; std::cmp::min(metadata.len() as usize, MAX_BUFF_SIZE)];
loop {
let bytes_read = local_file.read(&mut read_buffer).unwrap();
let bytes_read = local_file
.read(&mut read_buffer)
.map_err(|e| PyErr::new::<PyIOError, _>(format!("File read error: {}", e)))?;
if bytes_read == 0 {
break;
}
remote_file.write_all(&read_buffer[..bytes_read]).unwrap();
remote_file
.write_all(&read_buffer[..bytes_read])
.map_err(|e| {
PyErr::new::<PyIOError, _>(format!("Remote file write error: {}", e))
})?;
}
remote_file.flush().unwrap();
remote_file.send_eof().unwrap();
Expand All @@ -390,8 +408,10 @@ impl Connection {
let mut remote_file = self
.session
.scp_send(Path::new(&remote_path), 0o644, data.len() as u64, None)
.unwrap();
remote_file.write_all(data.as_bytes()).unwrap();
.map_err(|e| PyErr::new::<PyIOError, _>(format!("scp_send error: {}", e)))?;
remote_file
.write_all(data.as_bytes())
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Data write error: {}", e)))?;
remote_file.send_eof().unwrap();
remote_file.wait_eof().unwrap();
remote_file.close().unwrap();
Expand All @@ -404,56 +424,83 @@ impl Connection {
/// Otherwise, the contents of the file are returned as a string.
#[pyo3(signature = (remote_path, local_path=None))]
fn sftp_read(&mut self, remote_path: String, local_path: Option<String>) -> PyResult<String> {
let mut remote_file = BufReader::new(self.sftp().open(Path::new(&remote_path)).unwrap());
let mut remote_file = BufReader::new(
self.sftp()
.open(Path::new(&remote_path))
.map_err(|e| PyErr::new::<PyIOError, _>(format!("SFTP open error: {}", e)))?,
);
match local_path {
Some(local_path) => {
let local_file = std::fs::File::create(local_path)?;
let local_file = std::fs::File::create(&local_path)
.map_err(|e| PyErr::new::<PyIOError, _>(format!("File create error: {}", e)))?;
let mut writer = BufWriter::new(local_file);
let mut buffer = vec![0; MAX_BUFF_SIZE];
loop {
let len = remote_file.read(&mut buffer)?;
let len = remote_file.read(&mut buffer).map_err(|e| {
PyErr::new::<PyIOError, _>(format!("File read error: {}", e))
})?;
if len == 0 {
break;
}
writer.write_all(&buffer[..len])?;
writer.write_all(&buffer[..len]).map_err(|e| {
PyErr::new::<PyIOError, _>(format!("File write error: {}", e))
})?;
}
writer.flush()?;
writer
.flush()
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Flush error: {}", e)))?;
Ok("Ok".to_string())
}
None => {
let mut contents = String::new();
remote_file.read_to_string(&mut contents)?;
remote_file.read_to_string(&mut contents).map_err(|e| {
PyErr::new::<PyIOError, _>(format!("Read to string failed: {}", e))
})?;
Ok(contents)
}
}
}

/// Writes a file over SFTP.
/// If `remote_path` is not provided, the local file is written to the same path on the remote system.
/// Writes a file over SFTP. If `remote_path` is not provided, the local file is written to the same path on the remote system.
#[pyo3(signature = (local_path, remote_path=None))]
fn sftp_write(&mut self, local_path: String, remote_path: Option<String>) -> PyResult<()> {
let mut local_file = std::fs::File::open(&local_path).unwrap();
let mut local_file = std::fs::File::open(&local_path)
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Local file open error: {}", e)))?;
let remote_path = remote_path.unwrap_or_else(|| local_path.clone());
let metadata = local_file.metadata().unwrap();
let mut remote_file = self.sftp().create(Path::new(&remote_path)).unwrap();
let mut remote_file = self.sftp().create(Path::new(&remote_path)).map_err(|e| {
PyErr::new::<PyIOError, _>(format!("Remote file creation error: {}", e))
})?;
// create a variable-sized buffer to read the file and loop until EOF
let mut read_buffer = vec![0; std::cmp::min(metadata.len() as usize, MAX_BUFF_SIZE)];
loop {
let bytes_read = local_file.read(&mut read_buffer)?;
let bytes_read = local_file
.read(&mut read_buffer)
.map_err(|e| PyErr::new::<PyIOError, _>(format!("File read error: {}", e)))?;
if bytes_read == 0 {
break;
}
remote_file.write_all(&read_buffer[..bytes_read])?;
remote_file
.write_all(&read_buffer[..bytes_read])
.map_err(|e| {
PyErr::new::<PyIOError, _>(format!("Remote file write error: {}", e))
})?;
}
remote_file.close().unwrap();
Ok(())
}

/// Writes data over SFTP.
fn sftp_write_data(&mut self, data: String, remote_path: String) -> PyResult<()> {
let mut remote_file = self.sftp().create(Path::new(&remote_path)).unwrap();
remote_file.write_all(data.as_bytes()).unwrap();
remote_file.close().unwrap();
let mut remote_file = self.sftp().create(Path::new(&remote_path)).map_err(|e| {
PyErr::new::<PyIOError, _>(format!("Remote file creation error: {}", e))
})?;
remote_file
.write_all(data.as_bytes())
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Data write error: {}", e)))?;
remote_file
.close()
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Close error: {}", e)))?;
Ok(())
}

Expand All @@ -468,19 +515,26 @@ impl Connection {
let mut remote_file = BufReader::new(
self.session
.sftp()
.unwrap()
.map_err(|e| PyErr::new::<PyIOError, _>(format!("SFTP error: {}", e)))?
.open(Path::new(&source_path))
.unwrap(),
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Remote open error: {}", e)))?,
);
let dest_path = dest_path.unwrap_or_else(|| source_path.clone());
let mut other_file = dest_conn.sftp().create(Path::new(&dest_path)).unwrap();
let mut other_file = dest_conn
.sftp()
.create(Path::new(&dest_path))
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Dest file creation error: {}", e)))?;
let mut buffer = vec![0; MAX_BUFF_SIZE];
loop {
let len = remote_file.read(&mut buffer).unwrap();
let len = remote_file
.read(&mut buffer)
.map_err(|e| PyErr::new::<PyIOError, _>(format!("File read error: {}", e)))?;
if len == 0 {
break;
}
other_file.write_all(&buffer[..len]).unwrap();
other_file
.write_all(&buffer[..len])
.map_err(|e| PyErr::new::<PyIOError, _>(format!("File write error: {}", e)))?;
}
Ok(())
}
Expand Down Expand Up @@ -586,12 +640,20 @@ impl InteractiveShell {
/// Reads the output from the shell and returns an `SSHResult`.
/// Note: This sends an EOF to the shell, so you won't be able to send more commands after calling `read`.
fn read(&mut self) -> PyResult<SSHResult> {
self.channel.channel.flush().unwrap();
self.channel.channel.send_eof().unwrap();
self.channel
.channel
.flush()
.map_err(|e| PyErr::new::<PyTimeoutError, _>(format!("Channel flush error: {}", e)))?;
self.channel
.channel
.send_eof()
.map_err(|e| PyErr::new::<PyTimeoutError, _>(format!("Send EOF error: {}", e)))?;
match read_from_channel(&mut self.channel.channel) {
Ok(result) => Ok(result),
Err(e) => {
self.channel.channel.close().unwrap();
self.channel.channel.close().map_err(|e| {
PyErr::new::<PyTimeoutError, _>(format!("Channel close error: {}", e))
})?;
self.result = None;
Err(e)
}
Expand Down Expand Up @@ -685,24 +747,26 @@ impl FileTailer {

// Determine the current end of the remote file
fn seek_end(&mut self) -> PyResult<Option<u64>> {
let len = self
let metadata = self
.sftp_conn
.stat(Path::new(&self.remote_file))
.unwrap()
.size;
self.last_pos = len.unwrap();
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Stat error: {}", e)))?;
self.last_pos = metadata.size.unwrap_or(0);
if self.init_pos.is_none() {
self.init_pos = len;
self.init_pos = metadata.size;
}
Ok(len)
Ok(metadata.size)
}

// Read the contents of the remote file from a given position
#[pyo3(signature = (from_pos=None))]
fn read(&mut self, from_pos: Option<u64>) -> String {
let from_pos = from_pos.unwrap_or(self.last_pos);
let mut remote_file =
BufReader::new(self.sftp_conn.open(Path::new(&self.remote_file)).unwrap());
let mut remote_file = BufReader::new(
self.sftp_conn
.open(Path::new(&self.remote_file))
.expect("Opening remote file failed"),
);
remote_file
.seek(std::io::SeekFrom::Start(from_pos))
.unwrap();
Expand Down

0 comments on commit 5fc4e76

Please sign in to comment.