diff --git a/src/mock.rs b/src/mock.rs index 1fb80b8a25..a459348521 100644 --- a/src/mock.rs +++ b/src/mock.rs @@ -91,6 +91,10 @@ impl AsyncIo { AsyncIo::new(Buf::wrap(buf.into()), bytes) } + pub fn new_eof() -> AsyncIo { + AsyncIo::new(Buf::wrap(Vec::new().into()), 1) + } + pub fn flushed(&self) -> bool { self.flushed } diff --git a/src/proto/conn.rs b/src/proto/conn.rs index 1a2d25d3d8..31bc06a972 100644 --- a/src/proto/conn.rs +++ b/src/proto/conn.rs @@ -86,6 +86,9 @@ where I: AsyncRead + AsyncWrite, }); } else { trace!("poll when on keep-alive"); + if !T::should_read_first() { + self.try_empty_read()?; + } self.maybe_park_read(); return Ok(Async::NotReady); } @@ -102,7 +105,17 @@ where I: AsyncRead + AsyncWrite, pub fn can_read_head(&self) -> bool { match self.state.reading { - Reading::Init => true, + //Reading::Init => true, + Reading::Init => { + if T::should_read_first() { + true + } else { + match self.state.writing { + Writing::Init => false, + _ => true, + } + } + }, _ => false, } } @@ -219,6 +232,41 @@ where I: AsyncRead + AsyncWrite, } } + // This will check to make sure the io object read is empty. + // + // This should only be called for Clients wanting to enter the idle + // state. + pub fn try_empty_read(&mut self) -> io::Result<()> { + assert!(!self.can_read_head() && !self.can_read_body()); + + if !self.io.read_buf().is_empty() { + Err(io::Error::new(io::ErrorKind::InvalidData, "unexpected bytes after message ended")) + } else { + match self.io.read_from_io() { + Ok(Async::Ready(0)) => { + self.state.close_read(); + let must_error = !self.state.is_idle() && T::should_error_on_parse_eof(); + if must_error { + Err(io::ErrorKind::UnexpectedEof.into()) + } else { + Ok(()) + } + }, + Ok(Async::Ready(_)) => { + Err(io::Error::new(io::ErrorKind::InvalidData, "unexpected bytes after message ended")) + }, + Ok(Async::NotReady) => { + trace!("try_empty_read; read blocked"); + Ok(()) + }, + Err(e) => { + self.state.close(); + Err(e) + } + } + } + } + fn maybe_notify(&mut self) { // its possible that we returned NotReady from poll() without having // exhausted the underlying Io. We would have done this when we @@ -882,7 +930,7 @@ mod tests { fn test_conn_init_read_eof_busy() { let _: Result<(), ()> = future::lazy(|| { // server ignores - let io = AsyncIo::new_buf(vec![], 1); + let io = AsyncIo::new_eof(); let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); conn.state.busy(); @@ -892,7 +940,7 @@ mod tests { } // client - let io = AsyncIo::new_buf(vec![], 1); + let io = AsyncIo::new_eof(); let mut conn = Conn::<_, proto::Chunk, ClientTransaction>::new(io, Default::default()); conn.state.busy(); diff --git a/src/proto/dispatch.rs b/src/proto/dispatch.rs index f309ca6e45..b7de2f0858 100644 --- a/src/proto/dispatch.rs +++ b/src/proto/dispatch.rs @@ -137,6 +137,9 @@ where } else { let _ = body.close(); } + } else if !T::should_read_first() { + self.conn.try_empty_read()?; + return Ok(Async::NotReady); } else { self.conn.maybe_park_read(); return Ok(Async::Ready(())); @@ -188,13 +191,6 @@ where } fn is_done(&self) -> bool { - trace!( - "is_done; read={}, write={}, should_poll={}, body={}", - self.conn.is_read_closed(), - self.conn.is_write_closed(), - self.dispatch.should_poll(), - self.body_rx.is_some(), - ); let read_done = self.conn.is_read_closed(); if !T::should_read_first() && read_done { @@ -223,6 +219,7 @@ where #[inline] fn poll(&mut self) -> Poll { + trace!("Dispatcher::poll"); self.poll_read()?; self.poll_write()?; self.poll_flush()?; diff --git a/tests/client.rs b/tests/client.rs index eb72a2177e..17db0748b7 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -653,7 +653,6 @@ mod dispatch_impl { assert_eq!(closes.load(Ordering::Relaxed), 1); } - #[test] fn no_keep_alive_closes_connection() { // https://github.com/hyperium/hyper/issues/1383 @@ -694,6 +693,47 @@ mod dispatch_impl { assert_eq!(closes.load(Ordering::Relaxed), 1); } + #[test] + fn socket_disconnect_closes_idle_conn() { + // notably when keep-alive is enabled + let _ = pretty_env_logger::init(); + + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let mut core = Core::new().unwrap(); + let handle = core.handle(); + let closes = Arc::new(AtomicUsize::new(0)); + + let (tx1, rx1) = oneshot::channel(); + + thread::spawn(move || { + let mut sock = server.accept().unwrap().0; + sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + let mut buf = [0; 4096]; + sock.read(&mut buf).expect("read 1"); + sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n").unwrap(); + let _ = tx1.send(()); + }); + + let uri = format!("http://{}/a", addr).parse().unwrap(); + + let client = Client::configure() + .connector(DebugConnector(HttpConnector::new(1, &handle), closes.clone())) + .no_proto() + .build(&handle); + let res = client.get(uri).and_then(move |res| { + assert_eq!(res.status(), hyper::StatusCode::Ok); + res.body().concat2() + }); + let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); + + let timeout = Timeout::new(Duration::from_millis(200), &handle).unwrap(); + let rx = rx.and_then(move |_| timeout.map_err(|e| e.into())); + core.run(res.join(rx).map(|r| r.0)).unwrap(); + + assert_eq!(closes.load(Ordering::Relaxed), 1); + } struct DebugConnector(HttpConnector, Arc);