Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in RequestStream #271

Draft
wants to merge 21 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions h3/src/client/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,25 @@ impl Builder {

let conn_waker = Some(future::poll_fn(|cx| Poll::Ready(cx.waker().clone())).await);

let inner = ConnectionInner::new(quic, conn_state.clone(), self.config).await?;
let send_request = SendRequest {
open,
conn_state,
conn_waker,
max_field_section_size: self.config.settings.max_field_section_size,
sender_count: Arc::new(AtomicUsize::new(1)),
send_grease_frame: self.config.send_grease,
_buf: PhantomData,
error_sender: inner.error_sender.clone(),
};

Ok((
Connection {
inner: ConnectionInner::new(quic, conn_state.clone(), self.config).await?,
inner,
sent_closing: None,
recv_closing: None,
},
SendRequest {
open,
conn_state,
conn_waker,
max_field_section_size: self.config.settings.max_field_section_size,
sender_count: Arc::new(AtomicUsize::new(1)),
send_grease_frame: self.config.send_grease,
_buf: PhantomData,
},
send_request,
))
}
}
4 changes: 4 additions & 0 deletions h3/src/client/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
use futures_util::future;
use http::request;

use tokio::sync::mpsc::UnboundedSender;
#[cfg(feature = "tracing")]
use tracing::{info, instrument, trace};

Expand Down Expand Up @@ -115,6 +116,7 @@
pub(super) conn_waker: Option<Waker>,
pub(super) _buf: PhantomData<fn(B)>,
pub(super) send_grease_frame: bool,
pub(super) error_sender: UnboundedSender<(Code, &'static str)>,
}

impl<T, B> SendRequest<T, B>
Expand Down Expand Up @@ -188,6 +190,7 @@
self.max_field_section_size,
self.conn_state.clone(),
self.send_grease_frame,
self.error_sender.clone(),
),
};
// send the grease frame only once
Expand Down Expand Up @@ -223,6 +226,7 @@
conn_waker: self.conn_waker.clone(),
_buf: PhantomData,
send_grease_frame: self.send_grease_frame,
error_sender: self.error_sender.clone(),
}
}
}
Expand Down Expand Up @@ -388,7 +392,7 @@
Ok(Frame::Settings(_)) => {
#[cfg(feature = "tracing")]
trace!("Got settings");
()

Check warning on line 395 in h3/src/client/connection.rs

View workflow job for this annotation

GitHub Actions / Lint

unneeded unit expression
}

Ok(Frame::Goaway(id)) => {
Expand Down
2 changes: 1 addition & 1 deletion h3/src/client/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ where
// TODO what if called before recv_response ?
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub async fn recv_data(&mut self) -> Result<Option<impl Buf>, Error> {
self.inner.recv_data().await
future::poll_fn(|cx| self.poll_recv_data(cx)).await
}

/// Receive request body
Expand Down
118 changes: 94 additions & 24 deletions h3/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use futures_util::{future, ready};
use http::HeaderMap;
use stream::WriteBuf;

use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
#[cfg(feature = "tracing")]
use tracing::{instrument, warn};

Expand Down Expand Up @@ -144,6 +145,8 @@ where
// step of the grease sending poll fn
grease_step: GreaseStatus<C::SendStream, B>,
pub config: Config,
error_getter: UnboundedReceiver<(Code, &'static str)>,
pub(crate) error_sender: UnboundedSender<(Code, &'static str)>,
}

enum GreaseStatus<S, B>
Expand Down Expand Up @@ -250,6 +253,8 @@ where
future::poll_fn(|cx| conn.poll_open_send(cx)).await,
);

let (sender, receiver) = unbounded_channel();

//= https://www.rfc-editor.org/rfc/rfc9114#section-6.2.1
//= type=implication
//# The
Expand All @@ -273,6 +278,8 @@ where
send_grease_stream_flag: config.send_grease,
// start at first step
grease_step: GreaseStatus::NotStarted(PhantomData),
error_getter: receiver,
error_sender: sender,
};
conn_inner.send_control_stream_headers().await?;

Expand Down Expand Up @@ -418,13 +425,32 @@ where
Ok(())
}

/// function which checks if there is something in the error channel
/// and closes the connection if there is
/// it returns the error if there is one
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
fn poll_check_stream_errors(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
// check if there is an error in the error channel
// wake up the connection if there is an error

let x = self.error_getter.poll_recv(cx);
match x {
Poll::Ready(Some((code, cause))) => Err(self.close(code, cause)),
Poll::Ready(None) => Ok(()),
Poll::Pending => Ok(()),
}
}

/// Waits for the control stream to be received and reads subsequent frames.
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub fn poll_control(&mut self, cx: &mut Context<'_>) -> Poll<Result<Frame<PayloadLen>, Error>> {
if let Some(ref e) = self.shared.read("poll_accept_request").error {
return Poll::Ready(Err(e.clone()));
}

// check if a connection error occurred on a stream
self.poll_check_stream_errors(cx)?;

let recv = {
// TODO
self.poll_accept_recv(cx)?;
Expand Down Expand Up @@ -705,6 +731,7 @@ pub struct RequestStream<S, B> {
pub(super) conn_state: SharedStateRef,
pub(super) max_field_section_size: u64,
send_grease_frame: bool,
error_sender: UnboundedSender<(Code, &'static str)>,
}

impl<S, B> RequestStream<S, B> {
Expand All @@ -714,13 +741,15 @@ impl<S, B> RequestStream<S, B> {
max_field_section_size: u64,
conn_state: SharedStateRef,
grease: bool,
error_sender: UnboundedSender<(Code, &'static str)>,
) -> Self {
Self {
stream,
conn_state,
max_field_section_size,
trailers: None,
send_grease_frame: grease,
error_sender,
}
}
}
Expand All @@ -731,23 +760,36 @@ impl<S, B> ConnectionState for RequestStream<S, B> {
}
}

impl<S, B> RequestStream<S, B> {
/// Close the connection with an error
//#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub fn close(&self, code: Code, reason: &'static str) -> Error {
let _ = self.error_sender.send((code, reason));
self.maybe_conn_err(code)
}
}

impl<S, B> RequestStream<S, B>
where
S: quic::RecvStream,
{
/// Receive some of the request body.
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
// #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub fn poll_recv_data(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Option<impl Buf>, Error>> {
if !self.stream.has_data() {
let frame = self
.stream
.poll_next(cx)
.map_err(|e| self.maybe_conn_err(e))?;
let frame = ready!(self.stream.poll_next(cx)).map_err(|error| {
let error: Error = error.into();
if let Some(code) = error.try_get_code() {
self.close(code, "test")
} else {
error
}
})?;

match ready!(frame) {
match frame {
Some(Frame::Data { .. }) => (),
Some(Frame::Headers(encoded)) => {
self.trailers = Some(encoded);
Expand Down Expand Up @@ -776,20 +818,23 @@ where
//# The MAX_PUSH_ID frame is always sent on the control stream. Receipt
//# of a MAX_PUSH_ID frame on any other stream MUST be treated as a
//# connection error of type H3_FRAME_UNEXPECTED.
Some(_) => return Poll::Ready(Err(Code::H3_FRAME_UNEXPECTED.into())),
Some(_) => {
return Poll::Ready(Err(
self.close(Code::H3_FRAME_UNEXPECTED, "received unexpected frame")
))
}
None => return Poll::Ready(Ok(None)),
}
}

self.stream
.poll_data(cx)
.map_err(|e| self.maybe_conn_err(e))
}

/// Receive some of the request body.
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub async fn recv_data(&mut self) -> Result<Option<impl Buf>, Error> {
future::poll_fn(|cx| self.poll_recv_data(cx)).await
self.stream.poll_data(cx).map_err(|error| {
let error: Error = error.into();
if let Some(code) = error.try_get_code() {
self.close(code, "test")
} else {
error
}
})
}

/// Poll receive trailers.
Expand All @@ -801,8 +846,14 @@ where
let mut trailers = if let Some(encoded) = self.trailers.take() {
encoded
} else {
let frame = futures_util::ready!(self.stream.poll_next(cx))
.map_err(|e| self.maybe_conn_err(e))?;
let frame = futures_util::ready!(self.stream.poll_next(cx)).map_err(|error| {
let error: Error = error.into();
if let Some(code) = error.try_get_code() {
self.close(code, "test")
} else {
error
}
})?;
match frame {
Some(Frame::Headers(encoded)) => encoded,

Expand All @@ -828,20 +879,37 @@ where
//# The MAX_PUSH_ID frame is always sent on the control stream. Receipt
//# of a MAX_PUSH_ID frame on any other stream MUST be treated as a
//# connection error of type H3_FRAME_UNEXPECTED.
Some(_) => return Poll::Ready(Err(Code::H3_FRAME_UNEXPECTED.into())),
Some(_) => {
// try closing the connection by sending the error to the connection
// this error does not need to be handled. When the buffer is full, a other stream has closed the connection
// when the receiver is dropped, the connection is closed
let _ = self
.error_sender
.send((Code::H3_FRAME_UNEXPECTED, "received unexpected frame"));
return Poll::Ready(Err(Code::H3_FRAME_UNEXPECTED.into()));
}
None => return Poll::Ready(Ok(None)),
}
};
if !self.stream.is_eos() {
// Get the trailing frame
match self
.stream
.poll_next(cx)
.map_err(|e| self.maybe_conn_err(e))?
{
match self.stream.poll_next(cx).map_err(|error| {
let error: Error = error.into();
if let Some(code) = error.try_get_code() {
self.close(code, "test")
} else {
error
}
})? {
Poll::Ready(trailing_frame) => {
if trailing_frame.is_some() {
// if it's not unknown or reserved, fail.
// try closing the connection by sending the error to the connection
// this error does not need to be handled. When the buffer is full, a other stream has closed the connection
// when the receiver is dropped, the connection is closed
let _ = self
.error_sender
.send((Code::H3_FRAME_UNEXPECTED, "received unexpected frame"));
return Poll::Ready(Err(Code::H3_FRAME_UNEXPECTED.into()));
}
}
Expand Down Expand Up @@ -969,13 +1037,15 @@ where
conn_state: self.conn_state.clone(),
max_field_section_size: 0,
send_grease_frame: self.send_grease_frame,
error_sender: self.error_sender.clone(),
},
RequestStream {
stream: recv,
trailers: self.trailers,
conn_state: self.conn_state,
max_field_section_size: self.max_field_section_size,
send_grease_frame: self.send_grease_frame,
error_sender: self.error_sender,
},
)
}
Expand Down
Loading
Loading