diff --git a/quinn-proto/src/connection/assembler.rs b/quinn-proto/src/connection/assembler.rs index 617b9695b1..c9c2a136ac 100644 --- a/quinn-proto/src/connection/assembler.rs +++ b/quinn-proto/src/connection/assembler.rs @@ -1,5 +1,5 @@ use std::{ - cmp::Ordering, + cmp::{max, Ordering}, collections::{binary_heap::PeekMut, BinaryHeap}, mem, }; @@ -13,7 +13,8 @@ use crate::range_set::RangeSet; pub(crate) struct Assembler { state: State, data: BinaryHeap, - defragmented: usize, + buffered: usize, + allocated: usize, /// Number of bytes read by the application. When only ordered reads have been used, this is the /// length of the contiguous prefix of the stream which has been consumed by the application, /// aka the stream offset. @@ -54,8 +55,9 @@ impl Assembler { return None; } else if (chunk.offset + chunk.bytes.len() as u64) <= self.bytes_read { // Next chunk is useless as the read index is beyond its end + self.buffered -= chunk.size; + self.allocated -= chunk.allocation_size; PeekMut::pop(chunk); - self.defragmented = self.defragmented.saturating_sub(1); continue; } @@ -74,7 +76,8 @@ impl Assembler { Chunk::new(offset, chunk.bytes.split_to(max_length)) } else { self.bytes_read += chunk.bytes.len() as u64; - self.defragmented = self.defragmented.saturating_sub(1); + self.buffered -= chunk.size; + self.allocated -= chunk.allocation_size; let chunk = PeekMut::pop(chunk); Chunk::new(chunk.offset, chunk.bytes) }); @@ -87,8 +90,13 @@ impl Assembler { // counter to the new number of chunks left in the heap so that we can decide // when to defragment the queue again if necessary. fn defragment(&mut self) { - let buffered = self.data.iter().map(|c| c.bytes.len()).sum::(); - let mut buffer = BytesMut::with_capacity(buffered); + let fragmented_buffered = self + .data + .iter() + .filter(|c| c.is_fragmented()) + .map(|c| c.bytes.len()) + .sum::(); + let mut buffer = BytesMut::with_capacity(fragmented_buffered); let mut offset = self .data .peek() @@ -99,33 +107,49 @@ impl Assembler { let new = BinaryHeap::with_capacity(self.data.len()); let old = mem::replace(&mut self.data, new); for chunk in old.into_sorted_vec().into_iter().rev() { + if !chunk.is_fragmented() { + self.data.push(chunk); + continue; + } let end = offset + (buffer.len() as u64); if let Some(overlap) = end.checked_sub(chunk.offset) { if let Some(bytes) = chunk.bytes.get(overlap as usize..) { buffer.extend_from_slice(bytes); } } else { - let bytes = buffer.split().freeze(); - self.data.push(Buffer { offset, bytes }); + self.data + .push(Buffer::new_defragmented(offset, buffer.split().freeze())); offset = chunk.offset; buffer.extend_from_slice(&chunk.bytes); } } - let bytes = buffer.split().freeze(); - self.data.push(Buffer { offset, bytes }); - self.defragmented = self.data.len(); + self.data + .push(Buffer::new_defragmented(offset, buffer.split().freeze())); + self.allocated = self.buffered; } - pub(crate) fn insert(&mut self, mut offset: u64, mut bytes: Bytes) { + // Note: If a packet contains many frames from the same stream, the estimated over-allocation + // will be much higher because we are counting the same allocation multiple times. + pub(crate) fn insert(&mut self, mut offset: u64, mut bytes: Bytes, allocation_size: usize) { + debug_assert!( + bytes.len() <= allocation_size, + "allocation_size less than bytes.len(): {:?} < {:?}", + allocation_size, + bytes.len() + ); if let State::Unordered { ref mut recvd } = self.state { // Discard duplicate data for duplicate in recvd.replace(offset..offset + bytes.len() as u64) { if duplicate.start > offset { - self.data.push(Buffer { + let buffer = Buffer::new( offset, - bytes: bytes.split_to((duplicate.start - offset) as usize), - }); + bytes.split_to((duplicate.start - offset) as usize), + allocation_size, + ); + self.buffered += buffer.size; + self.allocated += buffer.allocation_size; + self.data.push(buffer); offset = duplicate.start; } bytes.advance((duplicate.end - offset) as usize); @@ -144,16 +168,20 @@ impl Assembler { if bytes.is_empty() { return; } - - self.data.push(Buffer { offset, bytes }); - // Why 32: on the one hand, we want to defragment rarely, ideally never + let buffer = Buffer::new(offset, bytes, allocation_size); + self.buffered += buffer.size; + self.allocated += buffer.allocation_size; + self.data.push(buffer); + // Rationale: on the one hand, we want to defragment rarely, ideally never // in non-pathological scenarios. However, a pathological or malicious // peer could send us one-byte frames, and since we use reference-counted // buffers in order to prevent copying, this could result in keeping a lot - // of memory allocated. In the worst case scenario of 32 1-byte chunks, - // each one from a ~1000-byte datagram, using 32 limits us to having a - // maximum pathological over-allocation of about 32k bytes. - if self.data.len() - self.defragmented > 32 { + // of memory allocated. This limits over-allocation in proportion to the + // buffered data. The constants are chosen somewhat arbitrarily and try to + // balance between defragmentation overhead and over-allocation. + let over_allocation = self.allocated - self.buffered; + let threshold = max(self.buffered * 3 / 2, 32 * 1024); + if over_allocation > threshold { self.defragment() } } @@ -170,7 +198,8 @@ impl Assembler { /// Discard all buffered data pub(crate) fn clear(&mut self) { self.data.clear(); - self.defragmented = 0; + self.buffered = 0; + self.allocated = 0; } } @@ -193,6 +222,44 @@ impl Chunk { struct Buffer { offset: u64, bytes: Bytes, + size: usize, + allocation_size: usize, +} + +impl Buffer { + /// Constructs a new, possibly fragmented Buffer + fn new(offset: u64, bytes: Bytes, allocation_size: usize) -> Self { + let size = bytes.len(); + // Treat buffers with small over-allocation as defragmented + let threshold = size * 6 / 5; + let allocation_size = if allocation_size > threshold { + allocation_size + } else { + size + }; + Self { + offset, + bytes, + size, + allocation_size, + } + } + + /// Constructs a new Buffer that is not fragmented + fn new_defragmented(offset: u64, bytes: Bytes) -> Self { + let size = bytes.len(); + Self { + offset, + bytes, + size, + allocation_size: size, + } + } + + /// Returns `true` if the buffer is fragmented + fn is_fragmented(&self) -> bool { + self.size < self.allocation_size + } } impl Ord for Buffer { @@ -253,13 +320,13 @@ mod test { fn assemble_ordered() { let mut x = Assembler::new(); assert_matches!(next(&mut x, 32), None); - x.insert(0, Bytes::from_static(b"123")); + x.insert(0, Bytes::from_static(b"123"), 3); assert_matches!(next(&mut x, 1), Some(ref y) if &y[..] == b"1"); assert_matches!(next(&mut x, 3), Some(ref y) if &y[..] == b"23"); - x.insert(3, Bytes::from_static(b"456")); + x.insert(3, Bytes::from_static(b"456"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"456"); - x.insert(6, Bytes::from_static(b"789")); - x.insert(9, Bytes::from_static(b"10")); + x.insert(6, Bytes::from_static(b"789"), 3); + x.insert(9, Bytes::from_static(b"10"), 2); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"789"); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"10"); assert_matches!(next(&mut x, 32), None); @@ -269,9 +336,9 @@ mod test { fn assemble_unordered() { let mut x = Assembler::new(); x.ensure_ordering(false).unwrap(); - x.insert(3, Bytes::from_static(b"456")); + x.insert(3, Bytes::from_static(b"456"), 3); assert_matches!(next(&mut x, 32), None); - x.insert(0, Bytes::from_static(b"123")); + x.insert(0, Bytes::from_static(b"123"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123"); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"456"); assert_matches!(next(&mut x, 32), None); @@ -280,8 +347,8 @@ mod test { #[test] fn assemble_duplicate() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"123")); - x.insert(0, Bytes::from_static(b"123")); + x.insert(0, Bytes::from_static(b"123"), 3); + x.insert(0, Bytes::from_static(b"123"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123"); assert_matches!(next(&mut x, 32), None); } @@ -289,8 +356,8 @@ mod test { #[test] fn assemble_duplicate_compact() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"123")); - x.insert(0, Bytes::from_static(b"123")); + x.insert(0, Bytes::from_static(b"123"), 3); + x.insert(0, Bytes::from_static(b"123"), 3); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123"); assert_matches!(next(&mut x, 32), None); @@ -299,8 +366,8 @@ mod test { #[test] fn assemble_contained() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"12345")); - x.insert(1, Bytes::from_static(b"234")); + x.insert(0, Bytes::from_static(b"12345"), 5); + x.insert(1, Bytes::from_static(b"234"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345"); assert_matches!(next(&mut x, 32), None); } @@ -308,8 +375,8 @@ mod test { #[test] fn assemble_contained_compact() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"12345")); - x.insert(1, Bytes::from_static(b"234")); + x.insert(0, Bytes::from_static(b"12345"), 5); + x.insert(1, Bytes::from_static(b"234"), 3); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345"); assert_matches!(next(&mut x, 32), None); @@ -318,8 +385,8 @@ mod test { #[test] fn assemble_contains() { let mut x = Assembler::new(); - x.insert(1, Bytes::from_static(b"234")); - x.insert(0, Bytes::from_static(b"12345")); + x.insert(1, Bytes::from_static(b"234"), 3); + x.insert(0, Bytes::from_static(b"12345"), 5); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345"); assert_matches!(next(&mut x, 32), None); } @@ -327,8 +394,8 @@ mod test { #[test] fn assemble_contains_compact() { let mut x = Assembler::new(); - x.insert(1, Bytes::from_static(b"234")); - x.insert(0, Bytes::from_static(b"12345")); + x.insert(1, Bytes::from_static(b"234"), 3); + x.insert(0, Bytes::from_static(b"12345"), 5); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345"); assert_matches!(next(&mut x, 32), None); @@ -337,8 +404,8 @@ mod test { #[test] fn assemble_overlapping() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"123")); - x.insert(1, Bytes::from_static(b"234")); + x.insert(0, Bytes::from_static(b"123"), 3); + x.insert(1, Bytes::from_static(b"234"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123"); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"4"); assert_matches!(next(&mut x, 32), None); @@ -347,8 +414,8 @@ mod test { #[test] fn assemble_overlapping_compact() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"123")); - x.insert(1, Bytes::from_static(b"234")); + x.insert(0, Bytes::from_static(b"123"), 4); + x.insert(1, Bytes::from_static(b"234"), 4); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"1234"); assert_matches!(next(&mut x, 32), None); @@ -357,10 +424,10 @@ mod test { #[test] fn assemble_complex() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"1")); - x.insert(2, Bytes::from_static(b"3")); - x.insert(4, Bytes::from_static(b"5")); - x.insert(0, Bytes::from_static(b"123456")); + x.insert(0, Bytes::from_static(b"1"), 1); + x.insert(2, Bytes::from_static(b"3"), 1); + x.insert(4, Bytes::from_static(b"5"), 1); + x.insert(0, Bytes::from_static(b"123456"), 6); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123456"); assert_matches!(next(&mut x, 32), None); } @@ -368,10 +435,10 @@ mod test { #[test] fn assemble_complex_compact() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"1")); - x.insert(2, Bytes::from_static(b"3")); - x.insert(4, Bytes::from_static(b"5")); - x.insert(0, Bytes::from_static(b"123456")); + x.insert(0, Bytes::from_static(b"1"), 1); + x.insert(2, Bytes::from_static(b"3"), 1); + x.insert(4, Bytes::from_static(b"5"), 1); + x.insert(0, Bytes::from_static(b"123456"), 6); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123456"); assert_matches!(next(&mut x, 32), None); @@ -380,19 +447,19 @@ mod test { #[test] fn assemble_old() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"1234")); + x.insert(0, Bytes::from_static(b"1234"), 4); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"1234"); - x.insert(0, Bytes::from_static(b"1234")); + x.insert(0, Bytes::from_static(b"1234"), 4); assert_matches!(next(&mut x, 32), None); } #[test] fn compact() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"abc")); - x.insert(3, Bytes::from_static(b"def")); - x.insert(9, Bytes::from_static(b"jkl")); - x.insert(12, Bytes::from_static(b"mno")); + x.insert(0, Bytes::from_static(b"abc"), 4); + x.insert(3, Bytes::from_static(b"def"), 4); + x.insert(9, Bytes::from_static(b"jkl"), 4); + x.insert(12, Bytes::from_static(b"mno"), 4); x.defragment(); assert_eq!( next_unordered(&mut x), @@ -407,7 +474,7 @@ mod test { #[test] fn defrag_with_missing_prefix() { let mut x = Assembler::new(); - x.insert(3, Bytes::from_static(b"def")); + x.insert(3, Bytes::from_static(b"def"), 3); x.defragment(); assert_eq!( next_unordered(&mut x), @@ -418,17 +485,17 @@ mod test { #[test] fn defrag_read_chunk() { let mut x = Assembler::new(); - x.insert(3, Bytes::from_static(b"def")); - x.insert(0, Bytes::from_static(b"abc")); - x.insert(7, Bytes::from_static(b"hij")); - x.insert(11, Bytes::from_static(b"lmn")); + x.insert(3, Bytes::from_static(b"def"), 4); + x.insert(0, Bytes::from_static(b"abc"), 4); + x.insert(7, Bytes::from_static(b"hij"), 4); + x.insert(11, Bytes::from_static(b"lmn"), 4); x.defragment(); assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"abcdef"); - x.insert(5, Bytes::from_static(b"fghijklmn")); + x.insert(5, Bytes::from_static(b"fghijklmn"), 9); assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"ghijklmn"); - x.insert(13, Bytes::from_static(b"nopq")); + x.insert(13, Bytes::from_static(b"nopq"), 4); assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"opq"); - x.insert(15, Bytes::from_static(b"pqrs")); + x.insert(15, Bytes::from_static(b"pqrs"), 4); assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"rs"); assert_matches!(x.read(usize::MAX, true), None); } @@ -437,13 +504,13 @@ mod test { fn unordered_happy_path() { let mut x = Assembler::new(); x.ensure_ordering(false).unwrap(); - x.insert(0, Bytes::from_static(b"abc")); + x.insert(0, Bytes::from_static(b"abc"), 3); assert_eq!( next_unordered(&mut x), Chunk::new(0, Bytes::from_static(b"abc")) ); assert_eq!(x.read(usize::MAX, false), None); - x.insert(3, Bytes::from_static(b"def")); + x.insert(3, Bytes::from_static(b"def"), 3); assert_eq!( next_unordered(&mut x), Chunk::new(3, Bytes::from_static(b"def")) @@ -455,15 +522,15 @@ mod test { fn unordered_dedup() { let mut x = Assembler::new(); x.ensure_ordering(false).unwrap(); - x.insert(3, Bytes::from_static(b"def")); + x.insert(3, Bytes::from_static(b"def"), 3); assert_eq!( next_unordered(&mut x), Chunk::new(3, Bytes::from_static(b"def")) ); assert_eq!(x.read(usize::MAX, false), None); - x.insert(0, Bytes::from_static(b"a")); - x.insert(0, Bytes::from_static(b"abcdefghi")); - x.insert(0, Bytes::from_static(b"abcd")); + x.insert(0, Bytes::from_static(b"a"), 1); + x.insert(0, Bytes::from_static(b"abcdefghi"), 9); + x.insert(0, Bytes::from_static(b"abcd"), 4); assert_eq!( next_unordered(&mut x), Chunk::new(0, Bytes::from_static(b"a")) @@ -477,30 +544,30 @@ mod test { Chunk::new(6, Bytes::from_static(b"ghi")) ); assert_eq!(x.read(usize::MAX, false), None); - x.insert(8, Bytes::from_static(b"ijkl")); + x.insert(8, Bytes::from_static(b"ijkl"), 4); assert_eq!( next_unordered(&mut x), Chunk::new(9, Bytes::from_static(b"jkl")) ); assert_eq!(x.read(usize::MAX, false), None); - x.insert(12, Bytes::from_static(b"mno")); + x.insert(12, Bytes::from_static(b"mno"), 3); assert_eq!( next_unordered(&mut x), Chunk::new(12, Bytes::from_static(b"mno")) ); assert_eq!(x.read(usize::MAX, false), None); - x.insert(2, Bytes::from_static(b"cde")); + x.insert(2, Bytes::from_static(b"cde"), 3); assert_eq!(x.read(usize::MAX, false), None); } #[test] fn chunks_dedup() { let mut x = Assembler::new(); - x.insert(3, Bytes::from_static(b"def")); + x.insert(3, Bytes::from_static(b"def"), 3); assert_eq!(x.read(usize::MAX, true), None); - x.insert(0, Bytes::from_static(b"a")); - x.insert(1, Bytes::from_static(b"bcdefghi")); - x.insert(0, Bytes::from_static(b"abcd")); + x.insert(0, Bytes::from_static(b"a"), 1); + x.insert(1, Bytes::from_static(b"bcdefghi"), 9); + x.insert(0, Bytes::from_static(b"abcd"), 4); assert_eq!( x.read(usize::MAX, true), Some(Chunk::new(0, Bytes::from_static(b"abcd"))) @@ -510,39 +577,41 @@ mod test { Some(Chunk::new(4, Bytes::from_static(b"efghi"))) ); assert_eq!(x.read(usize::MAX, true), None); - x.insert(8, Bytes::from_static(b"ijkl")); + x.insert(8, Bytes::from_static(b"ijkl"), 4); assert_eq!( x.read(usize::MAX, true), Some(Chunk::new(9, Bytes::from_static(b"jkl"))) ); assert_eq!(x.read(usize::MAX, true), None); - x.insert(12, Bytes::from_static(b"mno")); + x.insert(12, Bytes::from_static(b"mno"), 3); assert_eq!( x.read(usize::MAX, true), Some(Chunk::new(12, Bytes::from_static(b"mno"))) ); assert_eq!(x.read(usize::MAX, true), None); - x.insert(2, Bytes::from_static(b"cde")); + x.insert(2, Bytes::from_static(b"cde"), 3); assert_eq!(x.read(usize::MAX, true), None); } #[test] fn ordered_eager_discard() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"abc")); + x.insert(0, Bytes::from_static(b"abc"), 3); assert_eq!(x.data.len(), 1); assert_eq!( x.read(usize::MAX, true), Some(Chunk::new(0, Bytes::from_static(b"abc"))) ); - x.insert(0, Bytes::from_static(b"ab")); + x.insert(0, Bytes::from_static(b"ab"), 2); assert_eq!(x.data.len(), 0); - x.insert(2, Bytes::from_static(b"cd")); + x.insert(2, Bytes::from_static(b"cd"), 2); assert_eq!( x.data.peek(), Some(&Buffer { offset: 3, - bytes: Bytes::from_static(b"d") + bytes: Bytes::from_static(b"d"), + size: 1, + allocation_size: 2 }) ); } diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index b79b977f03..62edd7f598 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -1756,6 +1756,7 @@ where &mut self, space: SpaceId, crypto: &frame::Crypto, + payload_len: usize, ) -> Result<(), TransportError> { let expected = if !self.state.is_handshake() { SpaceId::Data @@ -1790,7 +1791,7 @@ where space .crypto_stream - .insert(crypto.offset, crypto.data.clone()); + .insert(crypto.offset, crypto.data.clone(), payload_len); while let Some(chunk) = space.crypto_stream.read(usize::MAX, true) { trace!("consumed {} CRYPTO bytes", chunk.bytes.len()); if self.crypto.read_handshake(&chunk.bytes)? { @@ -2325,6 +2326,7 @@ where packet: Packet, ) -> Result<(), TransportError> { debug_assert_ne!(packet.header.space(), SpaceId::Data); + let payload_len = packet.payload.len(); for frame in frame::Iter::new(packet.payload.freeze()) { let span = match frame { Frame::Padding => continue, @@ -2345,7 +2347,7 @@ where match frame { Frame::Padding | Frame::Ping => {} Frame::Crypto(frame) => { - self.read_crypto(packet.header.space(), &frame)?; + self.read_crypto(packet.header.space(), &frame, payload_len)?; } Frame::Ack(ack) => { self.on_ack_received(now, packet.header.space(), ack)?; @@ -2383,6 +2385,7 @@ where let is_0rtt = self.spaces[SpaceId::Data].crypto.is_none(); let mut is_probing_packet = true; let mut close = None; + let payload_len = payload.len(); for frame in frame::Iter::new(payload) { let span = match frame { Frame::Padding => continue, @@ -2427,10 +2430,10 @@ where return Err(err); } Frame::Crypto(frame) => { - self.read_crypto(SpaceId::Data, &frame)?; + self.read_crypto(SpaceId::Data, &frame, payload_len)?; } Frame::Stream(frame) => { - if self.streams.received(frame)?.should_transmit() { + if self.streams.received(frame, payload_len)?.should_transmit() { self.spaces[SpaceId::Data].pending.max_data = true; } } diff --git a/quinn-proto/src/connection/streams.rs b/quinn-proto/src/connection/streams.rs index 91b073e7e0..e81da0c2d5 100644 --- a/quinn-proto/src/connection/streams.rs +++ b/quinn-proto/src/connection/streams.rs @@ -281,7 +281,11 @@ impl Streams { /// Process incoming stream frame /// /// If successful, returns whether a `MAX_DATA` frame needs to be transmitted - pub fn received(&mut self, frame: frame::Stream) -> Result { + pub fn received( + &mut self, + frame: frame::Stream, + payload_len: usize, + ) -> Result { trace!(id = %frame.id, offset = frame.offset, len = frame.data.len(), fin = frame.fin, "got stream"); let stream = frame.id; self.validate_receive_id(stream).map_err(|e| { @@ -302,7 +306,7 @@ impl Streams { return Ok(ShouldTransmit(false)); } - let new_bytes = rs.ingest(frame, self.data_recvd, self.local_max_data)?; + let new_bytes = rs.ingest(frame, payload_len, self.data_recvd, self.local_max_data)?; self.data_recvd = self.data_recvd.saturating_add(new_bytes); if !rs.stopped { @@ -1079,12 +1083,15 @@ mod tests { let initial_max = client.local_max_data; assert_eq!( client - .received(frame::Stream { - id, - offset: 0, - fin: false, - data: Bytes::from_static(&[0; 2048]), - }) + .received( + frame::Stream { + id, + offset: 0, + fin: false, + data: Bytes::from_static(&[0; 2048]), + }, + 2048 + ) .unwrap(), ShouldTransmit(false) ); @@ -1113,12 +1120,15 @@ mod tests { let initial_max = client.local_max_data; assert_eq!( client - .received(frame::Stream { - id, - offset: 4096, - fin: false, - data: Bytes::from_static(&[0; 0]), - }) + .received( + frame::Stream { + id, + offset: 4096, + fin: false, + data: Bytes::from_static(&[0; 0]), + }, + 0 + ) .unwrap(), ShouldTransmit(false) ); @@ -1173,12 +1183,15 @@ mod tests { let initial_max = client.local_max_data; assert_eq!( client - .received(frame::Stream { - id, - offset: 0, - fin: false, - data: Bytes::from_static(&[0; 32]), - }) + .received( + frame::Stream { + id, + offset: 0, + fin: false, + data: Bytes::from_static(&[0; 32]), + }, + 32 + ) .unwrap(), ShouldTransmit(false) ); @@ -1199,12 +1212,15 @@ mod tests { assert_eq!(client.local_max_data - initial_max, 32); assert_eq!( client - .received(frame::Stream { - id, - offset: 32, - fin: true, - data: Bytes::from_static(&[0; 16]), - }) + .received( + frame::Stream { + id, + offset: 32, + fin: true, + data: Bytes::from_static(&[0; 16]), + }, + 16 + ) .unwrap(), ShouldTransmit(false) ); @@ -1219,12 +1235,15 @@ mod tests { // Server opens stream assert_eq!( client - .received(frame::Stream { - id, - offset: 0, - fin: false, - data: Bytes::from_static(&[0; 32]), - }) + .received( + frame::Stream { + id, + offset: 0, + fin: false, + data: Bytes::from_static(&[0; 32]) + }, + 32 + ) .unwrap(), ShouldTransmit(false) ); diff --git a/quinn-proto/src/connection/streams/recv.rs b/quinn-proto/src/connection/streams/recv.rs index 7c305e7c4a..bcbeeec76a 100644 --- a/quinn-proto/src/connection/streams/recv.rs +++ b/quinn-proto/src/connection/streams/recv.rs @@ -29,6 +29,7 @@ impl Recv { pub(super) fn ingest( &mut self, frame: frame::Stream, + payload_len: usize, received: u64, max_data: u64, ) -> Result { @@ -60,7 +61,7 @@ impl Recv { self.end = self.end.max(end); if !self.stopped { - self.assembler.insert(frame.offset, frame.data); + self.assembler.insert(frame.offset, frame.data, payload_len); } else { self.assembler.set_bytes_read(end); }