diff --git a/library/std/src/io/buffered/bufreader.rs b/library/std/src/io/buffered/bufreader.rs index 989cec976b72f..f7fbaa9c27649 100644 --- a/library/std/src/io/buffered/bufreader.rs +++ b/library/std/src/io/buffered/bufreader.rs @@ -1,9 +1,10 @@ -use crate::cmp; +mod buffer; + use crate::fmt; use crate::io::{ self, BufRead, IoSliceMut, Read, ReadBuf, Seek, SeekFrom, SizeHint, DEFAULT_BUF_SIZE, }; -use crate::mem::MaybeUninit; +use buffer::Buffer; /// The `BufReader` struct adds buffering to any reader. /// @@ -48,10 +49,7 @@ use crate::mem::MaybeUninit; #[stable(feature = "rust1", since = "1.0.0")] pub struct BufReader { inner: R, - buf: Box<[MaybeUninit]>, - pos: usize, - cap: usize, - init: usize, + buf: Buffer, } impl BufReader { @@ -93,8 +91,7 @@ impl BufReader { /// ``` #[stable(feature = "rust1", since = "1.0.0")] pub fn with_capacity(capacity: usize, inner: R) -> BufReader { - let buf = Box::new_uninit_slice(capacity); - BufReader { inner, buf, pos: 0, cap: 0, init: 0 } + BufReader { inner, buf: Buffer::with_capacity(capacity) } } } @@ -170,8 +167,7 @@ impl BufReader { /// ``` #[stable(feature = "bufreader_buffer", since = "1.37.0")] pub fn buffer(&self) -> &[u8] { - // SAFETY: self.cap is always <= self.init, so self.buf[self.pos..self.cap] is always init - unsafe { MaybeUninit::slice_assume_init_ref(&self.buf[self.pos..self.cap]) } + self.buf.buffer() } /// Returns the number of bytes the internal buffer can hold at once. @@ -194,7 +190,7 @@ impl BufReader { /// ``` #[stable(feature = "buffered_io_capacity", since = "1.46.0")] pub fn capacity(&self) -> usize { - self.buf.len() + self.buf.capacity() } /// Unwraps this `BufReader`, returning the underlying reader. @@ -224,8 +220,7 @@ impl BufReader { /// Invalidates all data in the internal buffer. #[inline] fn discard_buffer(&mut self) { - self.pos = 0; - self.cap = 0; + self.buf.discard_buffer() } } @@ -236,15 +231,15 @@ impl BufReader { /// must track this information themselves if it is required. #[stable(feature = "bufreader_seek_relative", since = "1.53.0")] pub fn seek_relative(&mut self, offset: i64) -> io::Result<()> { - let pos = self.pos as u64; + let pos = self.buf.pos() as u64; if offset < 0 { - if let Some(new_pos) = pos.checked_sub((-offset) as u64) { - self.pos = new_pos as usize; + if let Some(_) = pos.checked_sub((-offset) as u64) { + self.buf.unconsume((-offset) as usize); return Ok(()); } } else if let Some(new_pos) = pos.checked_add(offset as u64) { - if new_pos <= self.cap as u64 { - self.pos = new_pos as usize; + if new_pos <= self.buf.filled() as u64 { + self.buf.consume(offset as usize); return Ok(()); } } @@ -259,7 +254,7 @@ impl Read for BufReader { // If we don't have any buffered data and we're doing a massive read // (larger than our internal buffer), bypass our internal buffer // entirely. - if self.pos == self.cap && buf.len() >= self.buf.len() { + if self.buf.pos() == self.buf.filled() && buf.len() >= self.capacity() { self.discard_buffer(); return self.inner.read(buf); } @@ -275,7 +270,7 @@ impl Read for BufReader { // If we don't have any buffered data and we're doing a massive read // (larger than our internal buffer), bypass our internal buffer // entirely. - if self.pos == self.cap && buf.remaining() >= self.buf.len() { + if self.buf.pos() == self.buf.filled() && buf.remaining() >= self.capacity() { self.discard_buffer(); return self.inner.read_buf(buf); } @@ -295,9 +290,7 @@ impl Read for BufReader { // generation for the common path where the buffer has enough bytes to fill the passed-in // buffer. fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { - if self.buffer().len() >= buf.len() { - buf.copy_from_slice(&self.buffer()[..buf.len()]); - self.consume(buf.len()); + if self.buf.consume_with(buf.len(), |claimed| buf.copy_from_slice(claimed)) { return Ok(()); } @@ -306,7 +299,7 @@ impl Read for BufReader { fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { let total_len = bufs.iter().map(|b| b.len()).sum::(); - if self.pos == self.cap && total_len >= self.buf.len() { + if self.buf.pos() == self.buf.filled() && total_len >= self.capacity() { self.discard_buffer(); return self.inner.read_vectored(bufs); } @@ -325,8 +318,9 @@ impl Read for BufReader { // The inner reader might have an optimized `read_to_end`. Drain our buffer and then // delegate to the inner implementation. fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { - let nread = self.cap - self.pos; - buf.extend_from_slice(&self.buffer()); + let inner_buf = self.buffer(); + buf.extend_from_slice(inner_buf); + let nread = inner_buf.len(); self.discard_buffer(); Ok(nread + self.inner.read_to_end(buf)?) } @@ -371,33 +365,11 @@ impl Read for BufReader { #[stable(feature = "rust1", since = "1.0.0")] impl BufRead for BufReader { fn fill_buf(&mut self) -> io::Result<&[u8]> { - // If we've reached the end of our internal buffer then we need to fetch - // some more data from the underlying reader. - // Branch using `>=` instead of the more correct `==` - // to tell the compiler that the pos..cap slice is always valid. - if self.pos >= self.cap { - debug_assert!(self.pos == self.cap); - - let mut readbuf = ReadBuf::uninit(&mut self.buf); - - // SAFETY: `self.init` is either 0 or set to `readbuf.initialized_len()` - // from the last time this function was called - unsafe { - readbuf.assume_init(self.init); - } - - self.inner.read_buf(&mut readbuf)?; - - self.cap = readbuf.filled_len(); - self.init = readbuf.initialized_len(); - - self.pos = 0; - } - Ok(self.buffer()) + self.buf.fill_buf(&mut self.inner) } fn consume(&mut self, amt: usize) { - self.pos = cmp::min(self.pos + amt, self.cap); + self.buf.consume(amt) } } @@ -409,7 +381,10 @@ where fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("BufReader") .field("reader", &self.inner) - .field("buffer", &format_args!("{}/{}", self.cap - self.pos, self.buf.len())) + .field( + "buffer", + &format_args!("{}/{}", self.buf.filled() - self.buf.pos(), self.capacity()), + ) .finish() } } @@ -441,7 +416,7 @@ impl Seek for BufReader { fn seek(&mut self, pos: SeekFrom) -> io::Result { let result: u64; if let SeekFrom::Current(n) = pos { - let remainder = (self.cap - self.pos) as i64; + let remainder = (self.buf.filled() - self.buf.pos()) as i64; // it should be safe to assume that remainder fits within an i64 as the alternative // means we managed to allocate 8 exbibytes and that's absurd. // But it's not out of the realm of possibility for some weird underlying reader to @@ -499,7 +474,7 @@ impl Seek for BufReader { /// } /// ``` fn stream_position(&mut self) -> io::Result { - let remainder = (self.cap - self.pos) as u64; + let remainder = (self.buf.filled() - self.buf.pos()) as u64; self.inner.stream_position().map(|pos| { pos.checked_sub(remainder).expect( "overflow when subtracting remaining buffer size from inner stream position", diff --git a/library/std/src/io/buffered/bufreader/buffer.rs b/library/std/src/io/buffered/bufreader/buffer.rs new file mode 100644 index 0000000000000..8ae01f3b0ad8a --- /dev/null +++ b/library/std/src/io/buffered/bufreader/buffer.rs @@ -0,0 +1,105 @@ +///! An encapsulation of `BufReader`'s buffer management logic. +/// +/// This module factors out the basic functionality of `BufReader` in order to protect two core +/// invariants: +/// * `filled` bytes of `buf` are always initialized +/// * `pos` is always <= `filled` +/// Since this module encapsulates the buffer management logic, we can ensure that the range +/// `pos..filled` is always a valid index into the initialized region of the buffer. This means +/// that user code which wants to do reads from a `BufReader` via `buffer` + `consume` can do so +/// without encountering any runtime bounds checks. +use crate::cmp; +use crate::io::{self, Read, ReadBuf}; +use crate::mem::MaybeUninit; + +pub struct Buffer { + // The buffer. + buf: Box<[MaybeUninit]>, + // The current seek offset into `buf`, must always be <= `filled`. + pos: usize, + // Each call to `fill_buf` sets `filled` to indicate how many bytes at the start of `buf` are + // initialized with bytes from a read. + filled: usize, +} + +impl Buffer { + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + let buf = Box::new_uninit_slice(capacity); + Self { buf, pos: 0, filled: 0 } + } + + #[inline] + pub fn buffer(&self) -> &[u8] { + // SAFETY: self.pos and self.cap are valid, and self.cap => self.pos, and + // that region is initialized because those are all invariants of this type. + unsafe { MaybeUninit::slice_assume_init_ref(self.buf.get_unchecked(self.pos..self.filled)) } + } + + #[inline] + pub fn capacity(&self) -> usize { + self.buf.len() + } + + #[inline] + pub fn filled(&self) -> usize { + self.filled + } + + #[inline] + pub fn pos(&self) -> usize { + self.pos + } + + #[inline] + pub fn discard_buffer(&mut self) { + self.pos = 0; + self.filled = 0; + } + + #[inline] + pub fn consume(&mut self, amt: usize) { + self.pos = cmp::min(self.pos + amt, self.filled); + } + + /// If there are `amt` bytes available in the buffer, pass a slice containing those bytes to + /// `visitor` and return true. If there are not enough bytes available, return false. + #[inline] + pub fn consume_with(&mut self, amt: usize, mut visitor: V) -> bool + where + V: FnMut(&[u8]), + { + if let Some(claimed) = self.buffer().get(..amt) { + visitor(claimed); + // If the indexing into self.buffer() succeeds, amt must be a valid increment. + self.pos += amt; + true + } else { + false + } + } + + #[inline] + pub fn unconsume(&mut self, amt: usize) { + self.pos = self.pos.saturating_sub(amt); + } + + #[inline] + pub fn fill_buf(&mut self, mut reader: impl Read) -> io::Result<&[u8]> { + // If we've reached the end of our internal buffer then we need to fetch + // some more data from the reader. + // Branch using `>=` instead of the more correct `==` + // to tell the compiler that the pos..cap slice is always valid. + if self.pos >= self.filled { + debug_assert!(self.pos == self.filled); + + let mut readbuf = ReadBuf::uninit(&mut self.buf); + + reader.read_buf(&mut readbuf)?; + + self.filled = readbuf.filled_len(); + self.pos = 0; + } + Ok(self.buffer()) + } +} diff --git a/library/std/src/io/buffered/tests.rs b/library/std/src/io/buffered/tests.rs index 9d429e7090e83..fe45b13263844 100644 --- a/library/std/src/io/buffered/tests.rs +++ b/library/std/src/io/buffered/tests.rs @@ -523,6 +523,7 @@ fn bench_buffered_reader_small_reads(b: &mut test::Bencher) { let mut buf = [0u8; 4]; for _ in 0..1024 { reader.read_exact(&mut buf).unwrap(); + core::hint::black_box(&buf); } }); }