diff --git a/src/server/mod.rs b/src/server/mod.rs index efbd3a92d7..08b92e3ef1 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -33,7 +33,7 @@ pub use net::{Fresh, Streaming}; use Error; use buffer::BufReader; -use header::{Headers, Expect}; +use header::{Headers, Expect, Connection}; use http; use method::Method; use net::{NetworkListener, NetworkStream, HttpListener}; @@ -142,7 +142,7 @@ L: NetworkListener + Send + 'static { debug!("threads = {:?}", threads); let pool = ListenerPool::new(listener.clone()); - let work = move |mut stream| handle_connection(&mut stream, &handler); + let work = move |mut stream| Worker(&handler).handle_connection(&mut stream); let guard = thread::spawn(move || pool.accept(work, threads)); @@ -152,62 +152,95 @@ L: NetworkListener + Send + 'static { }) } -fn handle_connection<'h, S, H>(mut stream: &mut S, handler: &'h H) -where S: NetworkStream + Clone, H: Handler { - debug!("Incoming stream"); - let addr = match stream.peer_addr() { - Ok(addr) => addr, - Err(e) => { - error!("Peer Name error: {:?}", e); - return; - } - }; - - // FIXME: Use Type ascription - let stream_clone: &mut NetworkStream = &mut stream.clone(); - let mut rdr = BufReader::new(stream_clone); - let mut wrt = BufWriter::new(stream); - - let mut keep_alive = true; - while keep_alive { - let req = match Request::new(&mut rdr, addr) { - Ok(req) => req, - Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => { - trace!("tcp closed, cancelling keep-alive loop"); - break; +struct Worker<'a, H: Handler + 'static>(&'a H); + +impl<'a, H: Handler + 'static> Worker<'a, H> { + + fn handle_connection(&self, mut stream: &mut S) where S: NetworkStream + Clone { + debug!("Incoming stream"); + let addr = match stream.peer_addr() { + Ok(addr) => addr, + Err(e) => { + error!("Peer Name error: {:?}", e); + return; } - Err(Error::Io(e)) => { - debug!("ioerror in keepalive loop = {:?}", e); + }; + + // FIXME: Use Type ascription + let stream_clone: &mut NetworkStream = &mut stream.clone(); + let rdr = BufReader::new(stream_clone); + let wrt = BufWriter::new(stream); + + self.keep_alive_loop(rdr, wrt, addr); + debug!("keep_alive loop ending for {}", addr); + } + + fn keep_alive_loop(&self, mut rdr: BufReader<&mut NetworkStream>, mut wrt: W, addr: SocketAddr) { + let mut keep_alive = true; + while keep_alive { + let req = match Request::new(&mut rdr, addr) { + Ok(req) => req, + Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => { + trace!("tcp closed, cancelling keep-alive loop"); + break; + } + Err(Error::Io(e)) => { + debug!("ioerror in keepalive loop = {:?}", e); + break; + } + Err(e) => { + //TODO: send a 400 response + error!("request error = {:?}", e); + break; + } + }; + + + if !self.handle_expect(&req, &mut wrt) { break; } - Err(e) => { - //TODO: send a 400 response - error!("request error = {:?}", e); - break; + + keep_alive = http::should_keep_alive(req.version, &req.headers); + let version = req.version; + let mut res_headers = Headers::new(); + if !keep_alive { + res_headers.set(Connection::close()); + } + { + let mut res = Response::new(&mut wrt, &mut res_headers); + res.version = version; + self.0.handle(req, res); } - }; - if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) { - let status = handler.check_continue((&req.method, &req.uri, &req.headers)); - match write!(&mut wrt, "{} {}\r\n\r\n", Http11, status) { + // if the request was keep-alive, we need to check that the server agrees + // if it wasn't, then the server cannot force it to be true anyways + if keep_alive { + keep_alive = http::should_keep_alive(version, &res_headers); + } + + debug!("keep_alive = {:?} for {}", keep_alive, addr); + } + + } + + fn handle_expect(&self, req: &Request, wrt: &mut W) -> bool { + if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) { + let status = self.0.check_continue((&req.method, &req.uri, &req.headers)); + match write!(wrt, "{} {}\r\n\r\n", Http11, status) { Ok(..) => (), Err(e) => { error!("error writing 100-continue: {:?}", e); - break; + return false; } } if status != StatusCode::Continue { debug!("non-100 status ({}) for Expect 100 request", status); - break; + return false; } } - keep_alive = http::should_keep_alive(req.version, &req.headers); - let mut res = Response::new(&mut wrt); - res.version = req.version; - handler.handle(req, res); - debug!("keep_alive = {:?}", keep_alive); + true } } @@ -270,7 +303,7 @@ mod tests { use status::StatusCode; use uri::RequestUri; - use super::{Request, Response, Fresh, Handler, handle_connection}; + use super::{Request, Response, Fresh, Handler, Worker}; #[test] fn test_check_continue_default() { @@ -287,7 +320,7 @@ mod tests { res.start().unwrap().end().unwrap(); } - handle_connection(&mut mock, &handle); + Worker(&handle).handle_connection(&mut mock); let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; assert_eq!(&mock.write[..cont.len()], cont); let res = b"HTTP/1.1 200 OK\r\n"; @@ -316,7 +349,7 @@ mod tests { 1234567890\ "); - handle_connection(&mut mock, &Reject); + Worker(&Reject).handle_connection(&mut mock); assert_eq!(mock.write, &b"HTTP/1.1 417 Expectation Failed\r\n\r\n"[..]); } } diff --git a/src/server/response.rs b/src/server/response.rs index b608a6a31f..ad6ab69255 100644 --- a/src/server/response.rs +++ b/src/server/response.rs @@ -28,7 +28,7 @@ pub struct Response<'a, W: Any = Fresh> { // The status code for the request. status: status::StatusCode, // The outgoing headers on this response. - headers: header::Headers, + headers: &'a mut header::Headers, _writing: PhantomData } @@ -39,13 +39,13 @@ impl<'a, W: Any> Response<'a, W> { pub fn status(&self) -> status::StatusCode { self.status } /// The headers of this response. - pub fn headers(&self) -> &header::Headers { &self.headers } + pub fn headers(&self) -> &header::Headers { &*self.headers } /// Construct a Response from its constituent parts. pub fn construct(version: version::HttpVersion, body: HttpWriter<&'a mut (Write + 'a)>, status: status::StatusCode, - headers: header::Headers) -> Response<'a, Fresh> { + headers: &'a mut header::Headers) -> Response<'a, Fresh> { Response { status: status, version: version, @@ -57,7 +57,7 @@ impl<'a, W: Any> Response<'a, W> { /// Deconstruct this Response into its constituent parts. pub fn deconstruct(self) -> (version::HttpVersion, HttpWriter<&'a mut (Write + 'a)>, - status::StatusCode, header::Headers) { + status::StatusCode, &'a mut header::Headers) { unsafe { let parts = ( self.version, @@ -114,11 +114,11 @@ impl<'a, W: Any> Response<'a, W> { impl<'a> Response<'a, Fresh> { /// Creates a new Response that can be used to write to a network stream. #[inline] - pub fn new(stream: &'a mut (Write + 'a)) -> Response<'a, Fresh> { + pub fn new(stream: &'a mut (Write + 'a), headers: &'a mut header::Headers) -> Response<'a, Fresh> { Response { status: status::StatusCode::Ok, version: version::HttpVersion::Http11, - headers: header::Headers::new(), + headers: headers, body: ThroughWriter(stream), _writing: PhantomData, } @@ -165,7 +165,7 @@ impl<'a> Response<'a, Fresh> { /// Get a mutable reference to the Headers. #[inline] - pub fn headers_mut(&mut self) -> &mut header::Headers { &mut self.headers } + pub fn headers_mut(&mut self) -> &mut header::Headers { self.headers } } @@ -231,6 +231,7 @@ impl<'a, T: Any> Drop for Response<'a, T> { #[cfg(test)] mod tests { + use header::Headers; use mock::MockStream; use super::Response; @@ -252,9 +253,10 @@ mod tests { #[test] fn test_fresh_start() { + let mut headers = Headers::new(); let mut stream = MockStream::new(); { - let res = Response::new(&mut stream); + let res = Response::new(&mut stream, &mut headers); res.start().unwrap().deconstruct(); } @@ -268,9 +270,10 @@ mod tests { #[test] fn test_streaming_end() { + let mut headers = Headers::new(); let mut stream = MockStream::new(); { - let res = Response::new(&mut stream); + let res = Response::new(&mut stream, &mut headers); res.start().unwrap().end().unwrap(); } @@ -287,9 +290,10 @@ mod tests { #[test] fn test_fresh_drop() { use status::StatusCode; + let mut headers = Headers::new(); let mut stream = MockStream::new(); { - let mut res = Response::new(&mut stream); + let mut res = Response::new(&mut stream, &mut headers); *res.status_mut() = StatusCode::NotFound; } @@ -307,9 +311,10 @@ mod tests { fn test_streaming_drop() { use std::io::Write; use status::StatusCode; + let mut headers = Headers::new(); let mut stream = MockStream::new(); { - let mut res = Response::new(&mut stream); + let mut res = Response::new(&mut stream, &mut headers); *res.status_mut() = StatusCode::NotFound; let mut stream = res.start().unwrap(); stream.write_all(b"foo").unwrap();