Skip to content

Commit

Permalink
Merge pull request from GHSA-3999-5ffv-wp2r
Browse files Browse the repository at this point in the history
feat: switch pending_frames VecDequeue for an Option to bound it
  • Loading branch information
jxs authored Apr 30, 2024
2 parents cf6456f + af8f693 commit 460baf2
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 81 deletions.
2 changes: 1 addition & 1 deletion test-harness/tests/poll_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ fn concurrent_streams() {
const PAYLOAD_SIZE: usize = 128 * 1024;

let data = Msg(vec![0x42; PAYLOAD_SIZE]);
let n_streams = 1000;
let n_streams = 512;

let mut cfg = Config::default();
cfg.set_split_send_size(PAYLOAD_SIZE); // Use a large frame size to speed up the test.
Expand Down
141 changes: 73 additions & 68 deletions yamux/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ struct Active<T> {
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
no_streams_waker: Option<Waker>,

pending_frames: VecDeque<Frame<()>>,
pending_read_frame: Option<Frame<()>>,
pending_write_frame: Option<Frame<()>>,
new_outbound_stream_waker: Option<Waker>,

rtt: rtt::Rtt,
Expand Down Expand Up @@ -360,7 +361,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Mode::Client => 1,
Mode::Server => 2,
},
pending_frames: VecDeque::default(),
pending_read_frame: None,
pending_write_frame: None,
new_outbound_stream_waker: None,
rtt: rtt::Rtt::new(),
accumulated_max_stream_windows: Default::default(),
Expand All @@ -369,7 +371,12 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {

/// Gracefully close the connection to the remote.
fn close(self) -> Closing<T> {
Closing::new(self.stream_receivers, self.pending_frames, self.socket)
let pending_frames = self
.pending_read_frame
.into_iter()
.chain(self.pending_write_frame)
.collect::<VecDeque<Frame<()>>>();
Closing::new(self.stream_receivers, pending_frames, self.socket)
}

/// Cleanup all our resources.
Expand All @@ -392,7 +399,13 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
continue;
}

if let Some(frame) = self.pending_frames.pop_front() {
// Privilege pending `Pong` and `GoAway` `Frame`s
// over `Frame`s from the receivers.
if let Some(frame) = self
.pending_read_frame
.take()
.or_else(|| self.pending_write_frame.take())
{
self.socket.start_send_unpin(frame)?;
continue;
}
Expand All @@ -403,36 +416,63 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Poll::Pending => {}
}

match self.stream_receivers.poll_next_unpin(cx) {
Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => {
self.on_send_frame(frame);
continue;
}
Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => {
self.on_close_stream(id, ack);
continue;
}
Poll::Ready(Some((id, None))) => {
self.on_drop_stream(id);
continue;
}
Poll::Ready(None) => {
self.no_streams_waker = Some(cx.waker().clone());
if self.pending_write_frame.is_none() {
match self.stream_receivers.poll_next_unpin(cx) {
Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => {
log::trace!(
"{}/{}: sending: {}",
self.id,
frame.header().stream_id(),
frame.header()
);
self.pending_write_frame.replace(frame.into());
continue;
}
Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => {
log::trace!("{}/{}: sending close", self.id, id);
self.pending_write_frame
.replace(Frame::close_stream(id, ack).into());
continue;
}
Poll::Ready(Some((id, None))) => {
if let Some(frame) = self.on_drop_stream(id) {
log::trace!("{}/{}: sending: {}", self.id, id, frame.header());
self.pending_write_frame.replace(frame);
};
continue;
}
Poll::Ready(None) => {
self.no_streams_waker = Some(cx.waker().clone());
}
Poll::Pending => {}
}
Poll::Pending => {}
}

match self.socket.poll_next_unpin(cx) {
Poll::Ready(Some(frame)) => {
if let Some(stream) = self.on_frame(frame?)? {
return Poll::Ready(Ok(stream));
if self.pending_read_frame.is_none() {
match self.socket.poll_next_unpin(cx) {
Poll::Ready(Some(frame)) => {
match self.on_frame(frame?)? {
Action::None => {}
Action::New(stream) => {
log::trace!("{}: new inbound {} of {}", self.id, stream, self);
return Poll::Ready(Ok(stream));
}
Action::Ping(f) => {
log::trace!("{}/{}: pong", self.id, f.header().stream_id());
self.pending_read_frame.replace(f.into());
}
Action::Terminate(f) => {
log::trace!("{}: sending term", self.id);
self.pending_read_frame.replace(f.into());
}
}
continue;
}
continue;
}
Poll::Ready(None) => {
return Poll::Ready(Err(ConnectionError::Closed));
Poll::Ready(None) => {
return Poll::Ready(Err(ConnectionError::Closed));
}
Poll::Pending => {}
}
Poll::Pending => {}
}

// If we make it this far, at least one of the above must have registered a waker.
Expand Down Expand Up @@ -463,23 +503,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Poll::Ready(Ok(stream))
}

fn on_send_frame(&mut self, frame: Frame<Either<Data, WindowUpdate>>) {
log::trace!(
"{}/{}: sending: {}",
self.id,
frame.header().stream_id(),
frame.header()
);
self.pending_frames.push_back(frame.into());
}

fn on_close_stream(&mut self, id: StreamId, ack: bool) {
log::trace!("{}/{}: sending close", self.id, id);
self.pending_frames
.push_back(Frame::close_stream(id, ack).into());
}

fn on_drop_stream(&mut self, stream_id: StreamId) {
fn on_drop_stream(&mut self, stream_id: StreamId) -> Option<Frame<()>> {
let s = self.streams.remove(&stream_id).expect("stream not found");

log::trace!("{}: removing dropped stream {}", self.id, stream_id);
Expand Down Expand Up @@ -525,10 +549,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
}
frame
};
if let Some(f) = frame {
log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header());
self.pending_frames.push_back(f.into());
}
frame.map(Into::into)
}

/// Process the result of reading from the socket.
Expand All @@ -537,7 +558,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
/// and return a corresponding error, which terminates the connection.
/// Otherwise we process the frame and potentially return a new `Stream`
/// if one was opened by the remote.
fn on_frame(&mut self, frame: Frame<()>) -> Result<Option<Stream>> {
fn on_frame(&mut self, frame: Frame<()>) -> Result<Action> {
log::trace!("{}: received: {}", self.id, frame.header());

if frame.header().flags().contains(header::ACK)
Expand All @@ -560,23 +581,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Tag::Ping => self.on_ping(&frame.into_ping()),
Tag::GoAway => return Err(ConnectionError::Closed),
};
match action {
Action::None => {}
Action::New(stream) => {
log::trace!("{}: new inbound {} of {}", self.id, stream, self);
return Ok(Some(stream));
}
Action::Ping(f) => {
log::trace!("{}/{}: pong", self.id, f.header().stream_id());
self.pending_frames.push_back(f.into());
}
Action::Terminate(f) => {
log::trace!("{}: sending term", self.id);
self.pending_frames.push_back(f.into());
}
}

Ok(None)
Ok(action)
}

fn on_data(&mut self, frame: Frame<Data>) -> Action {
Expand Down
24 changes: 12 additions & 12 deletions yamux/src/connection/closing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ where
socket: Fuse<frame::Io<T>>,
) -> Self {
Self {
state: State::ClosingStreamReceiver,
state: State::FlushingPendingFrames,
stream_receivers,
pending_frames,
socket,
Expand All @@ -49,6 +49,14 @@ where

loop {
match this.state {
State::FlushingPendingFrames => {
ready!(this.socket.poll_ready_unpin(cx))?;

match this.pending_frames.pop_front() {
Some(frame) => this.socket.start_send_unpin(frame)?,
None => this.state = State::ClosingStreamReceiver,
}
}
State::ClosingStreamReceiver => {
for stream in this.stream_receivers.iter_mut() {
stream.inner_mut().close();
Expand All @@ -59,7 +67,7 @@ where
State::DrainingStreamReceiver => {
match this.stream_receivers.poll_next_unpin(cx) {
Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => {
this.pending_frames.push_back(frame.into())
this.pending_frames.push_back(frame.into());
}
Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => {
this.pending_frames
Expand All @@ -69,19 +77,11 @@ where
Poll::Pending | Poll::Ready(None) => {
// No more frames from streams, append `Term` frame and flush them all.
this.pending_frames.push_back(Frame::term().into());
this.state = State::FlushingPendingFrames;
this.state = State::ClosingSocket;
continue;
}
}
}
State::FlushingPendingFrames => {
ready!(this.socket.poll_ready_unpin(cx))?;

match this.pending_frames.pop_front() {
Some(frame) => this.socket.start_send_unpin(frame)?,
None => this.state = State::ClosingSocket,
}
}
State::ClosingSocket => {
ready!(this.socket.poll_close_unpin(cx))?;

Expand All @@ -93,8 +93,8 @@ where
}

enum State {
FlushingPendingFrames,
ClosingStreamReceiver,
DrainingStreamReceiver,
FlushingPendingFrames,
ClosingSocket,
}

0 comments on commit 460baf2

Please sign in to comment.