diff --git a/Cargo.lock b/Cargo.lock index 9bb3a082bb..dabb96dc93 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,6 +39,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "anyhow" version = "1.0.69" @@ -99,6 +114,19 @@ dependencies = [ "syn", ] +[[package]] +name = "async-compression" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "942c7cd7ae39e91bde4820d74132e9862e62c2f386c3aa90ccf55949f5bad63a" +dependencies = [ + "brotli", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -212,6 +240,27 @@ dependencies = [ "generic-array", ] +[[package]] +name = "brotli" +version = "3.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1a0b1dbcc8ae29329621f8d4f0d835787c1c38bb1401979b49d13b0b305ff68" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "2.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b6561fd3f895a11e8f72af2cb7d22e08366bebc2b6b57f7744c4bda27034744" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bumpalo" version = "3.12.0" @@ -930,6 +979,7 @@ version = "0.3.0" dependencies = [ "abao", "anyhow", + "async-compression", "base64 0.21.0", "blake3", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 036b4ca075..89ce62966a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ rust-version = "1.63" [dependencies] abao = { version = "0.2.0", features = ["group_size_16k", "tokio_io"], default-features = false } anyhow = { version = "1", features = ["backtrace"] } +async-compression = { version = "0.3.15", features = ["tokio", "brotli"] } base64 = "0.21.0" blake3 = "1.3.3" bytes = "1" diff --git a/src/get.rs b/src/get.rs index ed12394749..29f4cb35a7 100644 --- a/src/get.rs +++ b/src/get.rs @@ -19,7 +19,7 @@ use anyhow::{anyhow, bail, Context, Result}; use bytes::BytesMut; use futures::Future; use postcard::experimental::max_size::MaxSize; -use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; +use tokio::io::{AsyncRead, AsyncReadExt, BufReader, ReadBuf}; use tracing::{debug, error}; pub use crate::util::Hash; @@ -103,10 +103,13 @@ impl Stats { /// We guarantee that the data is correct by incrementally verifying a hash #[repr(transparent)] #[derive(Debug)] -pub struct DataStream(AsyncSliceDecoder); +pub struct DataStream(AsyncSliceDecoder); + +type RecvStream = + async_compression::tokio::bufread::BrotliDecoder>; impl DataStream { - fn new(inner: quinn::RecvStream, hash: Hash) -> Self { + fn new(inner: RecvStream, hash: Hash) -> Self { DataStream(AsyncSliceDecoder::new(inner, &hash.into(), 0, u64::MAX)) } @@ -114,7 +117,7 @@ impl DataStream { self.0.read_size().await } - fn into_inner(self) -> quinn::RecvStream { + fn into_inner(self) -> RecvStream { self.0.into_inner() } } @@ -149,7 +152,7 @@ where let now = Instant::now(); let connection = setup(opts).await?; - let (mut writer, mut reader) = connection.open_bi().await?; + let (mut writer, reader) = connection.open_bi().await?; on_connected().await?; @@ -181,6 +184,7 @@ where { debug!("reading response"); let mut in_buffer = BytesMut::with_capacity(1024); + let mut reader = BufReader::new(reader); // track total amount of blob data transferred let mut data_len = 0; @@ -218,7 +222,7 @@ where if blob_reader.read_exact(&mut [0u8; 1]).await.is_ok() { bail!("`on_blob` callback did not fully read the blob content") } - reader = blob_reader.into_inner(); + reader = blob_reader.into_inner().into_inner(); } } @@ -236,11 +240,12 @@ where } // Shut down the stream - if let Some(chunk) = reader.read_chunk(8, false).await? { - reader.stop(0u8.into()).ok(); - error!("Received unexpected data from the provider: {chunk:?}"); + if let Ok(bytes) = reader.read_u8().await { + reader.into_inner().stop(0u8.into()).ok(); + error!("Received unexpected data from the provider: {bytes:?}"); + } else { + drop(reader); } - drop(reader); let elapsed = now.elapsed(); @@ -261,7 +266,7 @@ where /// The `AsyncReader` can be used to read the content. async fn handle_blob_response( hash: Hash, - mut reader: quinn::RecvStream, + mut reader: BufReader, buffer: &mut BytesMut, ) -> Result { match read_lp(&mut reader, buffer).await? { @@ -277,7 +282,10 @@ async fn handle_blob_response( // next blob in collection will be sent over Res::Found => { assert!(buffer.is_empty()); - let decoder = DataStream::new(reader, hash); + // Decompress data + let decompress_reader = + async_compression::tokio::bufread::BrotliDecoder::new(reader); + let decoder = DataStream::new(decompress_reader, hash); Ok(decoder) } } diff --git a/src/lib.rs b/src/lib.rs index a27bf4ac07..0b82924d89 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,7 @@ mod tests { #[tokio::test] async fn basics() -> Result<()> { + setup_logging(); transfer_data(vec![("hello_world", "hello world!".as_bytes().to_vec())]).await } diff --git a/src/provider/mod.rs b/src/provider/mod.rs index 8ed5bb4400..d4985bdc72 100644 --- a/src/provider/mod.rs +++ b/src/provider/mod.rs @@ -821,6 +821,38 @@ enum SentStatus { NotFound, } +struct ShutdownCatcher(W); + +impl ShutdownCatcher { + fn into_inner(self) -> W { + self.0 + } +} + +impl AsyncWrite for ShutdownCatcher { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } +} + async fn send_blob( db: Database, name: Hash, @@ -837,10 +869,15 @@ async fn send_blob( // need to thread the writer though the spawn_blocking, since // taking a reference does not work. spawn_blocking requires // 'static lifetime. - writer = tokio::task::spawn_blocking(move || { + + // Compress data + let mut compressed_writer = + async_compression::tokio::write::BrotliEncoder::new(ShutdownCatcher(writer)); + + compressed_writer = tokio::task::spawn_blocking(move || { let file_reader = std::fs::File::open(&path)?; let outboard_reader = std::io::Cursor::new(outboard); - let mut wrapper = SyncIoBridge::new(&mut writer); + let mut wrapper = SyncIoBridge::new(&mut compressed_writer); let mut slice_extractor = abao::encode::SliceExtractor::new_outboard( file_reader, outboard_reader, @@ -848,10 +885,13 @@ async fn send_blob( size, ); let _copied = std::io::copy(&mut slice_extractor, &mut wrapper)?; - std::io::Result::Ok(writer) + std::io::Result::Ok(compressed_writer) }) .await??; + compressed_writer.shutdown().await?; + let writer = compressed_writer.into_inner().into_inner(); + Ok((SentStatus::Sent, writer, size)) } _ => { @@ -1069,6 +1109,7 @@ async fn write_response( } let used = postcard::to_slice(&response, buffer)?; + // Write lp write_lp(&mut writer, used).await?; debug!("written response of length {}", used.len());