diff --git a/quinn-proto/src/connection/assembler.rs b/quinn-proto/src/connection/assembler.rs index 7f4aefc1c..617b9695b 100644 --- a/quinn-proto/src/connection/assembler.rs +++ b/quinn-proto/src/connection/assembler.rs @@ -25,14 +25,9 @@ impl Assembler { Self::default() } - // Get the the next ordered chunk - pub(crate) fn read( - &mut self, - max_length: usize, - ordered: bool, - ) -> Result, AssembleError> { + pub(crate) fn ensure_ordering(&mut self, ordered: bool) -> Result<(), IllegalOrderedRead> { if ordered && !self.state.is_ordered() { - return Err(AssembleError::IllegalOrderedRead); + return Err(IllegalOrderedRead); } else if !ordered && self.state.is_ordered() { // Enter unordered mode let mut recvd = RangeSet::new(); @@ -42,17 +37,21 @@ impl Assembler { } self.state = State::Unordered { recvd }; } + Ok(()) + } + /// Get the the next chunk + pub(crate) fn read(&mut self, max_length: usize, ordered: bool) -> Option { loop { let mut chunk = match self.data.peek_mut() { Some(chunk) => chunk, - None => return Ok(None), + None => return None, }; if ordered { if chunk.offset > self.bytes_read { // Next chunk is after current read index - return Ok(None); + return None; } else if (chunk.offset + chunk.bytes.len() as u64) <= self.bytes_read { // Next chunk is useless as the read index is beyond its end PeekMut::pop(chunk); @@ -68,7 +67,7 @@ impl Assembler { } } - return Ok(Some(if max_length < chunk.bytes.len() { + return Some(if max_length < chunk.bytes.len() { self.bytes_read += max_length as u64; let offset = chunk.offset; chunk.offset += max_length as u64; @@ -78,7 +77,7 @@ impl Assembler { self.defragmented = self.defragmented.saturating_sub(1); let chunk = PeekMut::pop(chunk); Chunk::new(chunk.offset, chunk.bytes) - })); + }); } } @@ -242,11 +241,8 @@ impl Default for State { } /// Error indicating that an ordered read was performed on a stream after an unordered read -#[derive(Debug, Copy, Clone)] -pub enum AssembleError { - IllegalOrderedRead, - UnknownStream, -} +#[derive(Debug)] +pub struct IllegalOrderedRead; #[cfg(test)] mod test { @@ -272,6 +268,7 @@ mod test { #[test] fn assemble_unordered() { let mut x = Assembler::new(); + x.ensure_ordering(false).unwrap(); x.insert(3, Bytes::from_static(b"456")); assert_matches!(next(&mut x, 32), None); x.insert(0, Bytes::from_static(b"123")); @@ -426,42 +423,44 @@ mod test { x.insert(7, Bytes::from_static(b"hij")); x.insert(11, Bytes::from_static(b"lmn")); x.defragment(); - assert_matches!(x.read(usize::MAX, true), Ok(Some(ref y)) if &y.bytes[..] == b"abcdef"); + assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"abcdef"); x.insert(5, Bytes::from_static(b"fghijklmn")); - assert_matches!(x.read(usize::MAX, true), Ok(Some(ref y)) if &y.bytes[..] == b"ghijklmn"); + assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"ghijklmn"); x.insert(13, Bytes::from_static(b"nopq")); - assert_matches!(x.read(usize::MAX, true), Ok(Some(ref y)) if &y.bytes[..] == b"opq"); + assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"opq"); x.insert(15, Bytes::from_static(b"pqrs")); - assert_matches!(x.read(usize::MAX, true), Ok(Some(ref y)) if &y.bytes[..] == b"rs"); - assert_matches!(x.read(usize::MAX, true), Ok(None)); + assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"rs"); + assert_matches!(x.read(usize::MAX, true), None); } #[test] fn unordered_happy_path() { let mut x = Assembler::new(); + x.ensure_ordering(false).unwrap(); x.insert(0, Bytes::from_static(b"abc")); assert_eq!( next_unordered(&mut x), Chunk::new(0, Bytes::from_static(b"abc")) ); - assert_eq!(x.read(usize::MAX, false).unwrap(), None); + assert_eq!(x.read(usize::MAX, false), None); x.insert(3, Bytes::from_static(b"def")); assert_eq!( next_unordered(&mut x), Chunk::new(3, Bytes::from_static(b"def")) ); - assert_eq!(x.read(usize::MAX, false).unwrap(), None); + assert_eq!(x.read(usize::MAX, false), None); } #[test] fn unordered_dedup() { let mut x = Assembler::new(); + x.ensure_ordering(false).unwrap(); x.insert(3, Bytes::from_static(b"def")); assert_eq!( next_unordered(&mut x), Chunk::new(3, Bytes::from_static(b"def")) ); - assert_eq!(x.read(usize::MAX, false).unwrap(), None); + assert_eq!(x.read(usize::MAX, false), None); x.insert(0, Bytes::from_static(b"a")); x.insert(0, Bytes::from_static(b"abcdefghi")); x.insert(0, Bytes::from_static(b"abcd")); @@ -477,54 +476,54 @@ mod test { next_unordered(&mut x), Chunk::new(6, Bytes::from_static(b"ghi")) ); - assert_eq!(x.read(usize::MAX, false).unwrap(), None); + assert_eq!(x.read(usize::MAX, false), None); x.insert(8, Bytes::from_static(b"ijkl")); assert_eq!( next_unordered(&mut x), Chunk::new(9, Bytes::from_static(b"jkl")) ); - assert_eq!(x.read(usize::MAX, false).unwrap(), None); + assert_eq!(x.read(usize::MAX, false), None); x.insert(12, Bytes::from_static(b"mno")); assert_eq!( next_unordered(&mut x), Chunk::new(12, Bytes::from_static(b"mno")) ); - assert_eq!(x.read(usize::MAX, false).unwrap(), None); + assert_eq!(x.read(usize::MAX, false), None); x.insert(2, Bytes::from_static(b"cde")); - assert_eq!(x.read(usize::MAX, false).unwrap(), None); + assert_eq!(x.read(usize::MAX, false), None); } #[test] fn chunks_dedup() { let mut x = Assembler::new(); x.insert(3, Bytes::from_static(b"def")); - assert_eq!(x.read(usize::MAX, true).unwrap(), None); + assert_eq!(x.read(usize::MAX, true), None); x.insert(0, Bytes::from_static(b"a")); x.insert(1, Bytes::from_static(b"bcdefghi")); x.insert(0, Bytes::from_static(b"abcd")); assert_eq!( - x.read(usize::MAX, true).unwrap(), + x.read(usize::MAX, true), Some(Chunk::new(0, Bytes::from_static(b"abcd"))) ); assert_eq!( - x.read(usize::MAX, true).unwrap(), + x.read(usize::MAX, true), Some(Chunk::new(4, Bytes::from_static(b"efghi"))) ); - assert_eq!(x.read(usize::MAX, true).unwrap(), None); + assert_eq!(x.read(usize::MAX, true), None); x.insert(8, Bytes::from_static(b"ijkl")); assert_eq!( - x.read(usize::MAX, true).unwrap(), + x.read(usize::MAX, true), Some(Chunk::new(9, Bytes::from_static(b"jkl"))) ); - assert_eq!(x.read(usize::MAX, true).unwrap(), None); + assert_eq!(x.read(usize::MAX, true), None); x.insert(12, Bytes::from_static(b"mno")); assert_eq!( - x.read(usize::MAX, true).unwrap(), + x.read(usize::MAX, true), Some(Chunk::new(12, Bytes::from_static(b"mno"))) ); - assert_eq!(x.read(usize::MAX, true).unwrap(), None); + assert_eq!(x.read(usize::MAX, true), None); x.insert(2, Bytes::from_static(b"cde")); - assert_eq!(x.read(usize::MAX, true).unwrap(), None); + assert_eq!(x.read(usize::MAX, true), None); } #[test] @@ -533,7 +532,7 @@ mod test { x.insert(0, Bytes::from_static(b"abc")); assert_eq!(x.data.len(), 1); assert_eq!( - x.read(usize::MAX, true).unwrap(), + x.read(usize::MAX, true), Some(Chunk::new(0, Bytes::from_static(b"abc"))) ); x.insert(0, Bytes::from_static(b"ab")); @@ -549,10 +548,10 @@ mod test { } fn next_unordered(x: &mut Assembler) -> Chunk { - x.read(usize::MAX, false).unwrap().unwrap() + x.read(usize::MAX, false).unwrap() } fn next(x: &mut Assembler, size: usize) -> Option { - x.read(size, true).unwrap().map(|chunk| chunk.bytes) + x.read(size, true).map(|chunk| chunk.bytes) } } diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index b01330ae3..6ac9e7602 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -1760,7 +1760,7 @@ where space .crypto_stream .insert(crypto.offset, crypto.data.clone()); - while let Some(chunk) = space.crypto_stream.read(usize::MAX, true).unwrap() { + while let Some(chunk) = space.crypto_stream.read(usize::MAX, true) { trace!("consumed {} CRYPTO bytes", chunk.bytes.len()); if self.crypto.read_handshake(&chunk.bytes)? { self.events.push_back(Event::HandshakeDataReady); diff --git a/quinn-proto/src/connection/streams/recv.rs b/quinn-proto/src/connection/streams/recv.rs index dcc91155f..7c305e7c4 100644 --- a/quinn-proto/src/connection/streams/recv.rs +++ b/quinn-proto/src/connection/streams/recv.rs @@ -3,7 +3,7 @@ use thiserror::Error; use tracing::debug; use super::{ShouldTransmit, UnknownStream}; -use crate::connection::assembler::{AssembleError, Assembler, Chunk}; +use crate::connection::assembler::{Assembler, Chunk, IllegalOrderedRead}; use crate::{frame, TransportError, VarInt}; #[derive(Debug, Default)] @@ -73,7 +73,8 @@ impl Recv { return Err(ReadError::UnknownStream); } - match self.assembler.read(max_length, ordered)? { + self.assembler.ensure_ordering(ordered)?; + match self.assembler.read(max_length, ordered) { Some(chunk) => Ok(Some(chunk)), None => self.read_blocked().map(|()| None), } @@ -92,7 +93,8 @@ impl Recv { return Ok(Some(out)); } - while let Some(chunk) = self.assembler.read(usize::MAX, true)? { + self.assembler.ensure_ordering(true)?; + while let Some(chunk) = self.assembler.read(usize::MAX, true) { chunks[out.bufs] = chunk.bytes; out.read += chunks[out.bufs].len(); out.bufs += 1; @@ -314,13 +316,9 @@ pub enum ReadError { IllegalOrderedRead, } -impl From for ReadError { - fn from(e: AssembleError) -> Self { - use AssembleError::*; - match e { - IllegalOrderedRead => ReadError::IllegalOrderedRead, - UnknownStream => ReadError::UnknownStream, - } +impl From for ReadError { + fn from(_: IllegalOrderedRead) -> Self { + ReadError::IllegalOrderedRead } }