From f820dc6209508a1df8eb297d2da5e44ff126b614 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quang=20Ng=C3=B4?= Date: Sun, 28 Nov 2021 19:22:46 +0700 Subject: [PATCH] Handle CSRF protection according to `2.3.1. CSRF Protection` spec. --- examples/blocklist-update.rs | 2 +- examples/free-space.rs | 2 +- examples/port-test.rs | 2 +- examples/session-close.rs | 2 +- examples/session-get.rs | 2 +- examples/session-stats.rs | 2 +- examples/torrent-action.rs | 2 +- examples/torrent-add.rs | 2 +- examples/torrent-get.rs | 2 +- examples/torrent-remove.rs | 2 +- examples/torrent-rename-path.rs | 2 +- examples/torrent-set-location.rs | 2 +- src/lib.rs | 169 ++++++++++++++++++------------- 13 files changed, 109 insertions(+), 84 deletions(-) diff --git a/examples/blocklist-update.rs b/examples/blocklist-update.rs index e7c82f2..078a079 100644 --- a/examples/blocklist-update.rs +++ b/examples/blocklist-update.rs @@ -10,7 +10,7 @@ async fn main() -> Result<()> { dotenv().ok(); env_logger::init(); let url = env::var("TURL")?; - let client; + let mut client; if let (Ok(user), Ok(password)) = (env::var("TUSER"), env::var("TPWD")) { client = TransClient::with_auth(&url, BasicAuth {user, password}); } else { diff --git a/examples/free-space.rs b/examples/free-space.rs index f6188c5..0df5fc3 100644 --- a/examples/free-space.rs +++ b/examples/free-space.rs @@ -11,7 +11,7 @@ async fn main() -> Result<()> { env_logger::init(); let url = env::var("TURL")?; let dir = env::var("TDIR")?; - let client; + let mut client; if let (Ok(user), Ok(password)) = (env::var("TUSER"), env::var("TPWD")) { client = TransClient::with_auth(&url, BasicAuth {user, password}); } else { diff --git a/examples/port-test.rs b/examples/port-test.rs index fb33e45..b4a2da5 100644 --- a/examples/port-test.rs +++ b/examples/port-test.rs @@ -10,7 +10,7 @@ async fn main() -> Result<()> { dotenv().ok(); env_logger::init(); let url = env::var("TURL")?; - let client; + let mut client; if let (Ok(user), Ok(password)) = (env::var("TUSER"), env::var("TPWD")) { client = TransClient::with_auth(&url, BasicAuth {user, password}); } else { diff --git a/examples/session-close.rs b/examples/session-close.rs index 88f76c2..82776ea 100644 --- a/examples/session-close.rs +++ b/examples/session-close.rs @@ -10,7 +10,7 @@ async fn main() -> Result<()> { dotenv().ok(); env_logger::init(); let url = env::var("TURL")?; - let client; + let mut client; if let (Ok(user), Ok(password)) = (env::var("TUSER"), env::var("TPWD")) { client = TransClient::with_auth(&url, BasicAuth {user, password}); } else { diff --git a/examples/session-get.rs b/examples/session-get.rs index e8b4193..9dbeaa6 100644 --- a/examples/session-get.rs +++ b/examples/session-get.rs @@ -10,7 +10,7 @@ async fn main() -> Result<()> { dotenv().ok(); env_logger::init(); let url = env::var("TURL")?; - let client; + let mut client; if let (Ok(user), Ok(password)) = (env::var("TUSER"), env::var("TPWD")) { client = TransClient::with_auth(&url, BasicAuth {user, password}); } else { diff --git a/examples/session-stats.rs b/examples/session-stats.rs index c162fe7..a4c4a17 100644 --- a/examples/session-stats.rs +++ b/examples/session-stats.rs @@ -10,7 +10,7 @@ async fn main() -> Result<()> { dotenv().ok(); env_logger::init(); let url = env::var("TURL")?; - let client; + let mut client; if let (Ok(user), Ok(password)) = (env::var("TUSER"), env::var("TPWD")) { client = TransClient::with_auth(&url, BasicAuth {user, password}); } else { diff --git a/examples/torrent-action.rs b/examples/torrent-action.rs index d247756..855082c 100644 --- a/examples/torrent-action.rs +++ b/examples/torrent-action.rs @@ -11,7 +11,7 @@ async fn main() -> Result<()> { dotenv().ok(); env_logger::init(); let url = env::var("TURL")?; - let client; + let mut client; if let (Ok(user), Ok(password)) = (env::var("TUSER"), env::var("TPWD")) { client = TransClient::with_auth(&url, BasicAuth {user, password}); } else { diff --git a/examples/torrent-add.rs b/examples/torrent-add.rs index 843a068..6a5f5a6 100644 --- a/examples/torrent-add.rs +++ b/examples/torrent-add.rs @@ -11,7 +11,7 @@ async fn main() -> Result<()> { dotenv().ok(); env_logger::init(); let url = env::var("TURL")?; - let client; + let mut client; if let (Ok(user), Ok(password)) = (env::var("TUSER"), env::var("TPWD")) { client = TransClient::with_auth(&url, BasicAuth {user, password}); } else { diff --git a/examples/torrent-get.rs b/examples/torrent-get.rs index 01c6b45..fb94d8c 100644 --- a/examples/torrent-get.rs +++ b/examples/torrent-get.rs @@ -11,7 +11,7 @@ async fn main() -> Result<()> { dotenv().ok(); env_logger::init(); let url = env::var("TURL")?; - let client; + let mut client; if let (Ok(user), Ok(password)) = (env::var("TUSER"), env::var("TPWD")) { client = TransClient::with_auth(&url, BasicAuth {user, password}); } else { diff --git a/examples/torrent-remove.rs b/examples/torrent-remove.rs index 9b8e2cc..060c95c 100644 --- a/examples/torrent-remove.rs +++ b/examples/torrent-remove.rs @@ -11,7 +11,7 @@ async fn main() -> Result<()> { dotenv().ok(); env_logger::init(); let url = env::var("TURL")?; - let client; + let mut client; if let (Ok(user), Ok(password)) = (env::var("TUSER"), env::var("TPWD")) { client = TransClient::with_auth(&url, BasicAuth {user, password}); } else { diff --git a/examples/torrent-rename-path.rs b/examples/torrent-rename-path.rs index 9e50e46..59ee498 100644 --- a/examples/torrent-rename-path.rs +++ b/examples/torrent-rename-path.rs @@ -12,7 +12,7 @@ async fn main() -> Result<()> { env_logger::init(); let url= env::var("TURL")?; let basic_auth = BasicAuth{user: env::var("TUSER")?, password: env::var("TPWD")?}; - let client = TransClient::with_auth(&url, basic_auth); + let mut client = TransClient::with_auth(&url, basic_auth); let res: RpcResponse = client.torrent_rename_path(vec![Id::Id(1)], String::from("Folder/OldFile.jpg"), String::from("NewFile.jpg")).await?; println!("rename-path result: {:#?}", res); diff --git a/examples/torrent-set-location.rs b/examples/torrent-set-location.rs index 881f25c..91f05b3 100644 --- a/examples/torrent-set-location.rs +++ b/examples/torrent-set-location.rs @@ -11,7 +11,7 @@ async fn main() -> Result<()> { dotenv().ok(); env_logger::init(); let url = env::var("TURL")?; - let client; + let mut client; if let (Ok(user), Ok(password)) = (env::var("TUSER"), env::var("TPWD")) { client = TransClient::with_auth(&url, BasicAuth {user, password}); } else { diff --git a/src/lib.rs b/src/lib.rs index 608b090..af917e9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ extern crate log; extern crate reqwest; use reqwest::header::CONTENT_TYPE; +use reqwest::{Client, StatusCode}; use serde::de::DeserializeOwned; pub mod types; @@ -20,9 +21,30 @@ use types::{Id, Torrent, TorrentGetField, Torrents}; use types::{Nothing, Result, RpcRequest, RpcResponse, RpcResponseArgument, TorrentRenamePath}; use types::{TorrentAddArgs, TorrentAdded}; +const MAX_RETRIES: usize = 5; + +#[derive(Clone, Debug)] +enum TransError { + MaxRetriesReached, + NoSessionIdReceived, +} + +impl std::fmt::Display for TransError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match *self { + TransError::MaxRetriesReached => write!(f, "Max retries reached!"), + TransError::NoSessionIdReceived => write!(f, "No session id received!"), + } + } +} + +impl std::error::Error for TransError {} + pub struct TransClient { url: String, auth: Option, + session_id: Option, + client: Client, } impl TransClient { @@ -31,6 +53,8 @@ impl TransClient { TransClient { url: url.to_string(), auth: Some(basic_auth), + session_id: None, + client: Client::new(), } } @@ -39,45 +63,19 @@ impl TransClient { TransClient { url: url.to_string(), auth: None, + session_id: None, + client: Client::new(), } } /// Prepares a request for provided server and auth fn rpc_request(&self) -> reqwest::RequestBuilder { - let client = reqwest::Client::new(); if let Some(auth) = &self.auth { - client - .post(&self.url) + self.client.post(&self.url) .basic_auth(&auth.user, Some(&auth.password)) } else { - client.post(&self.url) - } - .header(CONTENT_TYPE, "application/json") - } - - /// Performs session-get call and takes the x-transmission-session-id - /// header to perform calls, using it's value - /// - /// # Errors - /// - /// If response is impossible to unwrap then it will return an empty session_id - async fn get_session_id(&self) -> String { - info!("Requesting session id info"); - let response: reqwest::Result = self - .rpc_request() - .json(&RpcRequest::session_get()) - .send() - .await; - let session_id = match response { - Ok(ref resp) => match resp.headers().get("x-transmission-session-id") { - Some(res) => res.to_str().expect("header value should be a string"), - _ => "", - }, - _ => "", - } - .to_owned(); - info!("Received session id: {}", session_id); - session_id + self.client.post(&self.url) + }.header(CONTENT_TYPE, "application/json") } /// Performs a session get call @@ -102,7 +100,7 @@ impl TransClient { /// env_logger::init(); /// let url= env::var("TURL")?; /// let basic_auth = BasicAuth{user: env::var("TUSER")?, password: env::var("TPWD")?}; - /// let client = TransClient::with_auth(&url, basic_auth); + /// let mut client = TransClient::with_auth(&url, basic_auth); /// let response: Result> = client.session_get().await; /// match response { /// Ok(_) => println!("Yay!"), @@ -112,7 +110,7 @@ impl TransClient { /// Ok(()) /// } /// ``` - pub async fn session_get(&self) -> Result> { + pub async fn session_get(&mut self) -> Result> { self.call(RpcRequest::session_get()).await } @@ -138,7 +136,7 @@ impl TransClient { /// env_logger::init(); /// let url= env::var("TURL")?; /// let basic_auth = BasicAuth{user: env::var("TUSER")?, password: env::var("TPWD")?}; - /// let client = TransClient::with_auth(&url, basic_auth); + /// let mut client = TransClient::with_auth(&url, basic_auth); /// let response: Result> = client.session_stats().await; /// match response { /// Ok(_) => println!("Yay!"), @@ -148,7 +146,7 @@ impl TransClient { /// Ok(()) /// } /// ``` - pub async fn session_stats(&self) -> Result> { + pub async fn session_stats(&mut self) -> Result> { self.call(RpcRequest::session_stats()).await } @@ -174,7 +172,7 @@ impl TransClient { /// env_logger::init(); /// let url= env::var("TURL")?; /// let basic_auth = BasicAuth{user: env::var("TUSER")?, password: env::var("TPWD")?}; - /// let client = TransClient::with_auth(&url, basic_auth); + /// let mut client = TransClient::with_auth(&url, basic_auth); /// let response: Result> = client.session_close().await; /// match response { /// Ok(_) => println!("Yay!"), @@ -184,7 +182,7 @@ impl TransClient { /// Ok(()) /// } /// ``` - pub async fn session_close(&self) -> Result> { + pub async fn session_close(&mut self) -> Result> { self.call(RpcRequest::session_close()).await } @@ -210,7 +208,7 @@ impl TransClient { /// env_logger::init(); /// let url= env::var("TURL")?; /// let basic_auth = BasicAuth{user: env::var("TUSER")?, password: env::var("TPWD")?}; - /// let client = TransClient::with_auth(&url, basic_auth); + /// let mut client = TransClient::with_auth(&url, basic_auth); /// let response: Result> = client.blocklist_update().await; /// match response { /// Ok(_) => println!("Yay!"), @@ -220,7 +218,7 @@ impl TransClient { /// Ok(()) /// } /// ``` - pub async fn blocklist_update(&self) -> Result> { + pub async fn blocklist_update(&mut self) -> Result> { self.call(RpcRequest::blocklist_update()).await } @@ -247,7 +245,7 @@ impl TransClient { /// let url= env::var("TURL")?; /// let dir = env::var("TDIR")?; /// let basic_auth = BasicAuth{user: env::var("TUSER")?, password: env::var("TPWD")?}; - /// let client = TransClient::with_auth(&url, basic_auth); + /// let mut client = TransClient::with_auth(&url, basic_auth); /// let response: Result> = client.free_space(dir).await; /// match response { /// Ok(_) => println!("Yay!"), @@ -257,7 +255,7 @@ impl TransClient { /// Ok(()) /// } /// ``` - pub async fn free_space(&self, path: String) -> Result> { + pub async fn free_space(&mut self, path: String) -> Result> { self.call(RpcRequest::free_space(path)).await } @@ -283,7 +281,7 @@ impl TransClient { /// env_logger::init(); /// let url= env::var("TURL")?; /// let basic_auth = BasicAuth{user: env::var("TUSER")?, password: env::var("TPWD")?}; - /// let client = TransClient::with_auth(&url, basic_auth); + /// let mut client = TransClient::with_auth(&url, basic_auth); /// let response: Result> = client.port_test().await; /// match response { /// Ok(_) => println!("Yay!"), @@ -293,7 +291,7 @@ impl TransClient { /// Ok(()) /// } /// ``` - pub async fn port_test(&self) -> Result> { + pub async fn port_test(&mut self) -> Result> { self.call(RpcRequest::port_test()).await } @@ -322,7 +320,7 @@ impl TransClient { /// env_logger::init(); /// let url= env::var("TURL")?; /// let basic_auth = BasicAuth{user: env::var("TUSER")?, password: env::var("TPWD")?}; - /// let client = TransClient::with_auth(&url, basic_auth); + /// let mut client = TransClient::with_auth(&url, basic_auth); /// /// let res: RpcResponse> = client.torrent_get(None, None).await?; /// let names: Vec<&String> = res.arguments.torrents.iter().map(|it| it.name.as_ref().unwrap()).collect(); @@ -348,7 +346,7 @@ impl TransClient { /// } /// ``` pub async fn torrent_get( - &self, + &mut self, fields: Option>, ids: Option>, ) -> Result>> { @@ -378,7 +376,7 @@ impl TransClient { /// env_logger::init(); /// let url= env::var("TURL")?; /// let basic_auth = BasicAuth{user: env::var("TUSER")?, password: env::var("TPWD")?}; - /// let client = TransClient::with_auth(&url, basic_auth); + /// let mut client = TransClient::with_auth(&url, basic_auth); /// let res1: RpcResponse = client.torrent_action(TorrentAction::Start, vec![Id::Id(1)]).await?; /// println!("Start result: {:?}", &res1.is_ok()); /// let res2: RpcResponse = client.torrent_action(TorrentAction::Stop, vec![Id::Id(1)]).await?; @@ -388,7 +386,7 @@ impl TransClient { /// } /// ``` pub async fn torrent_action( - &self, + &mut self, action: TorrentAction, ids: Vec, ) -> Result> { @@ -418,7 +416,7 @@ impl TransClient { /// env_logger::init(); /// let url= env::var("TURL")?; /// let basic_auth = BasicAuth{user: env::var("TUSER")?, password: env::var("TPWD")?}; - /// let client = TransClient::with_auth(&url, basic_auth); + /// let mut client = TransClient::with_auth(&url, basic_auth); /// let res: RpcResponse = client.torrent_remove(vec![Id::Id(1)], false).await?; /// println!("Remove result: {:?}", &res.is_ok()); /// @@ -426,7 +424,7 @@ impl TransClient { /// } /// ``` pub async fn torrent_remove( - &self, + &mut self, ids: Vec, delete_local_data: bool, ) -> Result> { @@ -457,7 +455,7 @@ impl TransClient { /// env_logger::init(); /// let url= env::var("TURL")?; /// let basic_auth = BasicAuth{user: env::var("TUSER")?, password: env::var("TPWD")?}; - /// let client = TransClient::with_auth(&url, basic_auth); + /// let mut client = TransClient::with_auth(&url, basic_auth); /// let res: RpcResponse = client.torrent_set_location(vec![Id::Id(1)], String::from("/new/location"), Option::from(false)).await?; /// println!("Set-location result: {:?}", &res.is_ok()); /// @@ -465,7 +463,7 @@ impl TransClient { /// } /// ``` pub async fn torrent_set_location( - &self, + &mut self, ids: Vec, location: String, move_from: Option, @@ -497,7 +495,7 @@ impl TransClient { /// env_logger::init(); /// let url= env::var("TURL")?; /// let basic_auth = BasicAuth{user: env::var("TUSER")?, password: env::var("TPWD")?}; - /// let client = TransClient::with_auth(&url, basic_auth); + /// let mut client = TransClient::with_auth(&url, basic_auth); /// let res: RpcResponse = client.torrent_rename_path(vec![Id::Id(1)], String::from("Folder/OldFile.jpg"), String::from("NewFile.jpg")).await?; /// println!("rename-path result: {:#?}", res); /// @@ -505,7 +503,7 @@ impl TransClient { /// } /// ``` pub async fn torrent_rename_path( - &self, + &mut self, ids: Vec, path: String, name: String, @@ -537,7 +535,7 @@ impl TransClient { /// env_logger::init(); /// let url= env::var("TURL")?; /// let basic_auth = BasicAuth{user: env::var("TUSER")?, password: env::var("TPWD")?}; - /// let client = TransClient::with_auth(&url, basic_auth); + /// let mut client = TransClient::with_auth(&url, basic_auth); /// let add: TorrentAddArgs = TorrentAddArgs { /// filename: Some("https://releases.ubuntu.com/20.04/ubuntu-20.04-desktop-amd64.iso.torrent".to_string()), /// ..TorrentAddArgs::default() @@ -549,7 +547,7 @@ impl TransClient { /// Ok(()) /// } /// ``` - pub async fn torrent_add(&self, add: TorrentAddArgs) -> Result> { + pub async fn torrent_add(&mut self, add: TorrentAddArgs) -> Result> { if add.metainfo == None && add.filename == None { panic!("Metainfo or Filename should be provided") } @@ -561,25 +559,52 @@ impl TransClient { /// # Errors /// /// Any IO Error or Deserialization error - async fn call(&self, request: RpcRequest) -> Result> + async fn call(&mut self, request: RpcRequest) -> Result> where RS: RpcResponseArgument + DeserializeOwned + std::fmt::Debug, { - info!("Loaded auth: {:?}", &self.auth); - let rq: reqwest::RequestBuilder = self - .rpc_request() - .header("X-Transmission-Session-Id", self.get_session_id().await) - .json(&request); - info!( - "Request body: {:?}", - rq.try_clone() - .expect("Unable to get the request body") - .body_string()? - ); - let resp: reqwest::Response = rq.send().await?; - let rpc_response: RpcResponse = resp.json().await?; - info!("Response body: {:#?}", rpc_response); - Ok(rpc_response) + let mut remaining_retries = MAX_RETRIES; + loop { + if remaining_retries <= 0 { + return Err(From::from(TransError::MaxRetriesReached)); + } + remaining_retries -= 1; + + info!("Loaded auth: {:?}", &self.auth); + let rq = match &self.session_id { + None => self.rpc_request(), + Some(id) => { + self.rpc_request().header("X-Transmission-Session-Id", id) + } + }.json(&request); + + info!( + "Request body: {:?}", + rq.try_clone() + .expect("Unable to get the request body") + .body_string()? + ); + + let rsp: reqwest::Response = rq.send().await?; + match rsp.status() { + StatusCode::CONFLICT => { + let session_id = rsp.headers() + .get("X-Transmission-Session-Id") + .ok_or(TransError::NoSessionIdReceived)? + .to_str()?; + self.session_id = Some(String::from(session_id)); + + info!("Got new session_id: {}. Retrying request.", session_id); + continue; + } + _ => { + let rpc_response: RpcResponse = rsp.json().await?; + info!("Response body: {:#?}", rpc_response); + + return Ok(rpc_response) + } + } + } } } @@ -606,7 +631,7 @@ mod tests { dotenv().ok(); env_logger::init(); let url = env::var("TURL")?; - let client; + let mut client; if let (Ok(user), Ok(password)) = (env::var("TUSER"), env::var("TPWD")) { client = TransClient::with_auth(&url, BasicAuth {user, password}); } else {