diff --git a/Cargo.toml b/Cargo.toml index ae74ba0..1e92cbe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,9 +24,12 @@ cfg-if = "1.0.0" axum_crate = { package = "axum", version = "0.6.1", features = ["ws"], optional = true } tokio-tungstenite = { version = "0.18.0", optional = true } +actix-web_crate = { package = "actix-web", version = "4.3.1", optional = true } +actix-http = { version = "3.3.1", optional = true } tokio-rustls = { version = "0.23.4", optional = true } tokio-native-tls = { version = "0.3.1", optional = true } + [features] default = ["client", "server"] @@ -35,6 +38,7 @@ client = ["tokio-tungstenite"] server = [] tungstenite = ["server", "tokio-tungstenite"] axum = ["server", "axum_crate"] +actix-web = ["actix-web_crate", "actix-http"] tls = [] native-tls = ["tls", "tokio-native-tls", "tokio-tungstenite/native-tls"] @@ -56,6 +60,7 @@ members = [ "examples/chat-client", "examples/chat-server", "examples/chat-server-axum", + "examples/chat-server-actix-web", "examples/echo-server", "examples/echo-server-native-tls", "examples/simple-client", diff --git a/examples/chat-server-actix-web/Cargo.toml b/examples/chat-server-actix-web/Cargo.toml new file mode 100644 index 0000000..2d1db5e --- /dev/null +++ b/examples/chat-server-actix-web/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "ezsockets-chat-actix-web" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +actix-web = "4.3.1" +async-trait = "0.1.52" +ezsockets = { path = "../../", features = ["actix-web"] } +tokio = { version = "1.17.0", features = ["full"] } +tracing = "0.1.32" +tracing-subscriber = "0.3.9" diff --git a/examples/chat-server-actix-web/src/main.rs b/examples/chat-server-actix-web/src/main.rs new file mode 100644 index 0000000..1f59928 --- /dev/null +++ b/examples/chat-server-actix-web/src/main.rs @@ -0,0 +1,133 @@ +use actix_web::App; +use actix_web::HttpRequest; +use actix_web::HttpResponse; +use actix_web::HttpServer; +use actix_web::web; +use async_trait::async_trait; +use ezsockets::Error; +use ezsockets::Server; +use ezsockets::Socket; +use std::collections::HashMap; +use std::net::SocketAddr; + +type SessionID = u16; +type Session = ezsockets::Session; + +#[derive(Debug)] +enum ChatMessage { + Send { from: SessionID, text: String }, +} + +struct ChatServer { + sessions: HashMap, + handle: Server, +} + +#[async_trait] +impl ezsockets::ServerExt for ChatServer { + type Session = ChatSession; + type Call = ChatMessage; + + async fn on_connect( + &mut self, + socket: Socket, + _address: SocketAddr, + _args: ::Args, + ) -> Result { + let id = (0..).find(|i| !self.sessions.contains_key(i)).unwrap_or(0); + let session = Session::create( + |_| ChatSession { + id, + server: self.handle.clone(), + }, + id, + socket, + ); + self.sessions.insert(id, session.clone()); + Ok(session) + } + + async fn on_disconnect( + &mut self, + id: ::ID, + ) -> Result<(), Error> { + assert!(self.sessions.remove(&id).is_some()); + Ok(()) + } + + async fn on_call(&mut self, call: Self::Call) -> Result<(), Error> { + match call { + ChatMessage::Send { text, from } => { + let sessions = self.sessions.iter().filter(|(id, _)| from != **id); + let text = format!("from {from}: {text}"); + for (id, handle) in sessions { + tracing::info!("sending {text} to {id}"); + handle.text(text.clone()); + } + } + }; + Ok(()) + } +} + +struct ChatSession { + id: SessionID, + server: Server, +} + +#[async_trait] +impl ezsockets::SessionExt for ChatSession { + type ID = SessionID; + type Args = (); + type Call = (); + + fn id(&self) -> &Self::ID { + &self.id + } + async fn on_text(&mut self, text: String) -> Result<(), Error> { + tracing::info!("received: {text}"); + self.server.call(ChatMessage::Send { + from: self.id, + text, + }); + Ok(()) + } + + async fn on_binary(&mut self, _bytes: Vec) -> Result<(), Error> { + unimplemented!() + } + + async fn on_call(&mut self, call: Self::Call) -> Result<(), Error> { + let () = call; + Ok(()) + } +} + +struct AppState { + server: Server, +} + +#[actix_web::main] +async fn main() -> std::io::Result<()> { + tracing_subscriber::fmt::init(); + let (server, _) = Server::create(|handle| ChatServer { + sessions: HashMap::new(), + handle, + }); + HttpServer::new(move || { + App::new() + .route("/ws", web::get().to(index)) + .app_data(web::Data::new(AppState { server: server.clone() })) + }) + .bind(("127.0.0.1", 8080))? + .run() + .await +} + + + +async fn index(req: HttpRequest, stream: web::Payload, data: web::Data) -> Result { + let (resp, id) = ezsockets::actix_web::accept(req, stream, &data.server, ()).await?; + tracing::info!(%id, ?resp, "new connection"); + Ok(resp) +} \ No newline at end of file diff --git a/src/actix_web.rs b/src/actix_web.rs new file mode 100644 index 0000000..2227ebe --- /dev/null +++ b/src/actix_web.rs @@ -0,0 +1,133 @@ +// This code comes mostly from https://github.com/actix/actix-web and actix-web-actors crate + +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; + +use actix_http::ws::hash_key; +pub use actix_http::ws::{CloseCode, CloseReason, Frame, HandshakeError, Message, ProtocolError}; +use actix_web::{ + error::Error, + http::{ + header::{self, HeaderValue}, + Method, StatusCode, + }, + web, HttpRequest, HttpResponse, +}; +use actix_web_crate as actix_web; +use tokio::net::TcpStream; +use tokio_tungstenite::tungstenite; + +use crate::{socket::Config, Server, ServerExt, SessionExt, Socket}; + +pub async fn accept( + req: HttpRequest, + payload: web::Payload, + server: &Server, + args: ::Args, +) -> Result<(HttpResponse, SX::ID), Error> +where + SE: ServerExt, + SX: SessionExt, +{ + // WebSocket accepts only GET + if *req.method() != Method::GET { + Err(HandshakeError::GetMethodRequired)?; + } + + // check for "UPGRADE" to WebSocket header + let has_hdr = if let Some(hdr) = req.headers().get(&header::UPGRADE) { + if let Ok(s) = hdr.to_str() { + s.to_ascii_lowercase().contains("websocket") + } else { + false + } + } else { + false + }; + if !has_hdr { + Err(HandshakeError::NoWebsocketUpgrade)? + } + + // Upgrade connection + if !req.head().upgrade() { + Err(HandshakeError::NoConnectionUpgrade)? + } + + // check supported version + if !req.headers().contains_key(&header::SEC_WEBSOCKET_VERSION) { + Err(HandshakeError::NoVersionHeader)? + } + let supported_ver = { + if let Some(hdr) = req.headers().get(&header::SEC_WEBSOCKET_VERSION) { + hdr == "13" || hdr == "8" || hdr == "7" + } else { + false + } + }; + if !supported_ver { + Err(HandshakeError::UnsupportedVersion)? + } + + // check client handshake for validity + if !req.headers().contains_key(&header::SEC_WEBSOCKET_KEY) { + Err(HandshakeError::BadWebsocketKey)? + } + let key = { + let key = req.headers().get(&header::SEC_WEBSOCKET_KEY).unwrap(); + hash_key(key.as_ref()) + }; + + // TODO: Remove this + let protocols: &[&'static str] = &[]; + // check requested protocols + let protocol = req + .headers() + .get(&header::SEC_WEBSOCKET_PROTOCOL) + .and_then(|req_protocols| { + let req_protocols = req_protocols.to_str().ok()?; + req_protocols + .split(',') + .map(|req_p| req_p.trim()) + .find(|req_p| protocols.iter().any(|p| p == req_p)) + }); + + let mut response = HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS) + .upgrade("websocket") + .insert_header(( + header::SEC_WEBSOCKET_ACCEPT, + // key is known to be header value safe ascii + HeaderValue::from_bytes(&key).unwrap(), + )) + .take(); + + if let Some(protocol) = protocol { + response.insert_header((header::SEC_WEBSOCKET_PROTOCOL, protocol)); + } + + // TODO: Somehow construct a stream that satisfies AsyncRead + AsyncWrite + Unpin + let stream = (|| todo!())(); + // The TcpStream is just for now, to satisfy the trait bounds + let websocket_stream = tokio_tungstenite::WebSocketStream::::from_raw_socket( + stream, + tungstenite::protocol::Role::Server, + None, + ) + .await; + + let socket = Socket::new(websocket_stream, Config::default()); + + let address = req + .peer_addr() + .or_else(|| { + // Using this random address, because the `peer_addr()` is going to return `None` only during the unit test anyways + Some(SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(123, 123, 123, 123), + 1234, + ))) + }) + .unwrap(); + + let session_id = server.accept(socket, address, args).await; + + let response = response.await?; + Ok((response, session_id)) +} diff --git a/src/lib.rs b/src/lib.rs index 9f37645..6f82ad3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,6 +18,9 @@ pub use socket::Stream; #[cfg(feature = "axum")] pub mod axum; +#[cfg(feature = "actix-web")] +pub mod actix_web; + #[cfg(feature = "tokio-tungstenite")] pub mod tungstenite;