diff --git a/Cargo.toml b/Cargo.toml index 23bb8f5..aff168f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,8 +14,8 @@ An implementation of SSL streams for Tokio backed by OpenSSL [dependencies] openssl = "0.10.30" openssl-sys = "0.9.58" -tokio = "0.2" +tokio = "0.3" [dev-dependencies] futures = "0.3" -tokio = { version = "0.2", features = ["full"] } +tokio = { version = "0.3", features = ["full"] } diff --git a/src/lib.rs b/src/lib.rs index ba42fe9..c90c454 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,8 +24,7 @@ use std::future::Future; use std::io::{self, Read, Write}; use std::pin::Pin; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite}; -use std::mem::MaybeUninit; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; /// Asynchronously performs a client-side TLS handshake over the provided stream. pub async fn connect( @@ -99,10 +98,13 @@ where S: AsyncRead + Unpin, { fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self.with_context(|ctx, stream| stream.poll_read(ctx, buf)) { - Poll::Ready(r) => r, - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } + self.with_context(|ctx, stream| { + let mut buf = ReadBuf::new(buf); + match stream.poll_read(ctx, &mut buf)? { + Poll::Ready(()) => Ok(buf.filled().len()), + Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), + } + }) } } @@ -182,10 +184,7 @@ where /// /// The caller must ensure the pointer is valid. pub unsafe fn from_raw_parts(ssl: *mut ffi::SSL, stream: S) -> Self { - let stream = StreamWrapper { - stream, - context: 0, - }; + let stream = StreamWrapper { stream, context: 0 }; SslStream(ssl::SslStream::from_raw_parts(ssl, stream)) } } @@ -194,19 +193,18 @@ impl AsyncRead for SslStream where S: AsyncRead + AsyncWrite + Unpin, { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit]) -> bool { - // Note that this does not forward to `S` because the buffer is - // unconditionally filled in by OpenSSL, not the actual object `S`. - // We're decrypting bytes from `S` into the buffer above! - false - } - fn poll_read( mut self: Pin<&mut Self>, ctx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - self.with_context(ctx, |s| cvt(s.read(buf))) + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.with_context(ctx, |s| match cvt(s.read(buf.initialize_unfilled()))? { + Poll::Ready(nread) => { + buf.advance(nread); + Poll::Ready(Ok(())) + } + Poll::Pending => Poll::Pending, + }) } } diff --git a/tests/google.rs b/tests/google.rs index 42b1f00..92060b8 100644 --- a/tests/google.rs +++ b/tests/google.rs @@ -33,7 +33,7 @@ async fn google() { #[tokio::test] async fn server() { - let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let server = async move {