diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index 001d76f739..8dff4ba29f 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -47,6 +47,11 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { Box::pin(async move { Ok(()) }) } + /// Forward to [`Connection::shrink_buffers()`]. + /// + /// [`Connection::shrink_buffers()`]: method@crate::connection::Connection::shrink_buffers + fn shrink_buffers(&mut self); + #[doc(hidden)] fn flush(&mut self) -> BoxFuture<'_, crate::Result<()>>; diff --git a/sqlx-core/src/any/connection/mod.rs b/sqlx-core/src/any/connection/mod.rs index 004a94e9f6..57fc726ec2 100644 --- a/sqlx-core/src/any/connection/mod.rs +++ b/sqlx-core/src/any/connection/mod.rs @@ -92,6 +92,10 @@ impl Connection for AnyConnection { self.backend.clear_cached_statements() } + fn shrink_buffers(&mut self) { + self.backend.shrink_buffers() + } + #[doc(hidden)] fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> { self.backend.flush() diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index d437269157..dfafda686c 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -112,6 +112,20 @@ pub trait Connection: Send { Box::pin(async move { Ok(()) }) } + /// Restore any buffers in the connection to their default capacity, if possible. + /// + /// Sending a large query or receiving a resultset with many columns can cause the connection + /// to allocate additional buffer space to fit the data which is retained afterwards in + /// case it's needed again. This can give the outward appearance of a memory leak, but is + /// in fact the intended behavior. + /// + /// Calling this method tells the connection to release that excess memory if it can, + /// though be aware that calling this too often can cause unnecessary thrashing or + /// fragmentation in the global allocator. If there's still data in the connection buffers + /// (unlikely if the last query was run to completion) then it may need to be moved to + /// allow the buffers to shrink. + fn shrink_buffers(&mut self); + #[doc(hidden)] fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>>; diff --git a/sqlx-core/src/net/socket/buffered.rs b/sqlx-core/src/net/socket/buffered.rs index dc05c87863..3dd5f64146 100644 --- a/sqlx-core/src/net/socket/buffered.rs +++ b/sqlx-core/src/net/socket/buffered.rs @@ -1,6 +1,6 @@ use crate::net::Socket; use bytes::BytesMut; -use std::io; +use std::{cmp, io}; use crate::error::Error; @@ -46,26 +46,7 @@ impl BufferedSocket { } pub async fn read_buffered(&mut self, len: usize) -> io::Result { - while self.read_buf.read.len() < len { - self.read_buf.reserve(len); - - let read = self.socket.read(&mut self.read_buf.available).await?; - - if read == 0 { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - format!( - "expected to read {} bytes, got {} bytes at EOF", - len, - self.read_buf.read.len() - ), - )); - } - - self.read_buf.advance(read); - } - - Ok(self.read_buf.drain(len)) + self.read_buf.read(len, &mut self.socket).await } pub fn write_buffer(&self) -> &WriteBuffer { @@ -123,6 +104,12 @@ impl BufferedSocket { self.socket.shutdown().await } + pub fn shrink_buffers(&mut self) { + // Won't drop data still in the buffer. + self.write_buf.shrink(); + self.read_buf.shrink(); + } + pub fn into_inner(self) -> S { self.socket } @@ -197,6 +184,22 @@ impl WriteBuffer { &mut self.buf[self.bytes_flushed..self.bytes_written] } + pub fn shrink(&mut self) { + if self.bytes_flushed > 0 { + // Move any data that remains to be flushed to the beginning of the buffer, + // if necessary. + self.buf + .copy_within(self.bytes_flushed..self.bytes_written, 0); + self.bytes_written -= self.bytes_flushed; + self.bytes_flushed = 0 + } + + // Drop excess capacity. + self.buf + .truncate(cmp::max(self.bytes_written, DEFAULT_BUF_SIZE)); + self.buf.shrink_to_fit(); + } + fn consume(&mut self, amt: usize) { let new_bytes_flushed = self .bytes_flushed @@ -218,6 +221,31 @@ impl WriteBuffer { } impl ReadBuffer { + async fn read(&mut self, len: usize, socket: &mut impl Socket) -> io::Result { + // Because of how `BytesMut` works, we should only be shifting capacity back and forth + // between `read` and `available` unless we have to read an oversize message. + while self.read.len() < len { + self.reserve(len - self.read.len()); + + let read = socket.read(&mut self.available).await?; + + if read == 0 { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + format!( + "expected to read {} bytes, got {} bytes at EOF", + len, + self.read.len() + ), + )); + } + + self.advance(read); + } + + Ok(self.drain(len)) + } + fn reserve(&mut self, amt: usize) { if let Some(additional) = amt.checked_sub(self.available.capacity()) { self.available.reserve(additional); @@ -231,4 +259,21 @@ impl ReadBuffer { fn drain(&mut self, amt: usize) -> BytesMut { self.read.split_to(amt) } + + fn shrink(&mut self) { + if self.available.capacity() > DEFAULT_BUF_SIZE { + // `BytesMut` doesn't have a way to shrink its capacity, + // but we only use `available` for spare capacity anyway so we can just replace it. + // + // If `self.read` still contains data on the next call to `advance` then this might + // force a memcpy as they'll no longer be pointing to the same allocation, + // but that's kind of unavoidable. + // + // The `async-std` impl of `Socket` will also need to re-zero the buffer, + // but that's also kind of unavoidable. + // + // We should be warning the user not to call this often. + self.available = BytesMut::with_capacity(DEFAULT_BUF_SIZE); + } + } } diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index 9bd0ebb85b..e73f4f2735 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -73,6 +73,18 @@ impl AsMut for PoolConnection { } impl PoolConnection { + /// Close this connection, allowing the pool to open a replacement. + /// + /// Equivalent to calling [`.detach()`] then [`.close()`], but the connection permit is retained + /// for the duration so that the pool may not exceed `max_connections`. + /// + /// [`.detach()`]: PoolConnection::detach + /// [`.close()`]: Connection::close + pub async fn close(mut self) -> Result<(), Error> { + let floating = self.take_live().float(self.pool.clone()); + floating.inner.raw.close().await + } + /// Detach this connection from the pool, allowing it to open a replacement. /// /// Note that if your application uses a single shared pool, this diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index 303522fa6a..768d38ff83 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -54,6 +54,10 @@ impl AnyConnectionBackend for MySqlConnection { MySqlTransactionManager::start_rollback(self) } + fn shrink_buffers(&mut self) { + Connection::shrink_buffers(self); + } + fn flush(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { Connection::flush(self) } diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index 98144fc0d6..c7b3543d23 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -108,4 +108,8 @@ impl Connection for MySqlConnection { { Transaction::begin(self) } + + fn shrink_buffers(&mut self) { + self.stream.shrink_buffers(); + } } diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index fed7422046..32effc3121 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -57,6 +57,10 @@ impl AnyConnectionBackend for PgConnection { PgTransactionManager::start_rollback(self) } + fn shrink_buffers(&mut self) { + Connection::shrink_buffers(self); + } + fn flush(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { Connection::flush(self) } diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index e1a44e80e9..2bbc38fa4a 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -199,6 +199,10 @@ impl Connection for PgConnection { }) } + fn shrink_buffers(&mut self) { + self.stream.shrink_buffers(); + } + #[doc(hidden)] fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> { self.wait_until_ready().boxed() diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index 517d601986..e4cfb609cc 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -59,6 +59,10 @@ impl AnyConnectionBackend for SqliteConnection { SqliteTransactionManager::start_rollback(self) } + fn shrink_buffers(&mut self) { + // NO-OP. + } + fn flush(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { Connection::flush(self) } diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 2ea1b66ed9..903353be14 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -157,6 +157,11 @@ impl Connection for SqliteConnection { }) } + #[inline] + fn shrink_buffers(&mut self) { + // No-op. + } + #[doc(hidden)] fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> { // For SQLite, FLUSH does effectively nothing... diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index 009a23f5e0..58f4369095 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -446,3 +446,34 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn test_shrink_buffers() -> anyhow::Result<()> { + // We don't really have a good way to test that `.shrink_buffers()` functions as expected + // without exposing a lot of internals, but we can at least be sure it doesn't + // materially affect the operation of the connection. + + let mut conn = new::().await?; + + // The connection buffer is only 8 KiB by default so this should definitely force it to grow. + let data = "This string should be 32 bytes!\n".repeat(1024); + assert_eq!(data.len(), 32 * 1024); + + let ret: String = sqlx::query_scalar("SELECT ?") + .bind(&data) + .fetch_one(&mut conn) + .await?; + + assert_eq!(ret, data); + + conn.shrink_buffers(); + + let ret: i64 = sqlx::query_scalar("SELECT ?") + .bind(&12345678i64) + .fetch_one(&mut conn) + .await?; + + assert_eq!(ret, 12345678i64); + + Ok(()) +} diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 68650a10c8..7bee4d86ac 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -1789,3 +1789,33 @@ async fn test_postgres_bytea_hex_deserialization_errors() -> anyhow::Result<()> } Ok(()) } + +#[sqlx_macros::test] +async fn test_shrink_buffers() -> anyhow::Result<()> { + // We don't really have a good way to test that `.shrink_buffers()` functions as expected + // without exposing a lot of internals, but we can at least be sure it doesn't + // materially affect the operation of the connection. + + let mut conn = new::().await?; + + // The connection buffer is only 8 KiB by default so this should definitely force it to grow. + let data = vec![0u8; 32 * 1024]; + + let ret: Vec = sqlx::query_scalar("SELECT $1::bytea") + .bind(&data) + .fetch_one(&mut conn) + .await?; + + assert_eq!(ret, data); + + conn.shrink_buffers(); + + let ret: i64 = sqlx::query_scalar("SELECT $1::int8") + .bind(&12345678i64) + .fetch_one(&mut conn) + .await?; + + assert_eq!(ret, 12345678i64); + + Ok(()) +}