From 843bd7c2a5fba4a63fe6afc13712a868808a62a2 Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Wed, 5 Apr 2023 01:13:56 +0300 Subject: [PATCH 01/12] access token verfy trait, google and test implementations --- Cargo.lock | 130 +++++++++++++++++++++++++++++++ mpc-recovery/Cargo.toml | 3 +- mpc-recovery/src/main.rs | 2 + mpc-recovery/src/ouath.rs | 83 ++++++++++++++++++++ mpc-recovery/tests/test-oauth.rs | 7 ++ 5 files changed, 224 insertions(+), 1 deletion(-) create mode 100644 mpc-recovery/src/ouath.rs create mode 100644 mpc-recovery/tests/test-oauth.rs diff --git a/Cargo.lock b/Cargo.lock index 6d96c0f71..b4f954b1b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -463,6 +463,15 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +[[package]] +name = "encoding_rs" +version = "0.8.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071a31f4ee85403370b58aca746f01041ede6f0da2730960ad001edc2b71b394" +dependencies = [ + "cfg-if", +] + [[package]] name = "errno" version = "0.3.0" @@ -678,8 +687,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi 0.11.0+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -815,6 +826,19 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1788965e61b367cd03a62950836d5cd41560c3577d90e40e0819373194d1661c" +dependencies = [ + "http", + "hyper", + "rustls", + "tokio", + "tokio-rustls", +] + [[package]] name = "hyperlocal" version = "0.8.0" @@ -893,6 +917,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "ipnet" +version = "2.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12b6ee2129af8d4fb011108c73d99a1b83a85977f23b82460c0ae2e25bb4b57f" + [[package]] name = "is-terminal" version = "0.4.6" @@ -1034,6 +1064,7 @@ dependencies = [ "clap", "futures", "hex", + "oauth2", "ractor", "ractor_cluster", "rand 0.7.3", @@ -1118,6 +1149,26 @@ dependencies = [ "libc", ] +[[package]] +name = "oauth2" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eeaf26a72311c087f8c5ba617c96fac67a5c04f430e716ac8d8ab2de62e23368" +dependencies = [ + "base64 0.13.1", + "chrono", + "getrandom 0.2.8", + "http", + "rand 0.8.5", + "reqwest", + "serde", + "serde_json", + "serde_path_to_error", + "sha2", + "thiserror", + "url", +] + [[package]] name = "object" version = "0.30.3" @@ -1497,6 +1548,45 @@ version = "0.6.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" +[[package]] +name = "reqwest" +version = "0.11.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27b71749df584b7f4cac2c426c127a7c785a5106cc98f7a8feb044115f0fa254" +dependencies = [ + "base64 0.21.0", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "hyper", + "hyper-rustls", + "ipnet", + "js-sys", + "log", + "mime", + "once_cell", + "percent-encoding", + "pin-project-lite", + "rustls", + "rustls-pemfile", + "serde", + "serde_json", + "serde_urlencoded", + "tokio", + "tokio-rustls", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "webpki-roots", + "winreg", +] + [[package]] name = "ring" version = "0.16.20" @@ -1544,6 +1634,15 @@ dependencies = [ "webpki", ] +[[package]] +name = "rustls-pemfile" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" +dependencies = [ + "base64 0.21.0", +] + [[package]] name = "rustversion" version = "1.0.12" @@ -2091,6 +2190,7 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] [[package]] @@ -2158,6 +2258,18 @@ dependencies = [ "wasm-bindgen-shared", ] +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f219e0d211ba40266969f6dbdd90636da12f75bee4fc9d6c23d1260dadb51454" +dependencies = [ + "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "wasm-bindgen-macro" version = "0.2.84" @@ -2207,6 +2319,15 @@ dependencies = [ "untrusted", ] +[[package]] +name = "webpki-roots" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c71e40d7d2c34a5106301fb632274ca37242cd0c9d3e64dbece371a40a2d87" +dependencies = [ + "webpki", +] + [[package]] name = "which" version = "4.4.0" @@ -2381,6 +2502,15 @@ version = "0.47.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d6e62c256dc6d40b8c8707df17df8d774e60e39db723675241e7c15e910bce7" +[[package]] +name = "winreg" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" +dependencies = [ + "winapi", +] + [[package]] name = "zeroize" version = "1.4.3" diff --git a/mpc-recovery/Cargo.toml b/mpc-recovery/Cargo.toml index a77b20a23..beed3f728 100644 --- a/mpc-recovery/Cargo.toml +++ b/mpc-recovery/Cargo.toml @@ -32,4 +32,5 @@ serde = "1" serde_json = "1" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } \ No newline at end of file +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +oauth2 = "4.3.0" \ No newline at end of file diff --git a/mpc-recovery/src/main.rs b/mpc-recovery/src/main.rs index a296629fb..f1f83fa19 100644 --- a/mpc-recovery/src/main.rs +++ b/mpc-recovery/src/main.rs @@ -1,3 +1,5 @@ +mod ouath; + use clap::Parser; use threshold_crypto::{serde_impl::SerdeSecret, PublicKeySet, SecretKeyShare}; diff --git a/mpc-recovery/src/ouath.rs b/mpc-recovery/src/ouath.rs new file mode 100644 index 000000000..c03dcd17f --- /dev/null +++ b/mpc-recovery/src/ouath.rs @@ -0,0 +1,83 @@ +use oauth2::basic::BasicClient; +use oauth2::reqwest::http_client; +use oauth2::{AccessToken, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenResponse, TokenUrl}; +use std::error::Error; + +pub trait OAuthTokenVerifier { + fn verify_token(&self, token: &str) -> Option; +} + +/* Google verifier */ +pub struct GoogleTokenVerifier { + client_id: ClientId, // TODO: do we need this field? +} + +impl GoogleTokenVerifier { + pub fn new(client_id: &str) -> GoogleTokenVerifier { // TODO: do we need this function? + GoogleTokenVerifier { + client_id: ClientId::new(client_id.to_owned()), + } + } + + pub fn verify_token(&self, token: &str) -> Result> { + let client_secret = ClientSecret::new("".to_owned()); + let auth_url = AuthUrl::new("https://accounts.google.com/o/oauth2/auth".to_owned())?; + let token_url = TokenUrl::new("https://accounts.google.com/o/oauth2/token".to_owned())?; + let redirect_url = RedirectUrl::new("http://localhost:8080".to_owned())?; + let client = BasicClient::new( + self.client_id.clone(), + Some(client_secret), + auth_url, + Some(token_url), + ) + .set_redirect_uri(redirect_url); + + let token = AccessToken::new(token.to_owned()); + let token_info_url = "https://www.googleapis.com/oauth2/v3/tokeninfo".parse()?; + let token_info_request = + client.request::<_, TokenResponse>(http_client, reqwest::Method::GET, token_info_url) + .unwrap() + .bearer_auth(token.secret()); + let token_info = token_info_request.send()?.json::()?; + + if let Some(aud) = token_info.get("aud") { + if let Some(client_id) = aud.as_str() { + if client_id == self.client_id.secret() { + if let Some(sub) = token_info.get("sub") { + if let Some(account_id) = sub.as_str() { + return Ok(account_id.to_owned()); + } + } + } + } + } + + Err("Invalid token".into()) + } +} + +/* Test verifier */ +pub struct TestTokenVerifier { + client_id: String, +} + +impl TestTokenVerifier { + pub fn new(client_id: String) -> Self { + Self { + client_id, + } + } + + fn verify_test_token(&self, token: &str) -> Option { + match token { + "valid" => Some(ClientId::new("testAccountId".to_owned())), // TODO: add prefix? + _ => None, + } + } +} + +impl OAuthTokenVerifier for TestTokenVerifier { + fn verify_token(&self, token: &str) -> Option { + self.verify_test_token(token) + } +} diff --git a/mpc-recovery/tests/test-oauth.rs b/mpc-recovery/tests/test-oauth.rs new file mode 100644 index 000000000..a16df46c2 --- /dev/null +++ b/mpc-recovery/tests/test-oauth.rs @@ -0,0 +1,7 @@ +#[cfg(test)] +mod tests { + #[test] + fn test_addition() { + assert_eq!(2 + 2, 4); + } +} From d08c8d0c1aca5690de63e34f0ac1a54acf9d41d4 Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Wed, 5 Apr 2023 11:40:21 +0300 Subject: [PATCH 02/12] separate test file deleted --- mpc-recovery/tests/test-oauth.rs | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 mpc-recovery/tests/test-oauth.rs diff --git a/mpc-recovery/tests/test-oauth.rs b/mpc-recovery/tests/test-oauth.rs deleted file mode 100644 index a16df46c2..000000000 --- a/mpc-recovery/tests/test-oauth.rs +++ /dev/null @@ -1,7 +0,0 @@ -#[cfg(test)] -mod tests { - #[test] - fn test_addition() { - assert_eq!(2 + 2, 4); - } -} From d7d93e0cf0b47da3162910c155d023660be7bc26 Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Wed, 5 Apr 2023 12:05:08 +0300 Subject: [PATCH 03/12] simple TestTokenVerifier tests added --- mpc-recovery/src/ouath.rs | 112 ++++++++++++++++++-------------------- 1 file changed, 53 insertions(+), 59 deletions(-) diff --git a/mpc-recovery/src/ouath.rs b/mpc-recovery/src/ouath.rs index c03dcd17f..57330003b 100644 --- a/mpc-recovery/src/ouath.rs +++ b/mpc-recovery/src/ouath.rs @@ -1,83 +1,77 @@ use oauth2::basic::BasicClient; use oauth2::reqwest::http_client; use oauth2::{AccessToken, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenResponse, TokenUrl}; -use std::error::Error; pub trait OAuthTokenVerifier { - fn verify_token(&self, token: &str) -> Option; + fn verify_token(&self, token: &str) -> Option<&str>; } /* Google verifier */ -pub struct GoogleTokenVerifier { - client_id: ClientId, // TODO: do we need this field? -} +pub struct GoogleTokenVerifier {} -impl GoogleTokenVerifier { - pub fn new(client_id: &str) -> GoogleTokenVerifier { // TODO: do we need this function? - GoogleTokenVerifier { - client_id: ClientId::new(client_id.to_owned()), - } - } +impl OAuthTokenVerifier for GoogleTokenVerifier { + fn verify_token(&self, token: &str) -> Option<&str> { + // let client_secret = ClientSecret::new("".to_owned()); + // let auth_url = AuthUrl::new("https://accounts.google.com/o/oauth2/auth".to_owned())?; + // let token_url = TokenUrl::new("https://accounts.google.com/o/oauth2/token".to_owned())?; + // let redirect_url = RedirectUrl::new("http://localhost:8080".to_owned())?; + // let client = BasicClient::new( + // self.client_id.clone(), + // Some(client_secret), + // auth_url, + // Some(token_url), + // ) + // .set_redirect_uri(redirect_url); - pub fn verify_token(&self, token: &str) -> Result> { - let client_secret = ClientSecret::new("".to_owned()); - let auth_url = AuthUrl::new("https://accounts.google.com/o/oauth2/auth".to_owned())?; - let token_url = TokenUrl::new("https://accounts.google.com/o/oauth2/token".to_owned())?; - let redirect_url = RedirectUrl::new("http://localhost:8080".to_owned())?; - let client = BasicClient::new( - self.client_id.clone(), - Some(client_secret), - auth_url, - Some(token_url), - ) - .set_redirect_uri(redirect_url); + // let token = AccessToken::new(token.to_owned()); + // let token_info_url = "https://www.googleapis.com/oauth2/v3/tokeninfo".parse()?; + // let token_info_request = + // client.request::<_, TokenResponse>(http_client, reqwest::Method::GET, token_info_url) + // .unwrap() + // .bearer_auth(token.secret()); + // let token_info = token_info_request.send()?.json::()?; - let token = AccessToken::new(token.to_owned()); - let token_info_url = "https://www.googleapis.com/oauth2/v3/tokeninfo".parse()?; - let token_info_request = - client.request::<_, TokenResponse>(http_client, reqwest::Method::GET, token_info_url) - .unwrap() - .bearer_auth(token.secret()); - let token_info = token_info_request.send()?.json::()?; + // if let Some(aud) = token_info.get("aud") { + // if let Some(client_id) = aud.as_str() { + // if client_id == self.client_id.secret() { + // if let Some(sub) = token_info.get("sub") { + // if let Some(account_id) = sub.as_str() { + // return Ok(account_id.to_owned()); + // } + // } + // } + // } + // } - if let Some(aud) = token_info.get("aud") { - if let Some(client_id) = aud.as_str() { - if client_id == self.client_id.secret() { - if let Some(sub) = token_info.get("sub") { - if let Some(account_id) = sub.as_str() { - return Ok(account_id.to_owned()); - } - } - } - } - } - - Err("Invalid token".into()) + // Err("Invalid token".into()) + return Some("TODO: replae this with google verification"); } } /* Test verifier */ -pub struct TestTokenVerifier { - client_id: String, -} - -impl TestTokenVerifier { - pub fn new(client_id: String) -> Self { - Self { - client_id, - } - } +pub struct TestTokenVerifier {} - fn verify_test_token(&self, token: &str) -> Option { +impl OAuthTokenVerifier for TestTokenVerifier { + fn verify_token(&self, token: &str) -> Option<&str> { match token { - "valid" => Some(ClientId::new("testAccountId".to_owned())), // TODO: add prefix? + "valid" => Some("testAccountId"), _ => None, } } } -impl OAuthTokenVerifier for TestTokenVerifier { - fn verify_token(&self, token: &str) -> Option { - self.verify_test_token(token) - } +#[test] +fn test_verify_token_valid() { + let verifier = TestTokenVerifier {}; + let token = "valid"; + let account_id = verifier.verify_token(token).unwrap(); + assert_eq!(account_id, "testAccountId"); +} + +#[test] +fn test_verify_token_invalid() { + let verifier = TestTokenVerifier {}; + let token = "invalid"; + let account_id = verifier.verify_token(token); + assert_eq!(account_id, None); } From c0c7a47fff5842b9d579ce96707e74043825fa9e Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Wed, 5 Apr 2023 12:15:26 +0300 Subject: [PATCH 04/12] verify_token made async --- mpc-recovery/src/ouath.rs | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/mpc-recovery/src/ouath.rs b/mpc-recovery/src/ouath.rs index 57330003b..cc32c1ac2 100644 --- a/mpc-recovery/src/ouath.rs +++ b/mpc-recovery/src/ouath.rs @@ -1,16 +1,18 @@ -use oauth2::basic::BasicClient; -use oauth2::reqwest::http_client; -use oauth2::{AccessToken, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenResponse, TokenUrl}; +// use oauth2::basic::BasicClient; +// use oauth2::reqwest::http_client; +// use oauth2::{AccessToken, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenResponse, TokenUrl}; +#[async_trait::async_trait] pub trait OAuthTokenVerifier { - fn verify_token(&self, token: &str) -> Option<&str>; + async fn verify_token(&self, token: &str) -> Option<&str>; } /* Google verifier */ -pub struct GoogleTokenVerifier {} +// pub struct GoogleTokenVerifier {} -impl OAuthTokenVerifier for GoogleTokenVerifier { - fn verify_token(&self, token: &str) -> Option<&str> { +// #[async_trait::async_trait] +// impl OAuthTokenVerifier for GoogleTokenVerifier { +// async fn verify_token(&self, token: &str) -> Option<&str> { // let client_secret = ClientSecret::new("".to_owned()); // let auth_url = AuthUrl::new("https://accounts.google.com/o/oauth2/auth".to_owned())?; // let token_url = TokenUrl::new("https://accounts.google.com/o/oauth2/token".to_owned())?; @@ -44,15 +46,15 @@ impl OAuthTokenVerifier for GoogleTokenVerifier { // } // Err("Invalid token".into()) - return Some("TODO: replae this with google verification"); - } -} + // } +// } /* Test verifier */ pub struct TestTokenVerifier {} +#[async_trait::async_trait] impl OAuthTokenVerifier for TestTokenVerifier { - fn verify_token(&self, token: &str) -> Option<&str> { + async fn verify_token(&self, token: &str) -> Option<&str> { match token { "valid" => Some("testAccountId"), _ => None, @@ -60,18 +62,18 @@ impl OAuthTokenVerifier for TestTokenVerifier { } } -#[test] -fn test_verify_token_valid() { +#[tokio::test] +async fn test_verify_token_valid() { let verifier = TestTokenVerifier {}; let token = "valid"; - let account_id = verifier.verify_token(token).unwrap(); + let account_id = verifier.verify_token(token).await.unwrap(); assert_eq!(account_id, "testAccountId"); } -#[test] -fn test_verify_token_invalid() { +#[tokio::test] +async fn test_verify_token_invalid() { let verifier = TestTokenVerifier {}; let token = "invalid"; - let account_id = verifier.verify_token(token); + let account_id = verifier.verify_token(token).await; assert_eq!(account_id, None); } From 69bb7200c9b3ab91f928624393ca35db1e4493ce Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Wed, 5 Apr 2023 14:17:25 +0300 Subject: [PATCH 05/12] mocked token check added --- mpc-recovery/src/actor.rs | 184 +++++++++++++++++++++----------------- mpc-recovery/src/lib.rs | 1 + mpc-recovery/src/ouath.rs | 111 ++++++++++++++--------- 3 files changed, 170 insertions(+), 126 deletions(-) diff --git a/mpc-recovery/src/actor.rs b/mpc-recovery/src/actor.rs index ed367baf0..699c737fc 100644 --- a/mpc-recovery/src/actor.rs +++ b/mpc-recovery/src/actor.rs @@ -7,6 +7,7 @@ use ractor_cluster::RactorClusterMessage; use serde::{Deserialize, Serialize}; use threshold_crypto::{PublicKeySet, SecretKeyShare, Signature, SignatureShare}; +use crate::ouath::{OAuthTokenVerifier, UniversalTokenVerifier}; use crate::NodeId; const MPC_RECOVERY_GROUP: &str = "mpc-recovery"; @@ -108,7 +109,7 @@ impl Actor for NodeActor { } NodeMessage::SignRequest(payload, reply) => { tracing::debug!(?payload, "sign request"); - state.handle_signed_msg(payload, reply) + state.handle_signed_msg(payload, reply).await } }; Ok(()) @@ -124,17 +125,27 @@ impl NodeActorState { } #[tracing::instrument(level = "debug", skip_all)] - fn handle_signed_msg(&mut self, payload: Payload, reply: RpcReplyPort) { - // TODO: run some check that the msg.task.payload makes sense, fail if not - tracing::debug!("approved"); - - let response = self.sign(&payload); - tracing::debug!(?response, "replying"); - - match reply.send(response) { - Ok(()) => {} - Err(e) => tracing::error!("failed to respond: {}", e), - }; + async fn handle_signed_msg(&mut self, payload: Payload, reply: RpcReplyPort) { + // TODO: extract access token from payload + let access_token = "validToken"; + let access_token_verifier = UniversalTokenVerifier {}; + match access_token_verifier.verify_token(access_token).await { + Some(client_id) => { + tracing::debug!("approved, cleintId: {}", client_id); + + let response = self.sign(&payload); + tracing::debug!(?response, "replying"); + + match reply.send(response) { + Ok(()) => {} + Err(e) => tracing::error!("failed to respond: {}", e), + }; + } + None => { + tracing::error!("failed to verify access token"); + return; + } + } } async fn handle_new_request( @@ -143,77 +154,86 @@ impl NodeActorState { reply: RpcReplyPort, remote_actors: &Vec>, ) { - // TODO: run some check that the payload makes sense, fail if not - tracing::debug!("approved"); - - let mut futures = Vec::new(); - for actor in remote_actors { - tracing::debug!(actor = ?actor.get_id(), "asking actor"); - let future = actor - .call( - |tx| NodeMessage::SignRequest(payload.clone(), tx), - Some(Duration::from_millis(2000)), - ) - .map(|r| r.map_err(ractor::RactorErr::from)) - .map(|r| match r { - Ok(ractor::rpc::CallResult::Success(ok_value)) => Ok(ok_value), - Ok(cr) => Err(ractor::RactorErr::from(cr)), - Err(e) => Err(e), - }); - futures.push(future); - } - - // create unordered collection of futures - let futures = futures.into_iter().collect::>(); - - let mut responses = futures - .collect::>() - .await - .into_iter() - .filter_map(|r| r.ok()) - .collect::>(); - - let response = self.sign(&payload); - tracing::debug!(?response, "adding response from self"); - responses.push(response); - - tracing::debug!( - ?responses, - "got {} successful responses total", - responses.len() - ); - - let mut sig_shares = Vec::new(); - for sign_response in &responses { - if self - .pk_set - .public_key_share(sign_response.node_id) - .verify(&sign_response.sig_share, &payload) - { - sig_shares.push((sign_response.node_id, &sign_response.sig_share)); - } else { - tracing::error!(?sign_response, "received invalid signature",); + // TODO: extract access token from payload + let access_token = "validToken"; + let access_token_verifier = UniversalTokenVerifier {}; + match access_token_verifier.verify_token(access_token).await { + Some(client_id) => { + tracing::debug!("approved, cleintId: {}", client_id); + let mut futures = Vec::new(); + for actor in remote_actors { + tracing::debug!(actor = ?actor.get_id(), "asking actor"); + let future = actor + .call( + |tx| NodeMessage::SignRequest(payload.clone(), tx), + Some(Duration::from_millis(2000)), + ) + .map(|r| r.map_err(ractor::RactorErr::from)) + .map(|r| match r { + Ok(ractor::rpc::CallResult::Success(ok_value)) => Ok(ok_value), + Ok(cr) => Err(ractor::RactorErr::from(cr)), + Err(e) => Err(e), + }); + futures.push(future); + } + + // create unordered collection of futures + let futures = futures.into_iter().collect::>(); + + let mut responses = futures + .collect::>() + .await + .into_iter() + .filter_map(|r| r.ok()) + .collect::>(); + + let response = self.sign(&payload); + tracing::debug!(?response, "adding response from self"); + responses.push(response); + + tracing::debug!( + ?responses, + "got {} successful responses total", + responses.len() + ); + + let mut sig_shares = Vec::new(); + for sign_response in &responses { + if self + .pk_set + .public_key_share(sign_response.node_id) + .verify(&sign_response.sig_share, &payload) + { + sig_shares.push((sign_response.node_id, &sign_response.sig_share)); + } else { + tracing::error!(?sign_response, "received invalid signature",); + } + } + + tracing::debug!( + ?sig_shares, + "got {} valid signature shares total", + sig_shares.len() + ); + + if let Ok(sig) = self + .pk_set + .combine_signatures(sig_shares.clone().into_iter()) + { + tracing::debug!(?sig, "replying with full signature"); + reply.send(SignatureResponse { sig }).unwrap(); + } else { + tracing::error!( + "expected to get at least {} shares, but only got {}", + self.pk_set.threshold() + 1, + sig_shares.len() + ); + } + } + None => { + tracing::error!("failed to verify access token"); + return; } - } - - tracing::debug!( - ?sig_shares, - "got {} valid signature shares total", - sig_shares.len() - ); - - if let Ok(sig) = self - .pk_set - .combine_signatures(sig_shares.clone().into_iter()) - { - tracing::debug!(?sig, "replying with full signature"); - reply.send(SignatureResponse { sig }).unwrap(); - } else { - tracing::error!( - "expected to get at least {} shares, but only got {}", - self.pk_set.threshold() + 1, - sig_shares.len() - ); } } } diff --git a/mpc-recovery/src/lib.rs b/mpc-recovery/src/lib.rs index dd146c516..c40bd4587 100644 --- a/mpc-recovery/src/lib.rs +++ b/mpc-recovery/src/lib.rs @@ -5,6 +5,7 @@ use ractor_cluster::{node::NodeConnectionMode, NodeServer}; use threshold_crypto::{PublicKeySet, SecretKeySet, SecretKeyShare}; mod actor; +mod ouath; mod web; const COOKIE: &str = "mpc-recovery-cookie"; diff --git a/mpc-recovery/src/ouath.rs b/mpc-recovery/src/ouath.rs index cc32c1ac2..b96289ffe 100644 --- a/mpc-recovery/src/ouath.rs +++ b/mpc-recovery/src/ouath.rs @@ -1,53 +1,52 @@ -// use oauth2::basic::BasicClient; -// use oauth2::reqwest::http_client; -// use oauth2::{AccessToken, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenResponse, TokenUrl}; - #[async_trait::async_trait] pub trait OAuthTokenVerifier { async fn verify_token(&self, token: &str) -> Option<&str>; } -/* Google verifier */ -// pub struct GoogleTokenVerifier {} +pub enum SupportedTokenVerifiers { + GoogleTokenVerifier, + TestTokenVerifier, +} -// #[async_trait::async_trait] -// impl OAuthTokenVerifier for GoogleTokenVerifier { -// async fn verify_token(&self, token: &str) -> Option<&str> { - // let client_secret = ClientSecret::new("".to_owned()); - // let auth_url = AuthUrl::new("https://accounts.google.com/o/oauth2/auth".to_owned())?; - // let token_url = TokenUrl::new("https://accounts.google.com/o/oauth2/token".to_owned())?; - // let redirect_url = RedirectUrl::new("http://localhost:8080".to_owned())?; - // let client = BasicClient::new( - // self.client_id.clone(), - // Some(client_secret), - // auth_url, - // Some(token_url), - // ) - // .set_redirect_uri(redirect_url); +/* Universal token verifier */ +pub struct UniversalTokenVerifier {} + +#[async_trait::async_trait] +impl OAuthTokenVerifier for UniversalTokenVerifier { + async fn verify_token(&self, token: &str) -> Option<&str> { + // TODO: here we assume that verifier type can be determined from the token + match get_token_verifier_type(token) { + SupportedTokenVerifiers::GoogleTokenVerifier => { + return GoogleTokenVerifier {}.verify_token(token).await; + } + SupportedTokenVerifiers::TestTokenVerifier => { + return TestTokenVerifier {}.verify_token(token).await; + } + } + } +} - // let token = AccessToken::new(token.to_owned()); - // let token_info_url = "https://www.googleapis.com/oauth2/v3/tokeninfo".parse()?; - // let token_info_request = - // client.request::<_, TokenResponse>(http_client, reqwest::Method::GET, token_info_url) - // .unwrap() - // .bearer_auth(token.secret()); - // let token_info = token_info_request.send()?.json::()?; +fn get_token_verifier_type(token: &str) -> SupportedTokenVerifiers { + match token.len() { + // TODO: add real token type detection + 0 => SupportedTokenVerifiers::GoogleTokenVerifier, + _ => SupportedTokenVerifiers::TestTokenVerifier, + } +} - // if let Some(aud) = token_info.get("aud") { - // if let Some(client_id) = aud.as_str() { - // if client_id == self.client_id.secret() { - // if let Some(sub) = token_info.get("sub") { - // if let Some(account_id) = sub.as_str() { - // return Ok(account_id.to_owned()); - // } - // } - // } - // } - // } +/* Google verifier */ +pub struct GoogleTokenVerifier {} - // Err("Invalid token".into()) - // } -// } +#[async_trait::async_trait] +impl OAuthTokenVerifier for GoogleTokenVerifier { + // TODO: replace with real implementation + async fn verify_token(&self, token: &str) -> Option<&str> { + match token { + "validToken" => Some("testAccountId"), + _ => None, + } + } +} /* Test verifier */ pub struct TestTokenVerifier {} @@ -56,7 +55,7 @@ pub struct TestTokenVerifier {} impl OAuthTokenVerifier for TestTokenVerifier { async fn verify_token(&self, token: &str) -> Option<&str> { match token { - "valid" => Some("testAccountId"), + "validToken" => Some("testAccountId"), _ => None, } } @@ -65,15 +64,39 @@ impl OAuthTokenVerifier for TestTokenVerifier { #[tokio::test] async fn test_verify_token_valid() { let verifier = TestTokenVerifier {}; - let token = "valid"; + let token = "validToken"; let account_id = verifier.verify_token(token).await.unwrap(); assert_eq!(account_id, "testAccountId"); } #[tokio::test] -async fn test_verify_token_invalid() { +async fn test_verify_token_invalid_with_test_verifier() { let verifier = TestTokenVerifier {}; let token = "invalid"; let account_id = verifier.verify_token(token).await; assert_eq!(account_id, None); } + +#[tokio::test] +async fn test_verify_token_valid_with_test_verifier() { + let verifier = TestTokenVerifier {}; + let token = "validToken"; + let account_id = verifier.verify_token(token).await.unwrap(); + assert_eq!(account_id, "testAccountId"); +} + +#[tokio::test] +async fn test_verify_token_invalid_with_universal_verifier() { + let verifier = UniversalTokenVerifier {}; + let token = "invalid"; + let account_id = verifier.verify_token(token).await; + assert_eq!(account_id, None); +} + +#[tokio::test] +async fn test_verify_token_valid_with_universal_verifier() { + let verifier = UniversalTokenVerifier {}; + let token = "validToken"; + let account_id = verifier.verify_token(token).await.unwrap(); + assert_eq!(account_id, "testAccountId"); +} From d35e2ac0a05512a717d65dab9e075d6bd899e4f3 Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Wed, 5 Apr 2023 14:31:12 +0300 Subject: [PATCH 06/12] fmt --- mpc-recovery/src/actor.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/mpc-recovery/src/actor.rs b/mpc-recovery/src/actor.rs index 699c737fc..478424dfb 100644 --- a/mpc-recovery/src/actor.rs +++ b/mpc-recovery/src/actor.rs @@ -143,7 +143,6 @@ impl NodeActorState { } None => { tracing::error!("failed to verify access token"); - return; } } } @@ -232,7 +231,6 @@ impl NodeActorState { } None => { tracing::error!("failed to verify access token"); - return; } } } From f3285a91b42691e65843abd86952ca1c7d9915bd Mon Sep 17 00:00:00 2001 From: Daniyar Itegulov Date: Wed, 5 Apr 2023 23:27:35 +1000 Subject: [PATCH 07/12] make OAuthTokenVerifier a phantom trait --- mpc-recovery/src/actor.rs | 32 +++++++++++++++++++++----------- mpc-recovery/src/lib.rs | 13 +++++++++---- mpc-recovery/src/ouath.rs | 27 +++++++++++---------------- mpc-recovery/src/web.rs | 5 +++-- 4 files changed, 44 insertions(+), 33 deletions(-) diff --git a/mpc-recovery/src/actor.rs b/mpc-recovery/src/actor.rs index 478424dfb..293ff3978 100644 --- a/mpc-recovery/src/actor.rs +++ b/mpc-recovery/src/actor.rs @@ -1,3 +1,5 @@ +use std::marker::PhantomData; + use futures::prelude::*; use futures::stream::FuturesUnordered; use ractor::{ @@ -7,7 +9,7 @@ use ractor_cluster::RactorClusterMessage; use serde::{Deserialize, Serialize}; use threshold_crypto::{PublicKeySet, SecretKeyShare, Signature, SignatureShare}; -use crate::ouath::{OAuthTokenVerifier, UniversalTokenVerifier}; +use crate::ouath::OAuthTokenVerifier; use crate::NodeId; const MPC_RECOVERY_GROUP: &str = "mpc-recovery"; @@ -53,18 +55,27 @@ pub enum NodeMessage { SignRequest(Payload, RpcReplyPort), } -pub struct NodeActor; +pub struct NodeActor { + el: PhantomData, +} + +impl NodeActor { + pub fn new() -> NodeActor { + NodeActor { el: PhantomData } + } +} -pub struct NodeActorState { +pub struct NodeActorState { id: NodeId, pk_set: PublicKeySet, sk_share: SecretKeyShare, + el: PhantomData, } #[async_trait::async_trait] -impl Actor for NodeActor { +impl Actor for NodeActor { type Msg = NodeMessage; - type State = NodeActorState; + type State = NodeActorState; type Arguments = (NodeId, PublicKeySet, SecretKeyShare); #[tracing::instrument(level = "debug", skip_all, fields(id = args.0))] @@ -80,6 +91,7 @@ impl Actor for NodeActor { id: args.0, pk_set: args.1, sk_share: args.2, + el: PhantomData, }) } @@ -116,7 +128,7 @@ impl Actor for NodeActor { } } -impl NodeActorState { +impl NodeActorState { fn sign(&self, payload: &[u8]) -> SignResponse { SignResponse { node_id: self.id, @@ -128,8 +140,7 @@ impl NodeActorState { async fn handle_signed_msg(&mut self, payload: Payload, reply: RpcReplyPort) { // TODO: extract access token from payload let access_token = "validToken"; - let access_token_verifier = UniversalTokenVerifier {}; - match access_token_verifier.verify_token(access_token).await { + match O::verify_token(access_token).await { Some(client_id) => { tracing::debug!("approved, cleintId: {}", client_id); @@ -151,12 +162,11 @@ impl NodeActorState { &mut self, payload: Payload, reply: RpcReplyPort, - remote_actors: &Vec>, + remote_actors: &Vec>>, ) { // TODO: extract access token from payload let access_token = "validToken"; - let access_token_verifier = UniversalTokenVerifier {}; - match access_token_verifier.verify_token(access_token).await { + match O::verify_token(access_token).await { Some(client_id) => { tracing::debug!("approved, cleintId: {}", client_id); let mut futures = Vec::new(); diff --git a/mpc-recovery/src/lib.rs b/mpc-recovery/src/lib.rs index c40bd4587..19e26ff8f 100644 --- a/mpc-recovery/src/lib.rs +++ b/mpc-recovery/src/lib.rs @@ -1,5 +1,6 @@ use actix_rt::task::JoinHandle; use actor::NodeActor; +use ouath::UniversalTokenVerifier; use ractor::{Actor, ActorRef}; use ractor_cluster::{node::NodeConnectionMode, NodeServer}; use threshold_crypto::{PublicKeySet, SecretKeySet, SecretKeyShare}; @@ -44,7 +45,7 @@ async fn start_actor( node_id: u64, pk_set: PublicKeySet, sk_share: SecretKeyShare, -) -> anyhow::Result<(ActorRef, JoinHandle<()>)> { +) -> anyhow::Result<(ActorRef>, JoinHandle<()>)> { // Printing shortened hash should be enough for most use cases, but if you enable TRACE level // you can see the entire curve details. if tracing::level_enabled!(tracing::Level::TRACE) { @@ -52,9 +53,13 @@ async fn start_actor( } else { tracing::debug!(public_key = ?pk_set.public_key(), "starting node actor"); } - Actor::spawn(None, actor::NodeActor, (node_id, pk_set.clone(), sk_share)) - .await - .map_err(|_e| anyhow::anyhow!("failed to start actor")) + Actor::spawn( + None, + actor::NodeActor::new(), + (node_id, pk_set.clone(), sk_share), + ) + .await + .map_err(|_e| anyhow::anyhow!("failed to start actor")) } #[tracing::instrument(level = "debug", skip_all, fields(id = node_id))] diff --git a/mpc-recovery/src/ouath.rs b/mpc-recovery/src/ouath.rs index b96289ffe..598b03b36 100644 --- a/mpc-recovery/src/ouath.rs +++ b/mpc-recovery/src/ouath.rs @@ -1,6 +1,6 @@ #[async_trait::async_trait] pub trait OAuthTokenVerifier { - async fn verify_token(&self, token: &str) -> Option<&str>; + async fn verify_token(token: &str) -> Option<&str>; } pub enum SupportedTokenVerifiers { @@ -13,14 +13,14 @@ pub struct UniversalTokenVerifier {} #[async_trait::async_trait] impl OAuthTokenVerifier for UniversalTokenVerifier { - async fn verify_token(&self, token: &str) -> Option<&str> { + async fn verify_token(token: &str) -> Option<&str> { // TODO: here we assume that verifier type can be determined from the token match get_token_verifier_type(token) { SupportedTokenVerifiers::GoogleTokenVerifier => { - return GoogleTokenVerifier {}.verify_token(token).await; + return GoogleTokenVerifier::verify_token(token).await; } SupportedTokenVerifiers::TestTokenVerifier => { - return TestTokenVerifier {}.verify_token(token).await; + return TestTokenVerifier::verify_token(token).await; } } } @@ -40,7 +40,7 @@ pub struct GoogleTokenVerifier {} #[async_trait::async_trait] impl OAuthTokenVerifier for GoogleTokenVerifier { // TODO: replace with real implementation - async fn verify_token(&self, token: &str) -> Option<&str> { + async fn verify_token(token: &str) -> Option<&str> { match token { "validToken" => Some("testAccountId"), _ => None, @@ -53,7 +53,7 @@ pub struct TestTokenVerifier {} #[async_trait::async_trait] impl OAuthTokenVerifier for TestTokenVerifier { - async fn verify_token(&self, token: &str) -> Option<&str> { + async fn verify_token(token: &str) -> Option<&str> { match token { "validToken" => Some("testAccountId"), _ => None, @@ -63,40 +63,35 @@ impl OAuthTokenVerifier for TestTokenVerifier { #[tokio::test] async fn test_verify_token_valid() { - let verifier = TestTokenVerifier {}; let token = "validToken"; - let account_id = verifier.verify_token(token).await.unwrap(); + let account_id = TestTokenVerifier::verify_token(token).await.unwrap(); assert_eq!(account_id, "testAccountId"); } #[tokio::test] async fn test_verify_token_invalid_with_test_verifier() { - let verifier = TestTokenVerifier {}; let token = "invalid"; - let account_id = verifier.verify_token(token).await; + let account_id = TestTokenVerifier::verify_token(token).await; assert_eq!(account_id, None); } #[tokio::test] async fn test_verify_token_valid_with_test_verifier() { - let verifier = TestTokenVerifier {}; let token = "validToken"; - let account_id = verifier.verify_token(token).await.unwrap(); + let account_id = TestTokenVerifier::verify_token(token).await.unwrap(); assert_eq!(account_id, "testAccountId"); } #[tokio::test] async fn test_verify_token_invalid_with_universal_verifier() { - let verifier = UniversalTokenVerifier {}; let token = "invalid"; - let account_id = verifier.verify_token(token).await; + let account_id = UniversalTokenVerifier::verify_token(token).await; assert_eq!(account_id, None); } #[tokio::test] async fn test_verify_token_valid_with_universal_verifier() { - let verifier = UniversalTokenVerifier {}; let token = "validToken"; - let account_id = verifier.verify_token(token).await.unwrap(); + let account_id = UniversalTokenVerifier::verify_token(token).await.unwrap(); assert_eq!(account_id, "testAccountId"); } diff --git a/mpc-recovery/src/web.rs b/mpc-recovery/src/web.rs index 0b6901038..1b74d52ed 100644 --- a/mpc-recovery/src/web.rs +++ b/mpc-recovery/src/web.rs @@ -5,11 +5,12 @@ use std::net::SocketAddr; use crate::{ actor::{NodeActor, NodeMessage}, + ouath::UniversalTokenVerifier, NodeId, }; #[tracing::instrument(level = "debug", skip(node_actor))] -pub async fn serve(id: NodeId, port: u16, node_actor: ActorRef) { +pub async fn serve(id: NodeId, port: u16, node_actor: ActorRef>) { let state = AppState { id, node_actor }; let app = Router::new() @@ -32,7 +33,7 @@ struct SubmitPayload { #[derive(Clone)] struct AppState { id: NodeId, - node_actor: ActorRef, + node_actor: ActorRef>, } #[tracing::instrument(level = "debug", skip_all, fields(id = state.id))] From f0b58e5c08cf11afe029ad5144f7b2a5110eb035 Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Thu, 6 Apr 2023 09:41:18 +0300 Subject: [PATCH 08/12] deelte files to avoid conflict --- mpc-recovery/src/actor.rs | 247 -------------------------------------- mpc-recovery/src/web.rs | 76 ------------ 2 files changed, 323 deletions(-) delete mode 100644 mpc-recovery/src/actor.rs delete mode 100644 mpc-recovery/src/web.rs diff --git a/mpc-recovery/src/actor.rs b/mpc-recovery/src/actor.rs deleted file mode 100644 index 293ff3978..000000000 --- a/mpc-recovery/src/actor.rs +++ /dev/null @@ -1,247 +0,0 @@ -use std::marker::PhantomData; - -use futures::prelude::*; -use futures::stream::FuturesUnordered; -use ractor::{ - concurrency::Duration, Actor, ActorProcessingErr, ActorRef, BytesConvertable, RpcReplyPort, -}; -use ractor_cluster::RactorClusterMessage; -use serde::{Deserialize, Serialize}; -use threshold_crypto::{PublicKeySet, SecretKeyShare, Signature, SignatureShare}; - -use crate::ouath::OAuthTokenVerifier; -use crate::NodeId; - -const MPC_RECOVERY_GROUP: &str = "mpc-recovery"; - -type Payload = Vec; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SignResponse { - node_id: NodeId, - sig_share: SignatureShare, -} - -impl BytesConvertable for SignResponse { - fn into_bytes(self) -> Vec { - serde_json::to_vec(&self).unwrap() - } - - fn from_bytes(bytes: Vec) -> Self { - serde_json::from_slice(&bytes).unwrap() - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SignatureResponse { - pub sig: Signature, -} - -impl BytesConvertable for SignatureResponse { - fn into_bytes(self) -> Vec { - serde_json::to_vec(&self).unwrap() - } - - fn from_bytes(bytes: Vec) -> Self { - serde_json::from_slice(&bytes).unwrap() - } -} - -#[derive(RactorClusterMessage, Debug)] -pub enum NodeMessage { - #[rpc] - NewRequest(Payload, RpcReplyPort), - #[rpc] - SignRequest(Payload, RpcReplyPort), -} - -pub struct NodeActor { - el: PhantomData, -} - -impl NodeActor { - pub fn new() -> NodeActor { - NodeActor { el: PhantomData } - } -} - -pub struct NodeActorState { - id: NodeId, - pk_set: PublicKeySet, - sk_share: SecretKeyShare, - el: PhantomData, -} - -#[async_trait::async_trait] -impl Actor for NodeActor { - type Msg = NodeMessage; - type State = NodeActorState; - type Arguments = (NodeId, PublicKeySet, SecretKeyShare); - - #[tracing::instrument(level = "debug", skip_all, fields(id = args.0))] - async fn pre_start( - &self, - myself: ActorRef, - args: (NodeId, PublicKeySet, SecretKeyShare), - ) -> Result { - tracing::debug!(group = MPC_RECOVERY_GROUP, "joining"); - ractor::pg::join(MPC_RECOVERY_GROUP.to_string(), vec![myself.get_cell()]); - // create the initial state - Ok(NodeActorState { - id: args.0, - pk_set: args.1, - sk_share: args.2, - el: PhantomData, - }) - } - - #[tracing::instrument(level = "debug", skip_all, fields(id = state.id, message))] - async fn handle( - &self, - _myself: ActorRef, - message: Self::Msg, - state: &mut Self::State, - ) -> Result<(), ActorProcessingErr> { - let remote_actors = ractor::pg::get_members(&MPC_RECOVERY_GROUP.to_string()) - .into_iter() - .filter(|actor| !actor.get_id().is_local()) - .map(ActorRef::::from) - .collect::>(); - tracing::debug!( - remote_actors = ?remote_actors.iter().map(|a| a.get_id()).collect::>(), - "connected to" - ); - - match message { - NodeMessage::NewRequest(payload, reply) => { - tracing::debug!(?payload, "new request"); - state - .handle_new_request(payload, reply, &remote_actors) - .await - } - NodeMessage::SignRequest(payload, reply) => { - tracing::debug!(?payload, "sign request"); - state.handle_signed_msg(payload, reply).await - } - }; - Ok(()) - } -} - -impl NodeActorState { - fn sign(&self, payload: &[u8]) -> SignResponse { - SignResponse { - node_id: self.id, - sig_share: self.sk_share.sign(payload), - } - } - - #[tracing::instrument(level = "debug", skip_all)] - async fn handle_signed_msg(&mut self, payload: Payload, reply: RpcReplyPort) { - // TODO: extract access token from payload - let access_token = "validToken"; - match O::verify_token(access_token).await { - Some(client_id) => { - tracing::debug!("approved, cleintId: {}", client_id); - - let response = self.sign(&payload); - tracing::debug!(?response, "replying"); - - match reply.send(response) { - Ok(()) => {} - Err(e) => tracing::error!("failed to respond: {}", e), - }; - } - None => { - tracing::error!("failed to verify access token"); - } - } - } - - async fn handle_new_request( - &mut self, - payload: Payload, - reply: RpcReplyPort, - remote_actors: &Vec>>, - ) { - // TODO: extract access token from payload - let access_token = "validToken"; - match O::verify_token(access_token).await { - Some(client_id) => { - tracing::debug!("approved, cleintId: {}", client_id); - let mut futures = Vec::new(); - for actor in remote_actors { - tracing::debug!(actor = ?actor.get_id(), "asking actor"); - let future = actor - .call( - |tx| NodeMessage::SignRequest(payload.clone(), tx), - Some(Duration::from_millis(2000)), - ) - .map(|r| r.map_err(ractor::RactorErr::from)) - .map(|r| match r { - Ok(ractor::rpc::CallResult::Success(ok_value)) => Ok(ok_value), - Ok(cr) => Err(ractor::RactorErr::from(cr)), - Err(e) => Err(e), - }); - futures.push(future); - } - - // create unordered collection of futures - let futures = futures.into_iter().collect::>(); - - let mut responses = futures - .collect::>() - .await - .into_iter() - .filter_map(|r| r.ok()) - .collect::>(); - - let response = self.sign(&payload); - tracing::debug!(?response, "adding response from self"); - responses.push(response); - - tracing::debug!( - ?responses, - "got {} successful responses total", - responses.len() - ); - - let mut sig_shares = Vec::new(); - for sign_response in &responses { - if self - .pk_set - .public_key_share(sign_response.node_id) - .verify(&sign_response.sig_share, &payload) - { - sig_shares.push((sign_response.node_id, &sign_response.sig_share)); - } else { - tracing::error!(?sign_response, "received invalid signature",); - } - } - - tracing::debug!( - ?sig_shares, - "got {} valid signature shares total", - sig_shares.len() - ); - - if let Ok(sig) = self - .pk_set - .combine_signatures(sig_shares.clone().into_iter()) - { - tracing::debug!(?sig, "replying with full signature"); - reply.send(SignatureResponse { sig }).unwrap(); - } else { - tracing::error!( - "expected to get at least {} shares, but only got {}", - self.pk_set.threshold() + 1, - sig_shares.len() - ); - } - } - None => { - tracing::error!("failed to verify access token"); - } - } - } -} diff --git a/mpc-recovery/src/web.rs b/mpc-recovery/src/web.rs deleted file mode 100644 index 1b74d52ed..000000000 --- a/mpc-recovery/src/web.rs +++ /dev/null @@ -1,76 +0,0 @@ -use axum::{extract::State, http::StatusCode, routing::post, Json, Router}; -use ractor::{concurrency::Duration, rpc::CallResult, ActorRef}; -use serde::Deserialize; -use std::net::SocketAddr; - -use crate::{ - actor::{NodeActor, NodeMessage}, - ouath::UniversalTokenVerifier, - NodeId, -}; - -#[tracing::instrument(level = "debug", skip(node_actor))] -pub async fn serve(id: NodeId, port: u16, node_actor: ActorRef>) { - let state = AppState { id, node_actor }; - - let app = Router::new() - .route("/submit", post(submit)) - .with_state(state); - - let addr = SocketAddr::from(([0, 0, 0, 0], port)); - tracing::debug!(?addr, "starting a web server"); - axum::Server::bind(&addr) - .serve(app.into_make_service()) - .await - .unwrap(); -} - -#[derive(Deserialize)] -struct SubmitPayload { - payload: String, -} - -#[derive(Clone)] -struct AppState { - id: NodeId, - node_actor: ActorRef>, -} - -#[tracing::instrument(level = "debug", skip_all, fields(id = state.id))] -async fn submit( - State(state): State, - Json(payload): Json, -) -> (StatusCode, Json) { - tracing::info!(payload = payload.payload, "submit request"); - - match state - .node_actor - .call( - |tx| NodeMessage::NewRequest(payload.payload.bytes().collect(), tx), - Some(Duration::from_millis(2000)), - ) - .await - { - Ok(call_result) => match call_result { - CallResult::Success(sig_response) => ( - StatusCode::OK, - Json(hex::encode(sig_response.sig.to_bytes())), - ), - CallResult::Timeout => { - tracing::error!("failed due to timeout"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json("timeout".to_string()), - ) - } - CallResult::SenderError => { - tracing::error!("failed due to sender error (did not get a response)"); - (StatusCode::INTERNAL_SERVER_ERROR, Json("error".to_string())) - } - }, - Err(e) => { - tracing::error!("failed due to messaging error: {}", e); - (StatusCode::INTERNAL_SERVER_ERROR, Json("error".to_string())) - } - } -} From c2f94346ba9aaa73344a4068426c4ed12d49ec1a Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Thu, 6 Apr 2023 16:37:16 +0300 Subject: [PATCH 09/12] refactored for new arch --- mpc-recovery/src/leader_node/mod.rs | 51 ++++++++++++++++--------- mpc-recovery/src/lib.rs | 1 + mpc-recovery/src/main.rs | 2 +- mpc-recovery/src/msg.rs | 9 +++-- mpc-recovery/src/{ouath.rs => oauth.rs} | 0 mpc-recovery/src/sign_node/mod.rs | 31 ++++++++++----- 6 files changed, 63 insertions(+), 31 deletions(-) rename mpc-recovery/src/{ouath.rs => oauth.rs} (100%) diff --git a/mpc-recovery/src/leader_node/mod.rs b/mpc-recovery/src/leader_node/mod.rs index 380eaaa5a..d171fd83d 100644 --- a/mpc-recovery/src/leader_node/mod.rs +++ b/mpc-recovery/src/leader_node/mod.rs @@ -1,4 +1,5 @@ use crate::msg::{LeaderRequest, LeaderResponse, SigShareRequest, SigShareResponse}; +use crate::oauth::{OAuthTokenVerifier, UniversalTokenVerifier}; use crate::NodeId; use axum::{extract::State, http::StatusCode, routing::post, Json, Router}; use futures::stream::FuturesUnordered; @@ -32,7 +33,7 @@ pub async fn run( }; let app = Router::new() - .route("/submit", post(submit)) + .route("/submit", post(submit::)) .with_state(state); let addr = SocketAddr::from(([0, 0, 0, 0], port)); @@ -58,14 +59,24 @@ async fn parse(response_future: ResponseFuture) -> anyhow::Result( State(state): State, Json(request): Json, ) -> (StatusCode, Json) { tracing::info!(payload = request.payload, "submit request"); - // TODO: run some check that the payload makes sense, fail if not - tracing::debug!("approved"); + // TODO: extract access token from payload + let access_token = "validToken"; + match T::verify_token(access_token).await { + Some(_) => { + tracing::info!("access token is valid"); + // continue execution + } + None => { + tracing::error!("access token verification failed"); + return (StatusCode::UNAUTHORIZED, Json(LeaderResponse::Err)); + } + } let sig_share_request = SigShareRequest { payload: request.payload.clone(), @@ -100,42 +111,48 @@ async fn submit( let mut sig_shares = BTreeMap::new(); sig_shares.insert(state.id, state.sk_share.sign(&request.payload)); for response_future in response_futures { - let response = match parse(response_future).await { - Ok(response) => response, + let (node_id, sig_share) = match parse(response_future).await { + Ok(response) => match response { + SigShareResponse::Ok { node_id, sig_share } => (node_id, sig_share), + SigShareResponse::Err => { + tracing::error!("Received an error response"); + continue; + } + }, Err(err) => { - tracing::error!(%err, "failed to get response"); + tracing::error!(%err, "Failed to get response"); continue; } }; if state .pk_set - .public_key_share(response.node_id) - .verify(&response.sig_share, &request.payload) + .public_key_share(node_id) + .verify(&sig_share, &request.payload) { - match sig_shares.entry(response.node_id) { + match sig_shares.entry(node_id) { Entry::Vacant(e) => { - tracing::debug!(?response, "received valid signature share"); - e.insert(response.sig_share); + tracing::debug!(?sig_share, "received valid signature share"); + e.insert(sig_share); } - Entry::Occupied(e) if e.get() == &response.sig_share => { + Entry::Occupied(e) if e.get() == &sig_share => { tracing::error!( - node_id = response.node_id, + node_id, sig_share = ?e.get(), "received a duplicate share" ); } Entry::Occupied(e) => { tracing::error!( - node_id = response.node_id, + node_id = node_id, sig_share_1 = ?e.get(), - sig_share_2 = ?response.sig_share, + sig_share_2 = ?sig_share, "received two different valid shares for the same node (should be impossible)" ); } } } else { - tracing::error!(?response, "received invalid signature",); + tracing::error!("received invalid signature",); } if sig_shares.len() > state.pk_set.threshold() { diff --git a/mpc-recovery/src/lib.rs b/mpc-recovery/src/lib.rs index 855e9792b..7be1b3758 100644 --- a/mpc-recovery/src/lib.rs +++ b/mpc-recovery/src/lib.rs @@ -3,6 +3,7 @@ use threshold_crypto::{PublicKeySet, SecretKeySet, SecretKeyShare}; mod leader_node; pub mod msg; mod sign_node; +mod oauth; type NodeId = u64; diff --git a/mpc-recovery/src/main.rs b/mpc-recovery/src/main.rs index 79122a47f..4ada12afc 100644 --- a/mpc-recovery/src/main.rs +++ b/mpc-recovery/src/main.rs @@ -1,4 +1,4 @@ -mod ouath; +mod oauth; use clap::Parser; use threshold_crypto::{serde_impl::SerdeSecret, PublicKeySet, SecretKeyShare}; diff --git a/mpc-recovery/src/msg.rs b/mpc-recovery/src/msg.rs index ecddbe419..680dfc18e 100644 --- a/mpc-recovery/src/msg.rs +++ b/mpc-recovery/src/msg.rs @@ -26,9 +26,12 @@ pub struct SigShareRequest { } #[derive(Serialize, Deserialize, Debug)] -pub struct SigShareResponse { - pub node_id: NodeId, - pub sig_share: SignatureShare, +pub enum SigShareResponse { + Ok { + node_id: NodeId, + sig_share: SignatureShare, + }, + Err, } mod hex_sig_share { diff --git a/mpc-recovery/src/ouath.rs b/mpc-recovery/src/oauth.rs similarity index 100% rename from mpc-recovery/src/ouath.rs rename to mpc-recovery/src/oauth.rs diff --git a/mpc-recovery/src/sign_node/mod.rs b/mpc-recovery/src/sign_node/mod.rs index 9e00d699f..bcc4b9db1 100644 --- a/mpc-recovery/src/sign_node/mod.rs +++ b/mpc-recovery/src/sign_node/mod.rs @@ -1,4 +1,5 @@ use crate::msg::{SigShareRequest, SigShareResponse}; +use crate::oauth::{OAuthTokenVerifier, UniversalTokenVerifier}; use crate::NodeId; use axum::{extract::State, http::StatusCode, routing::post, Json, Router}; use std::net::SocketAddr; @@ -15,7 +16,9 @@ pub async fn run(id: NodeId, pk_set: PublicKeySet, sk_share: SecretKeyShare, por let state = SignNodeState { id, sk_share }; - let app = Router::new().route("/sign", post(sign)).with_state(state); + let app = Router::new() + .route("/sign", post(sign::)) + .with_state(state); let addr = SocketAddr::from(([0, 0, 0, 0], port)); tracing::debug!(?addr, "starting http server"); @@ -32,18 +35,26 @@ struct SignNodeState { } #[tracing::instrument(level = "debug", skip_all, fields(id = state.id))] -async fn sign( +async fn sign( State(state): State, Json(request): Json, ) -> (StatusCode, Json) { tracing::info!(payload = request.payload, "sign request"); - // TODO: run some check that the payload makes sense, fail if not - tracing::debug!("approved"); - - let response = SigShareResponse { - node_id: state.id, - sig_share: state.sk_share.sign(request.payload), - }; - (StatusCode::OK, Json(response)) + // TODO: extract access token from payload + let access_token = "validToken"; + match T::verify_token(access_token).await { + Some(_) => { + tracing::debug!("access token is valid"); + let response = SigShareResponse::Ok { + node_id: state.id, + sig_share: state.sk_share.sign(request.payload), + }; + (StatusCode::OK, Json(response)) + }, + None => { + tracing::debug!("access token verification failed"); + (StatusCode::UNAUTHORIZED, Json(SigShareResponse::Err)) + } + } } From 9d21188b92c06ee71fe490732bde13562e47f656 Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Thu, 6 Apr 2023 16:56:29 +0300 Subject: [PATCH 10/12] logs in verification --- mpc-recovery/src/oauth.rs | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/mpc-recovery/src/oauth.rs b/mpc-recovery/src/oauth.rs index 598b03b36..72dbe9d8f 100644 --- a/mpc-recovery/src/oauth.rs +++ b/mpc-recovery/src/oauth.rs @@ -29,8 +29,14 @@ impl OAuthTokenVerifier for UniversalTokenVerifier { fn get_token_verifier_type(token: &str) -> SupportedTokenVerifiers { match token.len() { // TODO: add real token type detection - 0 => SupportedTokenVerifiers::GoogleTokenVerifier, - _ => SupportedTokenVerifiers::TestTokenVerifier, + 0 => { + tracing::info!("Using GoogleTokenVerifier"); + SupportedTokenVerifiers::GoogleTokenVerifier + }, + _ => { + tracing::info!("Using TestTokenVerifier"); + SupportedTokenVerifiers::TestTokenVerifier + }, } } @@ -42,8 +48,14 @@ impl OAuthTokenVerifier for GoogleTokenVerifier { // TODO: replace with real implementation async fn verify_token(token: &str) -> Option<&str> { match token { - "validToken" => Some("testAccountId"), - _ => None, + "validToken" => { + tracing::info!("GoogleTokenVerifier: access token is valid"); + Some("testAccountId") + }, + _ => { + tracing::info!("GoogleTokenVerifier: access token verification failed"); + None + } } } } @@ -55,8 +67,14 @@ pub struct TestTokenVerifier {} impl OAuthTokenVerifier for TestTokenVerifier { async fn verify_token(token: &str) -> Option<&str> { match token { - "validToken" => Some("testAccountId"), - _ => None, + "validToken" => { + tracing::info!("TestTokenVerifier: access token is valid"); + Some("testAccountId") + }, + _ => { + tracing::info!("TestTokenVerifier: access token verification failed"); + None + } } } } From 75a2b16e2df0424483fab26285b17af9360f472f Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Thu, 6 Apr 2023 17:02:37 +0300 Subject: [PATCH 11/12] fmt --- mpc-recovery/src/lib.rs | 2 +- mpc-recovery/src/oauth.rs | 8 ++++---- mpc-recovery/src/sign_node/mod.rs | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mpc-recovery/src/lib.rs b/mpc-recovery/src/lib.rs index 7be1b3758..b21d640cc 100644 --- a/mpc-recovery/src/lib.rs +++ b/mpc-recovery/src/lib.rs @@ -2,8 +2,8 @@ use threshold_crypto::{PublicKeySet, SecretKeySet, SecretKeyShare}; mod leader_node; pub mod msg; -mod sign_node; mod oauth; +mod sign_node; type NodeId = u64; diff --git a/mpc-recovery/src/oauth.rs b/mpc-recovery/src/oauth.rs index 72dbe9d8f..d638a417e 100644 --- a/mpc-recovery/src/oauth.rs +++ b/mpc-recovery/src/oauth.rs @@ -32,11 +32,11 @@ fn get_token_verifier_type(token: &str) -> SupportedTokenVerifiers { 0 => { tracing::info!("Using GoogleTokenVerifier"); SupportedTokenVerifiers::GoogleTokenVerifier - }, + } _ => { tracing::info!("Using TestTokenVerifier"); SupportedTokenVerifiers::TestTokenVerifier - }, + } } } @@ -51,7 +51,7 @@ impl OAuthTokenVerifier for GoogleTokenVerifier { "validToken" => { tracing::info!("GoogleTokenVerifier: access token is valid"); Some("testAccountId") - }, + } _ => { tracing::info!("GoogleTokenVerifier: access token verification failed"); None @@ -70,7 +70,7 @@ impl OAuthTokenVerifier for TestTokenVerifier { "validToken" => { tracing::info!("TestTokenVerifier: access token is valid"); Some("testAccountId") - }, + } _ => { tracing::info!("TestTokenVerifier: access token verification failed"); None diff --git a/mpc-recovery/src/sign_node/mod.rs b/mpc-recovery/src/sign_node/mod.rs index bcc4b9db1..344e852db 100644 --- a/mpc-recovery/src/sign_node/mod.rs +++ b/mpc-recovery/src/sign_node/mod.rs @@ -51,7 +51,7 @@ async fn sign( sig_share: state.sk_share.sign(request.payload), }; (StatusCode::OK, Json(response)) - }, + } None => { tracing::debug!("access token verification failed"); (StatusCode::UNAUTHORIZED, Json(SigShareResponse::Err)) From 04be724221c2c8adbdb6984eeeddb5238e97cf9e Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Thu, 6 Apr 2023 17:12:59 +0300 Subject: [PATCH 12/12] clippy --- mpc-recovery/src/msg.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/mpc-recovery/src/msg.rs b/mpc-recovery/src/msg.rs index 680dfc18e..9053c6911 100644 --- a/mpc-recovery/src/msg.rs +++ b/mpc-recovery/src/msg.rs @@ -26,6 +26,7 @@ pub struct SigShareRequest { } #[derive(Serialize, Deserialize, Debug)] +#[allow(clippy::large_enum_variant)] pub enum SigShareResponse { Ok { node_id: NodeId,