Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify traits and error handling #244

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/h3spec.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
LOGFILE=h3server.log
if ! [ -e "h3spec-linux-x86_64" ] ; then
# if we don't already have a h3spec executable, wget it from github
wget https://github.com/kazu-yamamoto/h3spec/releases/download/v0.1.10/h3spec-linux-x86_64
wget https://github.com/kazu-yamamoto/h3spec/releases/download/v0.1.11/h3spec-linux-x86_64
chmod +x h3spec-linux-x86_64
fi

Expand Down
27 changes: 16 additions & 11 deletions examples/webtransport_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_span_events(tracing_subscriber::fmt::format::FmtSpan::FULL)
.with_writer(std::io::stderr)
.with_max_level(tracing::Level::INFO)
.init();

#[cfg(feature = "tree")]
Expand All @@ -69,6 +70,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::from_default_env())
.with(tracing_tree::HierarchicalLayer::new(4).with_bracketed_fields(true))
.with_max_level(tracing::Level::INFO)
.init();

// process cli arguments
Expand Down Expand Up @@ -156,7 +158,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

async fn handle_connection(mut conn: Connection<h3_quinn::Connection, Bytes>) -> Result<()> {
async fn handle_connection(mut conn: Connection<h3_quinn::Connection<Bytes>, Bytes>) -> Result<()> {
// 3. TODO: Conditionally, if the client indicated that this is a webtransport session, we should accept it here, else use regular h3.
// if this is a webtransport session, then h3 needs to stop handing the datagrams, bidirectional streams, and unidirectional streams and give them
// to the webtransport session.
Expand Down Expand Up @@ -314,16 +316,19 @@ where
tracing::info!("Finished sending datagram");
}
}
uni_stream = session.accept_uni() => {
let (id, stream) = uni_stream?.unwrap();

let send = session.open_uni(id).await?;
tokio::spawn( async move { log_result!(echo_stream(send, stream).await); });
}
stream = session.accept_bi() => {
if let Some(server::AcceptedBi::BidiStream(_, stream)) = stream? {
let (send, recv) = quic::BidiStream::split(stream);
tokio::spawn( async move { log_result!(echo_stream(send, recv).await); });
stream = session.accept_streams() => {
match stream? {
Some(server::AcceptStream::BidiStream(_, stream)) =>{
info!("Received bidirectional stream");
let (send, recv) = quic::BidiStream::split(stream);
tokio::spawn( async move { log_result!(echo_stream(send, recv).await); });
},
Some(server::AcceptStream::UnidirectionalStream(id, stream)) => {
info!("Received unidirectional stream with id: {:?}", id);
let send = session.open_uni(id).await?;
tokio::spawn( async move { log_result!(echo_stream(send, stream).await); });
},
_ => (),
}
}
else => {
Expand Down
4 changes: 2 additions & 2 deletions h3-quinn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ categories = ["network-programming", "web-programming"]
license = "MIT"

[dependencies]
h3 = { version = "0.0.6", path = "../h3" }
bytes = "1"
quinn = { version = "0.11", default-features = false, features = [
"futures-io",
] }
h3 = { version = "0.0.6", path = "../h3" }
bytes = "1"
tokio-util = { version = "0.7.9" }
futures = { version = "0.3.28" }
tokio = { version = "1", features = ["io-util"], default-features = false }
Expand Down
117 changes: 71 additions & 46 deletions h3-quinn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ use bytes::{Buf, Bytes, BytesMut};

use futures::{
ready,
stream::{self},
stream::{self, select, Select},
Stream, StreamExt,
};
pub use quinn::{self, AcceptBi, AcceptUni, Endpoint, OpenBi, OpenUni, VarInt, WriteError};
pub use quinn::{self, Endpoint, OpenBi, OpenUni, VarInt, WriteError};
use quinn::{ApplicationClose, ClosedStream, ReadDatagram};

use h3::{
ext::Datagram,
quic::{self, Error, StreamId, WriteBuf},
quic::{self, Error, IncomingStreamType, StreamId, WriteBuf},
};
use tokio_util::sync::ReusableBoxFuture;

Expand All @@ -37,31 +37,67 @@ type BoxStreamSync<'a, T> = Pin<Box<dyn Stream<Item = T> + Sync + Send + 'a>>;
/// A QUIC connection backed by Quinn
///
/// Implements a [`quic::Connection`] backed by a [`quinn::Connection`].
pub struct Connection {
pub struct Connection<B>
where
B: Buf,
{
conn: quinn::Connection,
incoming_bi: BoxStreamSync<'static, <AcceptBi<'static> as Future>::Output>,
opening_bi: Option<BoxStreamSync<'static, <OpenBi<'static> as Future>::Output>>,
incoming_uni: BoxStreamSync<'static, <AcceptUni<'static> as Future>::Output>,
opening_uni: Option<BoxStreamSync<'static, <OpenUni<'static> as Future>::Output>>,
datagrams: BoxStreamSync<'static, <ReadDatagram<'static> as Future>::Output>,
}

impl Connection {
incoming: Select<
BoxStreamSync<
'static,
Result<IncomingStreamType<BidiStream<B>, RecvStream, B>, quinn::ConnectionError>,
>,
BoxStreamSync<
'static,
Result<IncomingStreamType<BidiStream<B>, RecvStream, B>, quinn::ConnectionError>,
>,
>,
}

impl<B> Connection<B>
where
B: Buf,
{
/// Create a [`Connection`] from a [`quinn::Connection`]
pub fn new(conn: quinn::Connection) -> Self {
let incoming_uni = Box::pin(stream::unfold(conn.clone(), |conn| async {
Some((
conn.accept_uni().await.map(|recv_stream| {
IncomingStreamType::<BidiStream<B>, RecvStream, B>::Unidirectional(
RecvStream::new(recv_stream),
)
}),
conn,
))
}));
let incoming_bi = Box::pin(stream::unfold(conn.clone(), |conn| async {
Some((
conn.accept_bi().await.map(|bidi_stream| {
IncomingStreamType::<BidiStream<B>, RecvStream, B>::Bidirectional(
BidiStream {
send: SendStream::new(bidi_stream.0),
recv: RecvStream::new(bidi_stream.1),
},
std::marker::PhantomData,
)
}),
conn,
))
}));

Self {
conn: conn.clone(),
incoming_bi: Box::pin(stream::unfold(conn.clone(), |conn| async {
Some((conn.accept_bi().await, conn))
})),

opening_bi: None,
incoming_uni: Box::pin(stream::unfold(conn.clone(), |conn| async {
Some((conn.accept_uni().await, conn))
})),

opening_uni: None,
datagrams: Box::pin(stream::unfold(conn, |conn| async {
Some((conn.read_datagram().await, conn))
})),
incoming: select(incoming_bi, incoming_uni),
}
}
}
Expand Down Expand Up @@ -153,51 +189,37 @@ impl From<quinn::SendDatagramError> for SendDatagramError {
}
}

impl<B> quic::Connection<B> for Connection
impl<B> quic::Connection<B> for Connection<B>
where
B: Buf,
{
type RecvStream = RecvStream;
type OpenStreams = OpenStreams;
type AcceptError = ConnectionError;

#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
fn poll_accept_bidi(
&mut self,
cx: &mut task::Context<'_>,
) -> Poll<Result<Option<Self::BidiStream>, Self::AcceptError>> {
let (send, recv) = match ready!(self.incoming_bi.poll_next_unpin(cx)) {
Some(x) => x?,
None => return Poll::Ready(Ok(None)),
};
Poll::Ready(Ok(Some(Self::BidiStream {
send: Self::SendStream::new(send),
recv: Self::RecvStream::new(recv),
})))
}

#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
fn poll_accept_recv(
&mut self,
cx: &mut task::Context<'_>,
) -> Poll<Result<Option<Self::RecvStream>, Self::AcceptError>> {
let recv = match ready!(self.incoming_uni.poll_next_unpin(cx)) {
Some(x) => x?,
None => return Poll::Ready(Ok(None)),
};
Poll::Ready(Ok(Some(Self::RecvStream::new(recv))))
}

fn opener(&self) -> Self::OpenStreams {
OpenStreams {
conn: self.conn.clone(),
opening_bi: None,
opening_uni: None,
}
}
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
fn poll_incoming(
&mut self,
cx: &mut task::Context<'_>,
) -> Poll<Result<IncomingStreamType<Self::BidiStream, Self::RecvStream, B>, Self::AcceptError>>
{
// put the two streams together

match ready!(self.incoming.poll_next_unpin(cx)).unwrap() {
Ok(x) => Poll::Ready(Ok(x)),
Err(e) => Poll::Ready(Err(ConnectionError(e).into())),
}
}
}

impl<B> quic::OpenStreams<B> for Connection
impl<B> quic::OpenStreams<B> for Connection<B>
where
B: Buf,
{
Expand Down Expand Up @@ -248,7 +270,7 @@ where
}
}

impl<B> quic::SendDatagramExt<B> for Connection
impl<B> quic::SendDatagramExt<B> for Connection<B>
where
B: Buf,
{
Expand All @@ -265,7 +287,10 @@ where
}
}

impl quic::RecvDatagramExt for Connection {
impl<B> quic::RecvDatagramExt for Connection<B>
where
B: Buf,
{
type Buf = Bytes;

type Error = ConnectionError;
Expand Down
Loading
Loading