Skip to content

Commit

Permalink
feat: implement recursive download for dfget
Browse files Browse the repository at this point in the history
Signed-off-by: Lzzzt <[email protected]>

feat: implement recursive download for dfget

Signed-off-by: Lzzzt <[email protected]>

feat: implement recursive download for dfget

Signed-off-by: lzzzt <[email protected]>

feat: implement recursive download for dfget

Signed-off-by: lzzzt <[email protected]>

feat: implement recursive download for dfget

Signed-off-by: lzzzt <[email protected]>

feat: implement recursive download for dfget

Signed-off-by: lzzzt <[email protected]>
  • Loading branch information
Lzzzzzt committed Jul 18, 2024
1 parent c202b14 commit 9ade0cb
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 32 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions dragonfly-client-core/src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ pub enum DFError {
// ExternalError is the error for external error.
#[error(transparent)]
ExternalError(#[from] ExternalError),

#[error("max download file count {0} exceeded")]
MaxDownloadFileCountExceeded(usize),

#[error(transparent)]
TokioJoinError(tokio::task::JoinError),
}

// SendError is the error for send.
Expand Down
1 change: 1 addition & 0 deletions dragonfly-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ futures-util = "0.3.30"
termion = "4.0.2"
tabled = "0.15.0"
path-absolutize = "3.1.1"
percent-encoding = "2.3.1"

[target.'cfg(not(target_env = "msvc"))'.dependencies]
tikv-jemallocator = { version = "0.5.4", features = ["profiling", "stats", "unprefixed_malloc_on_supported_platforms", "background_threads"] }
Expand Down
232 changes: 200 additions & 32 deletions dragonfly-client/src/bin/dfget/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,21 @@ use dragonfly_api::errordetails::v2::Backend;
use dragonfly_client::grpc::dfdaemon_download::DfdaemonDownloadClient;
use dragonfly_client::grpc::health::HealthClient;
use dragonfly_client::tracing::init_tracing;
use dragonfly_client_backend::{BackendFactory, HeadRequest};
use dragonfly_client_config::{self, default_piece_length, dfdaemon, dfget};
use dragonfly_client_core::{
error::{ErrorType, OrErr},
Error, Result,
};
use dragonfly_client_core::error::{BackendError, ErrorType, OrErr};
use dragonfly_client_core::{Error, Result};
use dragonfly_client_util::http::header_vec_to_hashmap;
use indicatif::{ProgressBar, ProgressState, ProgressStyle};
use indicatif::{MultiProgress, ProgressBar, ProgressState, ProgressStyle};
use path_absolutize::*;
use percent_encoding::percent_decode_str;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use std::{cmp::min, fmt::Write};
use termion::{color, style};
use tracing::{error, info, Level};
use tokio::sync::Semaphore;
use tracing::{error, info, warn, Level};
use url::Url;

const LONG_ABOUT: &str = r#"
Expand Down Expand Up @@ -64,7 +66,9 @@ Examples:
$ dfget cos://<bucket>/<path> -O /tmp/file.txt --storage-access-key-id=<access_key_id> --storage-access-key-secret=<access_key_secret> --storage-endpoint=<endpoint>
"#;

#[derive(Debug, Parser)]
const DFGET_HEAD_REQUEST_TASK_ID: &str = "dfget";

#[derive(Debug, Parser, Clone)]
#[command(
name = dfget::NAME,
author,
Expand Down Expand Up @@ -195,6 +199,20 @@ struct Args {
)]
storage_predefined_acl: Option<String>,

#[arg(
long,
default_value_t = 8,
help = "Specify the max number of file to download"
)]
download_max_files: usize,

#[arg(
long,
default_value_t = 5,
help = "Specify the concurrent count of download task"
)]
download_concurrent_count: usize,

#[arg(
short = 'l',
long,
Expand Down Expand Up @@ -448,37 +466,182 @@ async fn main() -> anyhow::Result<()> {
}

// run runs the dfget command.
async fn run(args: Args) -> Result<()> {
async fn run(mut args: Args) -> Result<()> {
let dfdaemon_download_client = get_dfdaemon_download_client(args.endpoint.to_path_buf())
.await
.map_err(|err| {
error!("initialize dfdaemon download client failed: {}", err);
err
})?;

// Get the absolute path of the output file.
let absolute_path = Path::new(&args.output).absolutize()?;
info!("download file to: {}", absolute_path.to_string_lossy());

args.output = absolute_path.into();

// If url end with '/', treat it as a directory and download the whole directory.
if args.url.path().ends_with('/') {
download_tasks(args, dfdaemon_download_client).await
} else {
// Download single file.
let progress_bar = ProgressBar::new(0);
download_task(args, progress_bar, dfdaemon_download_client).await
}
}

async fn download_tasks(args: Args, download_client: DfdaemonDownloadClient) -> Result<()> {
// Only when the `access_key_id` and `access_key_secret` are provided at the same time,
// they will be pass to the `DownloadTaskRequest`.
// they will be passed to the `DownloadTaskRequest`.
let mut object_storage = None;
if let (Some(access_key_id), Some(access_key_secret)) =
(args.storage_access_key_id, args.storage_access_key_secret)
{
if let (Some(access_key_id), Some(access_key_secret)) = (
args.storage_access_key_id.clone(),
args.storage_access_key_secret.clone(),
) {
object_storage = Some(ObjectStorage {
region: args.storage_region,
endpoint: args.storage_endpoint,
access_key_id,
access_key_secret,
session_token: args.storage_session_token,
credential: args.storage_credential,
predefined_acl: args.storage_predefined_acl,
session_token: args.storage_session_token.clone(),
region: args.storage_region.clone(),
endpoint: args.storage_endpoint.clone(),
credential: args.storage_credential.clone(),
predefined_acl: args.storage_predefined_acl.clone(),
});
}

// Get the absolute path of the output file.
let absolute_path = Path::new(&args.output).absolutize()?;
info!("download file to: {}", absolute_path.to_string_lossy());
// Init the backend factory to choose which backend to use for send head request.
let backend_factory = BackendFactory::new(None)?;

// Get the actual backend to send head request.
let backend = backend_factory.build(args.url.as_str())?;

// Send head request.
let head_response = backend
.head(HeadRequest {
task_id: DFGET_HEAD_REQUEST_TASK_ID.into(),
url: args.url.to_string(),
http_header: None,
timeout: args.timeout,
client_certs: None,
object_storage,
})
.await?;

// Return error when response is failed.
if !head_response.success {
return Err(Error::BackendError(BackendError {
message: head_response.error_message.unwrap_or_default(),
status_code: Some(head_response.http_status_code.unwrap_or_default()),
header: Some(head_response.http_header.unwrap_or_default()),
}));
}

// If target directory is empty, then just return.
let Some(entries) = head_response.entries else {
warn!("no file is found in {}", args.url);
return Ok(());
};

// Calc the total file count and compare it with args to decide whether to execute download task.
let file_count = entries.iter().filter(|e| !e.is_dir).count();
if file_count > args.download_max_files {
return Err(Error::MaxDownloadFileCountExceeded(file_count));
}

// Due to the root_dir always end with '/', but the args.output may end with '/', so
// append '/' to output_root_dir if need.
// These two variable root_dir and output_root_dir will be used to build the actual output
// directory.
// For example, if root_dir is '/test/' and output_root_dir is '/path/to/target/', the actual output
// directory will be '/path/to/target/file-to-download', so, if output_root_dir is not suffix with
// '/', it's necessary to append a '/' to it.
let root_dir = args.url.path();
let output_root_dir = if args.output.to_string_lossy().ends_with('/') {
args.output.to_string_lossy().to_string()
} else {
format!("{}/", args.output.to_string_lossy())
};

let multi_progress_bar = MultiProgress::new();

// Use the semaphore to control the concurrent download task number.
// The initial value of semaphore is taken from the user input.
let concurrent_control = Arc::new(Semaphore::new(args.download_concurrent_count));

// To store the download task handler.
let mut handlers = Vec::with_capacity(file_count);

for entry in entries {
let url: Url = entry.url.parse().expect("unexpected url");

// If entry is a directory, then create it, or execute the download task.
if entry.is_dir {
// The url in the entry is percentage encoded, so we should decode it to get right path.
let decoded_url_path = percent_decode_str(url.path()).decode_utf8_lossy();
// Get the actual path.
let output_dir = decoded_url_path.replacen(root_dir, &output_root_dir, 1);

tokio::fs::create_dir(&output_dir).await.map_err(|e| {
error!("create {} failed: {}", output_dir, e);
e
})?;
} else {
let mut args = args.clone();
// The url in the entry is percentage encoded, so we should decode it to get right path.
let decoded_url_path = percent_decode_str(url.path()).decode_utf8_lossy();
// Get the actual path.
args.output = decoded_url_path
.replacen(root_dir, &output_root_dir, 1)
.into();
args.url = url;

let progress_bar = multi_progress_bar.add(ProgressBar::new(0));
let client = download_client.clone();
let semaphore = concurrent_control.clone();

handlers.push(tokio::spawn(async move {
// This is used for concurrent control.
// semaphore will live until all the download task finished, it should
// remain open when the download task executing, so we can expect it directly.
let _permit = semaphore.acquire().await.expect("semaphore closed");
download_task(args, progress_bar, client).await
}));
}
}

// Wait for all download tasks finished.
for handler in handlers {
handler.await.map_err(Error::TokioJoinError)??;
}

Ok(())
}

async fn download_task(
args: Args,
progress_bar: ProgressBar,
download_client: DfdaemonDownloadClient,
) -> Result<()> {
// Only when the `access_key_id` and `access_key_secret` are provided at the same time,
// they will be passed to the `DownloadTaskRequest`.
let mut object_storage = None;
if let (Some(access_key_id), Some(access_key_secret)) = (
args.storage_access_key_id.clone(),
args.storage_access_key_secret.clone(),
) {
object_storage = Some(ObjectStorage {
access_key_id,
access_key_secret,
session_token: args.storage_session_token.clone(),
region: args.storage_region.clone(),
endpoint: args.storage_endpoint.clone(),
credential: args.storage_credential.clone(),
predefined_acl: args.storage_predefined_acl.clone(),
});
}

// Create dfdaemon client.
let response = dfdaemon_download_client
let response = download_client
.download_task(DownloadTaskRequest {
download: Some(Download {
url: args.url.to_string(),
Expand All @@ -492,7 +655,7 @@ async fn run(args: Args) -> Result<()> {
filtered_query_params: args.filtered_query_params.unwrap_or_default(),
request_header: header_vec_to_hashmap(args.header.unwrap_or_default())?,
piece_length: args.piece_length,
output_path: Some(absolute_path.to_string_lossy().to_string()),
output_path: Some(args.output.to_string_lossy().to_string()),
timeout: Some(
prost_wkt_types::Duration::try_from(args.timeout)
.or_err(ErrorType::ParseError)?,
Expand All @@ -510,20 +673,22 @@ async fn run(args: Args) -> Result<()> {
err
})?;

// Initialize progress bar.
let pb = ProgressBar::new(0);
pb.set_style(
// Get actual path rather than percentage encoded path as task name.
let task_name = percent_decode_str(args.url.path()).decode_utf8_lossy();

progress_bar.set_style(
ProgressStyle::with_template(
"[{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})",
"{msg:.bold}\n[{elapsed_precise}] [{bar:60.green/red}] {percent:3}% ({bytes_per_sec:.red}, {eta:.cyan})",
)
.or_err(ErrorType::ParseError)?
.with_key("eta", |state: &ProgressState, w: &mut dyn Write| {
write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
})
.progress_chars("#>-"),
.progress_chars("-"),
);
progress_bar.set_message(task_name.to_string());

// Download file.
// Download file.
let mut downloaded = 0;
let mut out_stream = response.into_inner();
while let Some(message) = out_stream.message().await.map_err(|err| {
Expand All @@ -532,20 +697,23 @@ async fn run(args: Args) -> Result<()> {
})? {
match message.response {
Some(download_task_response::Response::DownloadTaskStartedResponse(response)) => {
pb.set_length(response.content_length);
progress_bar.set_length(response.content_length);
}
Some(download_task_response::Response::DownloadPieceFinishedResponse(response)) => {
let piece = response.piece.ok_or(Error::InvalidParameter)?;

downloaded += piece.length;
let position = min(downloaded + piece.length, pb.length().unwrap_or(0));
pb.set_position(position);
let position = min(
downloaded + piece.length,
progress_bar.length().unwrap_or(0),
);
progress_bar.set_position(position);
}
None => {}
}
}

pb.finish_with_message("downloaded");
progress_bar.finish_with_message(format!("{} downloaded", task_name));
Ok(())
}

Expand Down

0 comments on commit 9ade0cb

Please sign in to comment.