diff --git a/sqlx-core/src/io/buf_stream.rs b/sqlx-core/src/io/buf_stream.rs index 4646cb7a29..cbb24ad06e 100644 --- a/sqlx-core/src/io/buf_stream.rs +++ b/sqlx-core/src/io/buf_stream.rs @@ -82,6 +82,17 @@ where pub async fn read_raw_into(&mut self, buf: &mut BytesMut, cnt: usize) -> Result<(), Error> { read_raw_into(&mut self.stream, buf, cnt).await } + + /// Read data from stream into buf + /// * On success: The number of bytes read (> 0) is returned + /// * On error: The error is returned, no bytes will have been read from the stream + /// * When dropped: No bytes will have been read from the stream + pub async fn read_some(&mut self, buf: &mut [u8]) -> Result { + match self.stream.read(buf).await? { + 0 => Err(io::Error::from(io::ErrorKind::ConnectionAborted).into()), + v => Ok(v), + } + } } impl Deref for BufStream diff --git a/sqlx-core/src/mysql/connection/stream.rs b/sqlx-core/src/mysql/connection/stream.rs index 8b2f453608..59609d309c 100644 --- a/sqlx-core/src/mysql/connection/stream.rs +++ b/sqlx-core/src/mysql/connection/stream.rs @@ -11,6 +11,12 @@ use crate::mysql::protocol::{Capabilities, Packet}; use crate::mysql::{MySqlConnectOptions, MySqlDatabaseError}; use crate::net::{MaybeTlsStream, Socket}; +enum RecvPackageState { + ReadingHeader { data: [u8; 4], progress: usize }, + ReadingData { data: Vec, progress: usize }, + Done, +} + pub struct MySqlStream { stream: BufStream>, pub(crate) server_version: (u16, u16, u16), @@ -19,6 +25,7 @@ pub struct MySqlStream { pub(crate) busy: Busy, pub(crate) charset: CharSet, pub(crate) collation: Collation, + recv_package_state: RecvPackageState, } #[derive(Debug, PartialEq, Eq)] @@ -72,6 +79,7 @@ impl MySqlStream { collation, charset, stream: BufStream::new(MaybeTlsStream::Raw(socket)), + recv_package_state: RecvPackageState::Done, }) } @@ -137,14 +145,38 @@ impl MySqlStream { // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html // https://mariadb.com/kb/en/library/0-packet/#standard-packet - let mut header: Bytes = self.stream.read(4).await?; - - let packet_size = header.get_uint_le(3) as usize; - let sequence_id = header.get_u8(); + if let RecvPackageState::Done = self.recv_package_state { + self.recv_package_state = RecvPackageState::ReadingHeader { + data: [0; 4], + progress: 0, + }; + } - self.sequence_id = sequence_id.wrapping_add(1); + if let RecvPackageState::ReadingHeader { data, progress } = &mut self.recv_package_state { + while *progress != 4 { + *progress += self.stream.read_some(&mut data[*progress..]).await?; + } + let mut header = data.as_ref(); + let packet_size = header.get_uint_le(3) as usize; + let sequence_id = header.get_u8(); + self.sequence_id = sequence_id.wrapping_add(1); + let mut data = Vec::new(); + data.resize(packet_size, 0); + self.recv_package_state = RecvPackageState::ReadingData { data, progress: 0 }; + } - let payload: Bytes = self.stream.read(packet_size).await?; + let payload: Bytes = if let RecvPackageState::ReadingData { data, progress } = + &mut self.recv_package_state + { + while *progress != data.len() { + *progress += self.stream.read_some(&mut data[*progress..]).await?; + } + let data = std::mem::take(data); + self.recv_package_state = RecvPackageState::Done; + data.into() + } else { + panic!("Unexpected RecvPackageState") + }; // TODO: packet compression // TODO: packet joining diff --git a/sqlx-macros/src/derives/attributes.rs b/sqlx-macros/src/derives/attributes.rs index 373feaae98..13bc624e46 100644 --- a/sqlx-macros/src/derives/attributes.rs +++ b/sqlx-macros/src/derives/attributes.rs @@ -55,7 +55,10 @@ pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result anyhow::Result<()> { Ok(()) } +#[tokio::test] +async fn drop_test() -> anyhow::Result<()> { + const CNT: usize = 222; + const POOL_SIZE: u32 = 2; + + setup_if_needed(); + + let pool = MySqlPoolOptions::new() + .max_connections(POOL_SIZE) + .connect(&std::env::var("DATABASE_URL").unwrap()) + .await?; + // Create a temporery table and insert a lot of stuff + { + let mut conn = pool.acquire().await.unwrap(); + sqlx::query("DROP TABLE IF EXISTS drop_test") + .execute(&mut conn) + .await?; + sqlx::query("CREATE TABLE drop_test (id BIGINT PRIMARY KEY AUTO_INCREMENT)") + .execute(&mut conn) + .await?; + let mut q = "INSERT INTO drop_test () VALUES ()".to_string(); + for _ in 0..CNT { + q.push_str(", ()"); + } + sqlx::query(&q).execute(&mut conn).await?; + } + + // It is somewhat tricky to get the timing right for failures to occur + // so we repeat the test a number of times + for _ in 0..50 { + // Create a bunch of long running jobs + // that hopefully will be dropped in read + let s = std::time::Instant::now(); + let mut futures = vec![]; + for i in 0..POOL_SIZE { + let s = s.clone(); + let pool = pool.clone(); + futures.push(async move { + let mut conn = pool.acquire().await.unwrap(); + { + let mut stream = sqlx::query("SELECT 1 FROM drop_test AS a, drop_test AS b") + .fetch(&mut conn); + while let Some(_) = stream.try_next().await? {} + } + println!("Thread {} finished {}", i, s.elapsed().as_secs_f64()); + Result::<(), anyhow::Error>::Ok(()) + }); + } + + /// Some "feature" in tokio causes the timeout to never occur if the + /// sleep time is more than one + #[cfg(feature = "_rt-tokio")] + fn drop_test_timeout() -> u64 { + 1 + } + + #[cfg(not(feature = "_rt-tokio"))] + fn drop_test_timeout() -> u64 { + 23 + } + + if let Ok(_) = sqlx_rt::timeout( + std::time::Duration::from_millis(drop_test_timeout()), + futures::future::join_all(futures), + ) + .await + { + println!( + "All queries finished before timeout, this should not happen. We waited {}s", + s.elapsed().as_secs_f64() + ); + continue; + } + println!("Timeout after {}s", s.elapsed().as_secs_f64()); + + // Perform some query and check the result + let pool = pool.clone(); + let f = async move { + let mut conn = pool.acquire().await.unwrap(); + let row = sqlx::query("SELECT CAST(SUM(id) AS UNSIGNED) AS s FROM drop_test") + .fetch_one(&mut conn) + .await?; + let s: u64 = row.try_get("s")?; + assert_eq!(s, ((CNT + 1) * (CNT + 2) / 2) as u64); + Result::<(), anyhow::Error>::Ok(()) + }; + // We add a timeout here as bugs in the drop handling can cause us to + // wait for gigabytes of data to be pushed from mysql + sqlx_rt::timeout(std::time::Duration::from_secs(7), f).await??; + } + + { + let mut conn = pool.acquire().await.unwrap(); + sqlx::query("DROP TABLE IF EXISTS drop_test") + .execute(&mut conn) + .await?; + } + + Ok(()) +} + // repro is more reliable with the basic scheduler used by `#[tokio::test]` #[cfg(feature = "_rt-tokio")] #[tokio::test]