Skip to content

Commit

Permalink
hyperium#34 properly implement TLS-1.3 shutdown behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
Yan Zhai committed Apr 19, 2019
1 parent b6e3945 commit 87916da
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 95 deletions.
66 changes: 33 additions & 33 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use super::*;
use std::io::Write;
use rustls::Session;

use std::io::Write;

/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
Expand All @@ -12,21 +11,14 @@ pub struct TlsStream<IO> {
pub(crate) state: TlsState,

#[cfg(feature = "early-data")]
pub(crate) early_data: (usize, Vec<u8>)
}

#[derive(Debug)]
pub(crate) enum TlsState {
#[cfg(feature = "early-data")] EarlyData,
Stream,
Eof,
Shutdown
pub(crate) early_data: (usize, Vec<u8>),
}

pub(crate) enum MidHandshake<IO> {
Handshaking(TlsStream<IO>),
#[cfg(feature = "early-data")] EarlyData(TlsStream<IO>),
End
#[cfg(feature = "early-data")]
EarlyData(TlsStream<IO>),
End,
}

impl<IO> TlsStream<IO> {
Expand All @@ -47,7 +39,8 @@ impl<IO> TlsStream<IO> {
}

impl<IO> Future for MidHandshake<IO>
where IO: AsyncRead + AsyncWrite,
where
IO: AsyncRead + AsyncWrite,
{
type Item = TlsStream<IO>;
type Error = io::Error;
Expand All @@ -71,13 +64,14 @@ where IO: AsyncRead + AsyncWrite,
MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)),
#[cfg(feature = "early-data")]
MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)),
MidHandshake::End => panic!()
MidHandshake::End => panic!(),
}
}
}

impl<IO> io::Read for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
where
IO: AsyncRead + AsyncWrite,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.state {
Expand Down Expand Up @@ -106,31 +100,35 @@ where IO: AsyncRead + AsyncWrite
}

self.read(buf)
},
TlsState::Stream => {
}
TlsState::Stream | TlsState::WriteShutdown => {
let mut stream = Stream::new(&mut self.io, &mut self.session);

match stream.read(buf) {
Ok(0) => {
self.state = TlsState::Eof;
self.state.shutdown_read();
Ok(0)
},
}
Ok(n) => Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
self.state = TlsState::Shutdown;
stream.session.send_close_notify();
self.state.shutdown_read();
if self.state.writeable() {
stream.session.send_close_notify();
self.state.shutdown_write();
}
Ok(0)
},
Err(e) => Err(e)
}
Err(e) => Err(e),
}
},
TlsState::Eof | TlsState::Shutdown => Ok(0),
}
TlsState::ReadShutdown | TlsState::FullyShutdown => Ok(0),
}
}
}

impl<IO> io::Write for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
where
IO: AsyncRead + AsyncWrite,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);
Expand Down Expand Up @@ -164,8 +162,8 @@ where IO: AsyncRead + AsyncWrite
self.state = TlsState::Stream;
data.clear();
stream.write(buf)
},
_ => stream.write(buf)
}
_ => stream.write(buf),
}
}

Expand All @@ -176,22 +174,24 @@ where IO: AsyncRead + AsyncWrite
}

impl<IO> AsyncRead for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
where
IO: AsyncRead + AsyncWrite,
{
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
}
}

impl<IO> AsyncWrite for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
where
IO: AsyncRead + AsyncWrite,
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
match self.state {
TlsState::Shutdown => (),
s if !s.writeable() => (),
_ => {
self.session.send_close_notify();
self.state = TlsState::Shutdown;
self.state.shutdown_write();
}
}

Expand Down
105 changes: 69 additions & 36 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,76 @@
pub extern crate rustls;
pub extern crate webpki;

extern crate futures;
extern crate tokio_io;
extern crate bytes;
extern crate futures;
extern crate iovec;
extern crate tokio_io;

mod common;
pub mod client;
mod common;
pub mod server;

use std::{ io, mem };
use common::Stream;
use futures::{Async, Future, Poll};
use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession};
use std::sync::Arc;
use std::{io, mem};
use tokio_io::{try_nb, AsyncRead, AsyncWrite};
use webpki::DNSNameRef;
use rustls::{
ClientSession, ServerSession,
ClientConfig, ServerConfig
};
use futures::{Async, Future, Poll};
use tokio_io::{ AsyncRead, AsyncWrite, try_nb };
use common::Stream;

#[derive(Debug, Copy, Clone)]
pub enum TlsState {
#[cfg(feature = "early-data")]
EarlyData,
Stream,
ReadShutdown,
WriteShutdown,
FullyShutdown,
}

impl TlsState {
pub(crate) fn shutdown_read(&mut self) {
match *self {
TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
_ => *self = TlsState::ReadShutdown,
}
}

pub(crate) fn shutdown_write(&mut self) {
match *self {
TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
_ => *self = TlsState::WriteShutdown,
}
}

pub(crate) fn writeable(&self) -> bool {
match *self {
TlsState::WriteShutdown | TlsState::FullyShutdown => true,
_ => false,
}
}
}

/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
#[derive(Clone)]
pub struct TlsConnector {
inner: Arc<ClientConfig>,
#[cfg(feature = "early-data")]
early_data: bool
early_data: bool,
}

/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
#[derive(Clone)]
pub struct TlsAcceptor {
inner: Arc<ServerConfig>
inner: Arc<ServerConfig>,
}

impl From<Arc<ClientConfig>> for TlsConnector {
fn from(inner: Arc<ClientConfig>) -> TlsConnector {
TlsConnector {
inner,
#[cfg(feature = "early-data")]
early_data: false
early_data: false,
}
}
}
Expand All @@ -66,40 +95,45 @@ impl TlsConnector {
}

pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO>
where IO: AsyncRead + AsyncWrite
where
IO: AsyncRead + AsyncWrite,
{
self.connect_with(domain, stream, |_| ())
}

#[inline]
pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F)
-> Connect<IO>
pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO>
where
IO: AsyncRead + AsyncWrite,
F: FnOnce(&mut ClientSession)
F: FnOnce(&mut ClientSession),
{
let mut session = ClientSession::new(&self.inner, domain);
f(&mut session);

#[cfg(not(feature = "early-data"))] {
#[cfg(not(feature = "early-data"))]
{
Connect(client::MidHandshake::Handshaking(client::TlsStream {
session, io: stream,
state: client::TlsState::Stream,
session,
io: stream,
state: TlsState::Stream,
}))
}

#[cfg(feature = "early-data")] {
#[cfg(feature = "early-data")]
{
Connect(if self.early_data {
client::MidHandshake::EarlyData(client::TlsStream {
session, io: stream,
state: client::TlsState::EarlyData,
early_data: (0, Vec::new())
session,
io: stream,
state: TlsState::EarlyData,
early_data: (0, Vec::new()),
})
} else {
client::MidHandshake::Handshaking(client::TlsStream {
session, io: stream,
state: client::TlsState::Stream,
early_data: (0, Vec::new())
session,
io: stream,
state: TlsState::Stream,
early_data: (0, Vec::new()),
})
})
}
Expand All @@ -108,29 +142,29 @@ impl TlsConnector {

impl TlsAcceptor {
pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
where IO: AsyncRead + AsyncWrite,
where
IO: AsyncRead + AsyncWrite,
{
self.accept_with(stream, |_| ())
}

#[inline]
pub fn accept_with<IO, F>(&self, stream: IO, f: F)
-> Accept<IO>
pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
where
IO: AsyncRead + AsyncWrite,
F: FnOnce(&mut ServerSession)
F: FnOnce(&mut ServerSession),
{
let mut session = ServerSession::new(&self.inner);
f(&mut session);

Accept(server::MidHandshake::Handshaking(server::TlsStream {
session, io: stream,
state: server::TlsState::Stream,
session,
io: stream,
state: TlsState::Stream,
}))
}
}


/// Future returned from `ClientConfigExt::connect_async` which will resolve
/// once the connection handshake has finished.
pub struct Connect<IO>(client::MidHandshake<IO>);
Expand All @@ -139,7 +173,6 @@ pub struct Connect<IO>(client::MidHandshake<IO>);
/// once the accept handshake has finished.
pub struct Accept<IO>(server::MidHandshake<IO>);


impl<IO: AsyncRead + AsyncWrite> Future for Connect<IO> {
type Item = client::TlsStream<IO>;
type Error = io::Error;
Expand Down
Loading

0 comments on commit 87916da

Please sign in to comment.