Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Polish multipart writer to allow oneshot optimization #3031

Merged
merged 2 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 82 additions & 16 deletions core/src/raw/oio/write/multipart_upload_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::task::Context;
use std::task::Poll;

use async_trait::async_trait;
use bytes::Bytes;
use futures::future::BoxFuture;

use crate::raw::*;
Expand All @@ -37,8 +38,26 @@ use crate::*;
/// - Services impl `MultipartUploadWrite`
/// - `MultipartUploadWriter` impl `Write`
/// - Expose `MultipartUploadWriter` as `Accessor::Writer`
///
/// # Notes
///
/// `MultipartUploadWrite` has an oneshot optimization when `write` has been called only once:
///
/// ```no_build
/// w.write(bs).await?;
/// w.close().await?;
/// ```
///
/// We will use `write_once` instead of starting a new multipart upload.
#[async_trait]
pub trait MultipartUploadWrite: Send + Sync + Unpin + 'static {
/// write_once is used to write the data to underlying storage at once.
///
/// MultipartUploadWriter will call this API when:
///
/// - All the data has been written to the buffer and we can perform the upload at once.
async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()>;

/// initiate_part will call start a multipart upload and return the upload id.
///
/// MultipartUploadWriter will call this when:
Expand Down Expand Up @@ -90,14 +109,15 @@ pub struct MultipartUploadPart {
pub struct MultipartUploadWriter<W: MultipartUploadWrite> {
state: State<W>,

cache: Option<Bytes>,
upload_id: Option<Arc<String>>,
parts: Vec<MultipartUploadPart>,
}

enum State<W> {
Idle(Option<W>),
Init(BoxFuture<'static, (W, Result<String>)>),
Write(BoxFuture<'static, (W, usize, Result<MultipartUploadPart>)>),
Write(BoxFuture<'static, (W, Result<MultipartUploadPart>)>),
Close(BoxFuture<'static, (W, Result<()>)>),
Abort(BoxFuture<'static, (W, Result<()>)>),
}
Expand All @@ -113,6 +133,7 @@ impl<W: MultipartUploadWrite> MultipartUploadWriter<W> {
Self {
state: State::Idle(Some(inner)),

cache: None,
upload_id: None,
parts: Vec::new(),
}
Expand All @@ -128,15 +149,15 @@ where
loop {
match &mut self.state {
State::Idle(w) => {
let w = w.take().expect("writer must be valid");
match self.upload_id.as_ref() {
Some(upload_id) => {
let size = bs.remaining();
let bs = bs.copy_to_bytes(size);
let upload_id = upload_id.clone();
let part_number = self.parts.len();

let bs = self.cache.clone().expect("cache must be valid").clone();
let w = w.take().expect("writer must be valid");
self.state = State::Write(Box::pin(async move {
let size = bs.len();
let part = w
.write_part(
&upload_id,
Expand All @@ -146,10 +167,18 @@ where
)
.await;

(w, size, part)
(w, part)
}));
}
None => {
// Fill cache with the first write.
if self.cache.is_none() {
let size = bs.remaining();
self.cache = Some(bs.copy_to_bytes(size));
return Poll::Ready(Ok(size));
}

let w = w.take().expect("writer must be valid");
self.state = State::Init(Box::pin(async move {
let upload_id = w.initiate_part().await;
(w, upload_id)
Expand All @@ -163,10 +192,12 @@ where
self.upload_id = Some(Arc::new(upload_id?));
}
State::Write(fut) => {
let (w, size, part) = ready!(fut.as_mut().poll(cx));
let (w, part) = ready!(fut.as_mut().poll(cx));
self.state = State::Idle(Some(w));

self.parts.push(part?);
// Replace the cache when last write succeeded
let size = bs.remaining();
self.cache = Some(bs.copy_to_bytes(size));
return Poll::Ready(Ok(size));
}
State::Close(_) => {
Expand All @@ -191,25 +222,57 @@ where
match self.upload_id.clone() {
Some(upload_id) => {
let parts = self.parts.clone();
self.state = State::Close(Box::pin(async move {
let res = w.complete_part(&upload_id, &parts).await;
(w, res)
}));
match self.cache.clone() {
Some(bs) => {
let upload_id = upload_id.clone();
self.state = State::Write(Box::pin(async move {
let size = bs.len();
let part = w
.write_part(
&upload_id,
parts.len(),
size as u64,
AsyncBody::Bytes(bs),
)
.await;
(w, part)
}));
}
None => {
self.state = State::Close(Box::pin(async move {
let res = w.complete_part(&upload_id, &parts).await;
(w, res)
}));
}
}
}
None => return Poll::Ready(Ok(())),
None => match self.cache.clone() {
Some(bs) => {
self.state = State::Close(Box::pin(async move {
let size = bs.len();
let res = w.write_once(size as u64, AsyncBody::Bytes(bs)).await;
(w, res)
}));
}
None => return Poll::Ready(Ok(())),
},
}
}
State::Close(fut) => {
let (w, res) = futures::ready!(fut.as_mut().poll(cx));
self.state = State::Idle(Some(w));
self.cache = None;
return Poll::Ready(res);
}
State::Init(_) => unreachable!(
"MultipartUploadWriter must not go into State::Init during poll_close"
),
State::Write(_) => unreachable!(
"MultipartUploadWriter must not go into State::Write during poll_close"
),
State::Write(fut) => {
let (w, part) = ready!(fut.as_mut().poll(cx));
self.state = State::Idle(Some(w));
self.parts.push(part?);
self.cache = None;
}
State::Abort(_) => unreachable!(
"MultipartUploadWriter must not go into State::Abort during poll_close"
),
Expand All @@ -229,7 +292,10 @@ where
(w, res)
}));
}
None => return Poll::Ready(Ok(())),
None => {
self.cache = None;
return Poll::Ready(Ok(()));
}
}
}
State::Abort(fut) => {
Expand Down
6 changes: 2 additions & 4 deletions core/src/services/cos/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,9 @@ impl Accessor for CosBackend {
let writer = CosWriter::new(self.core.clone(), path, args.clone());

let w = if args.append() {
CosWriters::Three(oio::AppendObjectWriter::new(writer))
} else if args.content_length().is_some() {
CosWriters::One(oio::OneShotWriter::new(writer))
CosWriters::Two(oio::AppendObjectWriter::new(writer))
} else {
CosWriters::Two(oio::MultipartUploadWriter::new(writer))
CosWriters::One(oio::MultipartUploadWriter::new(writer))
};

let w = if let Some(buffer_size) = args.buffer_size() {
Expand Down
20 changes: 6 additions & 14 deletions core/src/services/cos/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,15 @@
use std::sync::Arc;

use async_trait::async_trait;
use bytes::Bytes;
use http::StatusCode;

use super::core::*;
use super::error::parse_error;
use crate::raw::*;
use crate::*;

pub type CosWriters = oio::ThreeWaysWriter<
oio::OneShotWriter<CosWriter>,
oio::MultipartUploadWriter<CosWriter>,
oio::AppendObjectWriter<CosWriter>,
>;
pub type CosWriters =
oio::TwoWaysWriter<oio::MultipartUploadWriter<CosWriter>, oio::AppendObjectWriter<CosWriter>>;

pub struct CosWriter {
core: Arc<CosCore>,
Expand All @@ -50,16 +46,15 @@ impl CosWriter {
}

#[async_trait]
impl oio::OneShotWrite for CosWriter {
async fn write_once(&self, buf: Bytes) -> Result<()> {
let size = buf.len();
impl oio::MultipartUploadWrite for CosWriter {
async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> {
let mut req = self.core.cos_put_object_request(
&self.path,
Some(size as u64),
Some(size),
self.op.content_type(),
self.op.content_disposition(),
self.op.cache_control(),
AsyncBody::Bytes(buf),
body,
)?;

self.core.sign(&mut req).await?;
Expand All @@ -76,10 +71,7 @@ impl oio::OneShotWrite for CosWriter {
_ => Err(parse_error(resp).await?),
}
}
}

#[async_trait]
impl oio::MultipartUploadWrite for CosWriter {
async fn initiate_part(&self) -> Result<String> {
let resp = self
.core
Expand Down
6 changes: 2 additions & 4 deletions core/src/services/obs/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,11 +375,9 @@ impl Accessor for ObsBackend {
let writer = ObsWriter::new(self.core.clone(), path, args.clone());

let w = if args.append() {
ObsWriters::Three(oio::AppendObjectWriter::new(writer))
} else if args.content_length().is_some() {
ObsWriters::One(oio::OneShotWriter::new(writer))
ObsWriters::Two(oio::AppendObjectWriter::new(writer))
} else {
ObsWriters::Two(oio::MultipartUploadWriter::new(writer))
ObsWriters::One(oio::MultipartUploadWriter::new(writer))
};

let w = if let Some(buffer_size) = args.buffer_size() {
Expand Down
20 changes: 6 additions & 14 deletions core/src/services/obs/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
use std::sync::Arc;

use async_trait::async_trait;
use bytes::Bytes;
use http::StatusCode;

use super::core::*;
Expand All @@ -27,11 +26,8 @@ use crate::raw::oio::MultipartUploadPart;
use crate::raw::*;
use crate::*;

pub type ObsWriters = oio::ThreeWaysWriter<
oio::OneShotWriter<ObsWriter>,
oio::MultipartUploadWriter<ObsWriter>,
oio::AppendObjectWriter<ObsWriter>,
>;
pub type ObsWriters =
oio::TwoWaysWriter<oio::MultipartUploadWriter<ObsWriter>, oio::AppendObjectWriter<ObsWriter>>;

pub struct ObsWriter {
core: Arc<ObsCore>,
Expand All @@ -51,15 +47,14 @@ impl ObsWriter {
}

#[async_trait]
impl oio::OneShotWrite for ObsWriter {
async fn write_once(&self, bs: Bytes) -> Result<()> {
let size = bs.len();
impl oio::MultipartUploadWrite for ObsWriter {
async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> {
let mut req = self.core.obs_put_object_request(
&self.path,
Some(size as u64),
Some(size),
self.op.content_type(),
self.op.cache_control(),
AsyncBody::Bytes(bs),
body,
)?;

self.core.sign(&mut req).await?;
Expand All @@ -76,10 +71,7 @@ impl oio::OneShotWrite for ObsWriter {
_ => Err(parse_error(resp).await?),
}
}
}

#[async_trait]
impl oio::MultipartUploadWrite for ObsWriter {
async fn initiate_part(&self) -> Result<String> {
let resp = self
.core
Expand Down
6 changes: 2 additions & 4 deletions core/src/services/oss/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,11 +473,9 @@ impl Accessor for OssBackend {
let writer = OssWriter::new(self.core.clone(), path, args.clone());

let w = if args.append() {
OssWriters::Three(oio::AppendObjectWriter::new(writer))
} else if args.content_length().is_some() {
OssWriters::One(oio::OneShotWriter::new(writer))
OssWriters::Two(oio::AppendObjectWriter::new(writer))
} else {
OssWriters::Two(oio::MultipartUploadWriter::new(writer))
OssWriters::One(oio::MultipartUploadWriter::new(writer))
};

let w = if let Some(buffer_size) = args.buffer_size() {
Expand Down
20 changes: 6 additions & 14 deletions core/src/services/oss/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,15 @@
use std::sync::Arc;

use async_trait::async_trait;
use bytes::Bytes;
use http::StatusCode;

use super::core::*;
use super::error::parse_error;
use crate::raw::*;
use crate::*;

pub type OssWriters = oio::ThreeWaysWriter<
oio::OneShotWriter<OssWriter>,
oio::MultipartUploadWriter<OssWriter>,
oio::AppendObjectWriter<OssWriter>,
>;
pub type OssWriters =
oio::TwoWaysWriter<oio::MultipartUploadWriter<OssWriter>, oio::AppendObjectWriter<OssWriter>>;

pub struct OssWriter {
core: Arc<OssCore>,
Expand All @@ -50,16 +46,15 @@ impl OssWriter {
}

#[async_trait]
impl oio::OneShotWrite for OssWriter {
async fn write_once(&self, bs: Bytes) -> Result<()> {
let size = bs.len();
impl oio::MultipartUploadWrite for OssWriter {
async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> {
let mut req = self.core.oss_put_object_request(
&self.path,
Some(size as u64),
Some(size),
self.op.content_type(),
self.op.content_disposition(),
self.op.cache_control(),
AsyncBody::Bytes(bs),
body,
false,
)?;

Expand All @@ -77,10 +72,7 @@ impl oio::OneShotWrite for OssWriter {
_ => Err(parse_error(resp).await?),
}
}
}

#[async_trait]
impl oio::MultipartUploadWrite for OssWriter {
async fn initiate_part(&self) -> Result<String> {
let resp = self
.core
Expand Down
Loading