From 7a9dc6fe08da906d5eea018fa0f8327c87e0b026 Mon Sep 17 00:00:00 2001 From: PramUkesh Date: Mon, 22 Jan 2024 12:02:49 +0530 Subject: [PATCH] for testing --- src/core/user.rs | 1 + src/environment.rs | 4 + src/plugins/websocket/events.rs | 4 +- src/plugins/websocket/handlers/sites.rs | 187 ++++++++++++++++++------ src/plugins/websocket/handlers/users.rs | 9 +- src/plugins/websocket/mod.rs | 103 ++++++++++--- src/plugins/websocket/request.rs | 25 +++- src/plugins/websocket/response.rs | 17 ++- 8 files changed, 269 insertions(+), 81 deletions(-) diff --git a/src/core/user.rs b/src/core/user.rs index dcf6d4f..6cba115 100644 --- a/src/core/user.rs +++ b/src/core/user.rs @@ -36,6 +36,7 @@ pub mod models { #[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Debug)] pub struct Cert { + #[serde(flatten)] auth_pair: AuthPair, pub auth_type: String, pub auth_user_name: String, diff --git a/src/environment.rs b/src/environment.rs index 94e7202..651fa23 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -104,6 +104,8 @@ pub struct Environment { pub size_limit: usize, pub file_size_limit: usize, pub site_peers_need: usize, + #[cfg(debug_assertions)] + pub debug: bool, } fn get_matches() -> ArgMatches { @@ -351,6 +353,8 @@ pub fn get_env(matches: &ArgMatches) -> Result { .get_one::("SITE_PEERS_NEED") .unwrap() .parse()?, + #[cfg(debug_assertions)] + debug: true, }; Ok(env) } diff --git a/src/plugins/websocket/events.rs b/src/plugins/websocket/events.rs index 19c755d..fe31123 100644 --- a/src/plugins/websocket/events.rs +++ b/src/plugins/websocket/events.rs @@ -38,8 +38,8 @@ impl Handler for WebsocketController { #[serde(rename_all = "camelCase")] pub enum ServerEvent { Event { cmd: String, params: EventType }, - Notification { cmd: String, params: Value }, - Confirm { cmd: String, params: Value }, + Notification { cmd: String, id: usize, params: Value }, + Confirm { cmd: String, id: usize, params: Value }, } #[allow(clippy::enum_variant_names)] diff --git a/src/plugins/websocket/handlers/sites.rs b/src/plugins/websocket/handlers/sites.rs index 45d78df..44423fc 100644 --- a/src/plugins/websocket/handlers/sites.rs +++ b/src/plugins/websocket/handlers/sites.rs @@ -2,15 +2,18 @@ use actix::AsyncContext; use actix_web_actors::ws::WebsocketContext; use futures::executor::block_on; use log::*; -use serde::Serialize; -use serde_json::{Value, json}; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; -use super::super::{error::Error, request::Command, response::Message, ZeruWebsocket}; +use super::{ + super::{error::Error, request::Command, response::Message, ZeruWebsocket}, + users::{get_current_user, handle_cert_set}, +}; use crate::{ environment::SITE_PERMISSIONS_DETAILS, plugins::site_server::handlers::{ sites::{DBQueryRequest, SiteInfoListRequest, SiteInfoRequest}, - users::UserSiteData, + users::{UserCertAddRequest, UserCertDeleteRequest, UserSetSiteCertRequest, UserSiteData}, }, plugins::{ site_server::handlers::{ @@ -24,26 +27,105 @@ use crate::{ }, }; -pub fn handle_cert_add( - _: &ZeruWebsocket, - _: &mut WebsocketContext, - _: &Command, -) -> Result { - unimplemented!("Please File a Bug Report") +pub fn handle_cert_add(ws: &mut ZeruWebsocket, command: &Command) -> Result { + let mut msg: UserCertAddRequest = serde_json::from_value(command.params.clone()).unwrap(); + let domain = msg.domain.clone(); + msg.user_addr = String::from("current"); + msg.site_addr = ws.address.address.clone(); + let res = block_on(ws.user_controller.send(msg.clone()))?; + match res { + Err(_) => command.respond("Not changed"), + Ok(false) => { + let user = get_current_user(ws)?; + let current_cert = user.certs.get(&domain).unwrap(); + let body = format!( + "Your current certificate: {}/{}@{}", + current_cert.auth_type, current_cert.auth_user_name, domain, + ); + let txt = format!( + "Change it to {}/{}@{}", + msg.auth_type, msg.auth_user_name, domain + ); + let _ = ws.cmd( + "confirm", + json!([body, txt,]), + Some(Box::new(move |ws, cmd| cert_add_confirm(ws, cmd))), + Some(command.params.clone()), + ); + command.command() + } + Ok(true) => { + let _ = ws.cmd( + "notification", + json!([ + "done", + format!( + "New certificate added: {}/{}@{}", + msg.auth_type, msg.auth_user_name, domain + ) + ]), + None, + None, + ); + let msg = UserSetSiteCertRequest { + user_addr: String::from("current"), + site_addr: ws.address.address.clone(), + provider: domain.clone(), + }; + let _ = block_on(ws.user_controller.send(msg))?; + ws.update_websocket(Some(json!(vec!["cert_changed", &domain]))); + command.respond("ok") + } + } } -pub fn handle_cert_select(ws: &mut ZeruWebsocket, cmd: &Command) -> Result { - let params = cmd.params.as_array().unwrap(); - let accepted_providers = params[0].as_array(); - let accepted_pattern = if let Some(Value::String(pattern)) = params.get(2) { - Some(pattern) - } else { - None +fn cert_add_confirm(ws: &mut ZeruWebsocket, cmd: &Command) -> Option> { + let params = cmd.params.clone(); + let user = String::from("current"); + let mut add_msg: UserCertAddRequest = serde_json::from_value(params.clone()).unwrap(); + add_msg.user_addr = user.clone(); + add_msg.site_addr = ws.address.address.clone(); + + let msg = UserCertDeleteRequest { + user_addr: user.clone(), + domain: add_msg.domain.clone(), }; - let accept_any = if let Some(Value::Bool(value)) = params.get(1) { - *value - } else { - accepted_providers.is_none() || accepted_pattern.is_none() + let _ = block_on(ws.user_controller.send(msg)).unwrap(); + let res = block_on(ws.user_controller.send(add_msg.clone())).unwrap(); + assert!(res.is_ok()); + assert!(res.unwrap()); + let _ = ws.cmd( + "notification", + json!([ + "done", + format!( + "Certificate changed to: {}/{}@{}", + add_msg.auth_type, add_msg.auth_user_name, add_msg.domain + ) + ]), + None, + None, + ); + ws.update_websocket(Some(json!(vec!["cert_changed", &add_msg.domain]))); + Some(cmd.respond("ok")) +} + +#[derive(Deserialize, Debug)] +struct CertSelectRequest { + #[serde(default)] + accepted_providers: Vec, + accepted_pattern: Option, + accept_any: bool, +} + +pub fn handle_cert_select(ws: &mut ZeruWebsocket, cmd: &Command) -> Result { + let CertSelectRequest { + accepted_providers, + accepted_pattern, + mut accept_any, + } = serde_json::from_value(cmd.params.clone()).unwrap(); + if !accept_any { + accept_any = accepted_providers.is_empty() || accepted_pattern.is_none(); }; let site_data = block_on(ws.user_controller.send(UserSiteData { user_addr: String::from("current"), @@ -68,7 +150,7 @@ pub fn handle_cert_select(ws: &mut ZeruWebsocket, cmd: &Command) -> Result Result Result + + accepted_providers.iter().for_each(|provider| { + if !user.certs.contains_key(provider.as_str()) { + body += "
"; + body += &format!( + " Register »{}", - provider, provider - ); - body += "
"; - } - } - }); - } - let script = format!( + provider, provider + ); + body += ""; + } + }); + + let _ = ws.cmd( + "notification", + json!(["ask", body]), + Some(Box::new(move |ws, cmd| Some(handle_cert_set(ws, cmd)))), + None, + ); + let script = notification_script_template(ws.next_message_id - 1); + cmd.inject_script(ws.next_message_id as isize, script) +} + +fn notification_script_template(id: usize) -> String { + format!( " $(\".notification .select.cert\").on(\"click\", function() {{ $(\".notification .select\").removeClass('active') zeroframe.response({}, this.title) return false - }}) - ", - ws.next_message_id - ); - ws.send_notification(json!(["ask", body])); //TODO!: Need callback for response - cmd.inject_script(script) + }})", + id + ) } pub fn handle_site_info(ws: &ZeruWebsocket, command: &Command) -> Result { @@ -165,7 +251,12 @@ pub fn handle_site_info(ws: &ZeruWebsocket, command: &Command) -> Result Result { +pub fn handle_cert_set(ws: &mut ZeruWebsocket, command: &Command) -> Result { + error!("Handling CertSet with command: {:?}", command); let site = ws.address.address.clone(); - let provider = command.params[0].as_str().unwrap().to_string(); + let provider = command.params.as_str().unwrap().to_string(); let _ = block_on(ws.user_controller.send(UserSetSiteCertRequest { user_addr: String::from("current"), site_addr: site, - provider, + provider: provider.clone(), }))?; + error!("Cert set: {}", provider); + ws.update_websocket(Some(json!(vec!["cert_changed", &provider]))); command.respond("ok") } diff --git a/src/plugins/websocket/mod.rs b/src/plugins/websocket/mod.rs index 718c141..204c3a2 100644 --- a/src/plugins/websocket/mod.rs +++ b/src/plugins/websocket/mod.rs @@ -21,7 +21,7 @@ use serde::{Deserialize, Serialize}; use self::{ events::{EventType, ServerEvent, WebsocketController}, handlers::{files::*, sites::*, tracker::*, users::*}, - request::CommandType, + request::{CommandResponse, CommandType}, }; use crate::{ controllers::{sites::SitesController, users::UserController}, @@ -48,6 +48,7 @@ pub fn register_site_plugins(app: App) -> App { let websocket_controller = WebsocketController { listeners: vec![] }.start(); app.app_data(Data::new(websocket_controller)) .service(scope("/ZeroNet-Internal").route("/Websocket", get().to(serve_websocket))) + .route("/Websocket", get().to(serve_websocket)) } pub async fn serve_websocket( @@ -71,14 +72,17 @@ pub async fn serve_websocket( .get(header_name!("host")) .and_then(|h| h.to_str().ok()) .unwrap_or(""); - let origin_host = origin.split("://").collect::>()[1]; - if origin_host != host { - //TODO!: and origin_host not in allowed_ws_origins - let msg = format!( - "Invalid origin: {} (host: {}, allowed: missing_impl)", - origin, host - ); - return Ok(error403(&req, Some(&msg))); + let origin_host = origin.split("://").collect::>(); + if !ENV.debug { + let origin_host = origin_host[1]; + if origin_host != host { + //TODO!: and origin_host not in allowed_ws_origins + let msg = format!( + "Invalid origin: {} (host: {}, allowed: missing_impl)", + origin, host + ); + return Ok(error403(&req, Some(&msg))); + } } let wrapper_key = query.get("wrapper_key").unwrap(); let future = data @@ -100,8 +104,9 @@ pub async fn serve_websocket( site_addr: addr, address, channels: vec![], - next_message_id: 0, + next_message_id: 1, waiting_callbacks: HashMap::new(), + callback_data: HashMap::new(), }; let (addr, res) = WsResponseBuilder::new(websocket, &req, stream) .start_with_addr() @@ -110,7 +115,8 @@ pub async fn serve_websocket( Ok(res) } -type WaitingCallback = Box ()>; +type WaitingCallback = + Box Option>>; pub struct ZeruWebsocket { site_controller: Addr, @@ -121,6 +127,7 @@ pub struct ZeruWebsocket { channels: Vec, next_message_id: usize, waiting_callbacks: HashMap, + callback_data: HashMap, } impl Actor for ZeruWebsocket { @@ -136,13 +143,21 @@ impl StreamHandler> for ZeruWebsocket { match msg.unwrap() { ws::Message::Ping(msg) => ctx.pong(&msg), ws::Message::Text(text) => { + error!("Incoming message: {:?}", text); let command: Command = match serde_json::from_str(&text) { Ok(c) => c, - Err(e) => { - error!( - "Could not deserialize incoming message: {:?} ({:?})", - text, e - ); + Err(_) => { + let cmd_res: CommandResponse = match serde_json::from_str(&text) { + Ok(cmd_res) => cmd_res, + Err(e) => { + error!( + "Could not deserialize incoming message: {:?} ({:?})", + text, e + ); + return; + } + }; + self.handle_response(&cmd_res).unwrap(); return; } }; @@ -295,24 +310,35 @@ impl ZeruWebsocket { cmd: &str, params: Value, callback: Option, + callback_data: Option, ) -> Result<(), Error> { let id = self.next_message_id; self.next_message_id += 1; if let Some(callback) = callback { + error!("Adding callback for id: {}", id); self.waiting_callbacks.insert(id, callback); } + if let Some(callback_data) = callback_data { + trace!("Adding callback data for id: {}", id); + self.callback_data.insert(id, callback_data); + } match cmd { "confirm" => { - self.confirm(params); + self.confirm(id, params); + return Ok(()); + } + "notification" => { + self.send_notification(id, params); return Ok(()); } _ => unimplemented!("Command not implemented: {}", cmd), } } - fn confirm(&mut self, params: Value) { + fn confirm(&mut self, id: usize, params: Value) { self.ws_controller.do_send(ServerEvent::Confirm { cmd: "confirm".to_string(), + id, params, }); } @@ -327,6 +353,28 @@ impl ZeruWebsocket { Ok(res) } + fn handle_response(&mut self, response: &CommandResponse) -> Result<(), Error> { + error!("Handling response: {:?}", response); + let id = response.to as usize; + let callback = self.waiting_callbacks.remove(&id); + if let Some(callback) = callback { + let data = self.callback_data.remove(&id).unwrap_or(response.result.clone()); + let command = Command { + cmd: CommandType::UiServer(SiteInfo), + params: data, + id: response.id, + wrapper_nonce: String::new(), + }; + let res = callback(self, &command); + if let Some(res) = res { + let _ = command.respond(res?); + } + } else { + error!("No callback found for response: {:?}", response); + } + return Ok(()); + } + fn handle_command( &mut self, ctx: &mut ws::WebsocketContext, @@ -334,14 +382,13 @@ impl ZeruWebsocket { ) -> Result<(), Error> { trace!( "Handling command: {:?} with params: {:?}", - command.cmd, - command.params + command.cmd, command.params ); let response = if let CommandType::UiServer(cmd) = &command.cmd { match cmd { Ping => handle_ping(command), ServerInfo => handle_server_info(self, ctx, command), - CertAdd => handle_cert_add(self, ctx, command), + CertAdd => handle_cert_add(self, command), CertSelect => handle_cert_select(self, command), SiteInfo => handle_site_info(self, command), SiteSign => handle_site_sign(self, ctx, command), @@ -393,6 +440,16 @@ impl ZeruWebsocket { }); } } + } else if let CommandType::Response(response) = &command.cmd { + error!("Unhandled Response: {:?}", response); + let callback = self.waiting_callbacks.remove(&0); + if let Some(callback) = callback { + let res = callback(self, command); + if let Some(res) = res { + let _ = command.respond(res?); + } + } + return Ok(()); } else { debug!("Unhandled Plugin command: {:?}", command.cmd); command.respond("ok") @@ -417,6 +474,7 @@ impl ZeruWebsocket { } fn on_event(&mut self, channel: &str, params: &serde_json::Value) -> Result<(), Error> { + error!("Handling event: {} with params: {:?}", channel, params); if !self.channels.contains(&channel.to_string()) { return Ok(()); } @@ -463,9 +521,10 @@ impl ZeruWebsocket { Ok(()) } - fn send_notification(&mut self, params: serde_json::Value) { + fn send_notification(&mut self, id: usize, params: serde_json::Value) { let _ = self.ws_controller.do_send(ServerEvent::Notification { cmd: "notification".to_string(), + id, params, }); } diff --git a/src/plugins/websocket/request.rs b/src/plugins/websocket/request.rs index 919550a..03783ed 100644 --- a/src/plugins/websocket/request.rs +++ b/src/plugins/websocket/request.rs @@ -6,11 +6,18 @@ use crate::utils::is_default; #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(untagged)] pub enum CommandType { + Response(ResponseCommandType), UiServer(UiServerCommandType), Admin(AdminCommandType), Plugin(PluginCommands), } +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename = "camelCase")] +pub enum ResponseCommandType { + Response, +} + #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] pub enum UiServerCommandType { @@ -133,8 +140,8 @@ impl Command { Ok(resp) } - pub fn inject_script(&self, body: T) -> Result { - let resp = Message::inject_script(self.id, serde_json::to_value(body)?); + pub fn inject_script(&self, id: isize, body: T) -> Result { + let resp = Message::inject_script(id, serde_json::to_value(body)?); Ok(resp) } @@ -143,3 +150,17 @@ impl Command { Ok(resp) } } +#[derive(Serialize, Deserialize, Debug)] +pub struct CommandResponse { + pub cmd: String, + pub id: isize, + pub to: isize, + pub result: serde_json::Value, +} + +impl CommandResponse { + pub fn respond(&self, body: T) -> Result { + let resp = Message::new(self.id, serde_json::to_value(body)?); + Ok(resp) + } +} diff --git a/src/plugins/websocket/response.rs b/src/plugins/websocket/response.rs index f92e9d9..9529579 100644 --- a/src/plugins/websocket/response.rs +++ b/src/plugins/websocket/response.rs @@ -3,17 +3,19 @@ use serde_json::json; use crate::utils::is_default; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] pub struct Message { cmd: MessageType, - #[serde(skip_serializing_if = "is_default")] pub id: usize, #[serde(skip_serializing_if = "is_default")] to: isize, + #[serde(skip_serializing_if = "is_default")] result: serde_json::Value, + #[serde(skip_serializing_if = "is_default")] + params: serde_json::Value, } -#[derive(Serialize, Deserialize, PartialEq)] +#[derive(Serialize, Deserialize, PartialEq, Debug)] #[serde(rename_all = "camelCase")] pub enum MessageType { Command, @@ -30,6 +32,7 @@ impl Message { to: id, result: body, id: 0, + params: json!(null), } } @@ -37,8 +40,9 @@ impl Message { Message { cmd: MessageType::InjectScript, to: id, - result: body, + result: json!(null), id: 0, + params: body, } } @@ -46,11 +50,16 @@ impl Message { self.cmd == MessageType::Command } + pub fn is_inject_script(&self) -> bool { + self.cmd == MessageType::InjectScript + } + pub fn command() -> Message { Message { cmd: MessageType::Command, to: 0, result: json!(null), + params: json!(null), id: 0, } }