diff --git a/src/bin/soar-dl/cli.rs b/src/bin/soar-dl/cli.rs index 8718f69..21043f6 100644 --- a/src/bin/soar-dl/cli.rs +++ b/src/bin/soar-dl/cli.rs @@ -48,4 +48,12 @@ pub struct Args { /// Output file path #[arg(required = false, short, long)] pub output: Option, + + /// GHCR concurrency + #[arg(required = false, short, long)] + pub concurrency: Option, + + /// GHCR API to use + #[arg(required = false, long)] + pub ghcr_api: Option, } diff --git a/src/bin/soar-dl/download_manager.rs b/src/bin/soar-dl/download_manager.rs index 7f7ce4a..156d730 100644 --- a/src/bin/soar-dl/download_manager.rs +++ b/src/bin/soar-dl/download_manager.rs @@ -4,7 +4,7 @@ use indicatif::HumanBytes; use regex::Regex; use serde::Deserialize; use soar_dl::{ - downloader::{DownloadOptions, DownloadState, Downloader}, + downloader::{DownloadOptions, DownloadState, Downloader, OciDownloadOptions}, error::{DownloadError, PlatformError}, github::{Github, GithubAsset, GithubRelease}, gitlab::{Gitlab, GitlabAsset, GitlabRelease}, @@ -81,6 +81,8 @@ impl DownloadManager { let assets = handler.filter_releases(&releases, &options).await?; let selected_asset = self.select_asset(&assets)?; + + println!("Downloading asset from {}", selected_asset.download_url()); handler.download(&selected_asset, options.clone()).await?; Ok(()) } @@ -130,10 +132,12 @@ impl DownloadManager { for reference in &self.args.ghcr { println!("Downloading using OCI reference: {}", reference); - let options = DownloadOptions { + let options = OciDownloadOptions { url: reference.clone(), + concurrency: self.args.concurrency.clone(), output_path: self.args.output.clone(), progress_callback: Some(self.progress_callback.clone()), + api: self.args.ghcr_api.clone(), }; let _ = downloader .download_oci(options) @@ -187,10 +191,12 @@ impl DownloadManager { Ok(PlatformUrl::Oci(url)) => { println!("Downloading using OCI reference: {}", url); - let options = DownloadOptions { - url: link.clone(), + let options = OciDownloadOptions { + url: url.clone(), + concurrency: self.args.concurrency.clone(), output_path: self.args.output.clone(), progress_callback: Some(self.progress_callback.clone()), + api: self.args.ghcr_api.clone(), }; let _ = downloader .download_oci(options) diff --git a/src/downloader.rs b/src/downloader.rs index 4649d62..073c6cf 100644 --- a/src/downloader.rs +++ b/src/downloader.rs @@ -11,6 +11,7 @@ use reqwest::header::USER_AGENT; use tokio::{ fs::{self, OpenOptions}, io::AsyncWriteExt, + sync::Semaphore, task, }; use url::Url; @@ -39,6 +40,15 @@ pub struct Downloader { client: reqwest::Client, } +#[derive(Clone)] +pub struct OciDownloadOptions { + pub url: String, + pub concurrency: Option, + pub output_path: Option, + pub progress_callback: Option>, + pub api: Option, +} + impl Downloader { pub async fn download(&self, options: DownloadOptions) -> Result { let url = Url::parse(&options.url).map_err(|err| DownloadError::InvalidUrl { @@ -124,7 +134,7 @@ impl Downloader { pub async fn download_blob( &self, client: OciClient, - options: DownloadOptions, + options: OciDownloadOptions, ) -> Result<(), DownloadError> { let reference = client.reference.clone(); let digest = reference.tag; @@ -170,10 +180,10 @@ impl Downloader { Ok(()) } - pub async fn download_oci(&self, options: DownloadOptions) -> Result<(), DownloadError> { + pub async fn download_oci(&self, options: OciDownloadOptions) -> Result<(), DownloadError> { let url = options.url.clone(); let reference: Reference = url.into(); - let oci_client = OciClient::new(&reference); + let oci_client = OciClient::new(&reference, options.api.clone()); if reference.tag.starts_with("sha256:") { return self.download_blob(oci_client, options).await; @@ -188,6 +198,7 @@ impl Downloader { callback(DownloadState::Preparing(total_bytes)); } + let semaphore = Arc::new(Semaphore::new(options.concurrency.unwrap_or(1) as usize)); let downloaded_bytes = Arc::new(Mutex::new(0u64)); let outdir = options.output_path; let base_path = if let Some(dir) = outdir { @@ -198,6 +209,7 @@ impl Downloader { }; for layer in manifest.layers { + let permit = semaphore.clone().acquire_owned().await.unwrap(); let client_clone = oci_client.clone(); let cb_clone = options.progress_callback.clone(); let downloaded_bytes = downloaded_bytes.clone(); @@ -219,6 +231,7 @@ impl Downloader { Ok::<(), DownloadError>(()) }); + drop(permit); tasks.push(task); } diff --git a/src/oci.rs b/src/oci.rs index 4d68a7f..a48e67b 100644 --- a/src/oci.rs +++ b/src/oci.rs @@ -44,6 +44,7 @@ pub struct OciManifest { pub struct OciClient { client: reqwest::Client, pub reference: Reference, + pub api: Option, } #[derive(Clone)] @@ -86,11 +87,12 @@ impl From for Reference { } impl OciClient { - pub fn new(reference: &Reference) -> Self { + pub fn new(reference: &Reference, api: Option) -> Self { let client = reqwest::Client::new(); Self { client, reference: reference.clone(), + api, } } @@ -103,8 +105,13 @@ impl OciClient { pub async fn manifest(&self) -> Result { let manifest_url = format!( - "https://ghcr.io/v2/{}/manifests/{}", - self.reference.package, self.reference.tag + "{}/{}/manifests/{}", + self.api + .clone() + .unwrap_or("https://ghcr.io/v2".to_string()) + .trim_end_matches('/'), + self.reference.package, + self.reference.tag ); let resp = self .client @@ -139,8 +146,13 @@ impl OciClient { F: Fn(u64, u64) + Send + 'static, { let blob_url = format!( - "https://ghcr.io/v2/{}/blobs/{}", - self.reference.package, layer.digest + "{}/{}/blobs/{}", + self.api + .clone() + .unwrap_or("https://ghcr.io/v2".to_string()) + .trim_end_matches('/'), + self.reference.package, + layer.digest ); let resp = self .client