diff --git a/object_store/src/aws/checksum.rs b/object_store/src/aws/checksum.rs index 57762b641ac6..a50bd2d18b9c 100644 --- a/object_store/src/aws/checksum.rs +++ b/object_store/src/aws/checksum.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::config::Parse; use ring::digest::{self, digest as ring_digest}; use std::str::FromStr; @@ -66,3 +67,12 @@ impl TryFrom<&String> for Checksum { value.parse() } } + +impl Parse for Checksum { + fn parse(v: &str) -> crate::Result { + v.parse().map_err(|_| crate::Error::Generic { + store: "Config", + source: format!("\"{v}\" is not a valid checksum algorithm").into(), + }) + } +} diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index 5de177afa10a..0a18b9ec69b1 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -53,8 +53,9 @@ use crate::aws::credential::{ AwsCredential, CredentialProvider, InstanceCredentialProvider, StaticCredentialProvider, WebIdentityProvider, }; +use crate::client::ClientConfigKey; +use crate::config::ConfigValue; use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}; -use crate::util::str_is_truthy; use crate::{ ClientOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Path, Result, RetryConfig, StreamExt, @@ -103,9 +104,6 @@ enum Error { source: std::num::ParseIntError, }, - #[snafu(display("Invalid Checksum algorithm"))] - InvalidChecksumAlgorithm, - #[snafu(display("Missing region"))] MissingRegion, @@ -461,13 +459,13 @@ pub struct AmazonS3Builder { /// Retry config retry_config: RetryConfig, /// When set to true, fallback to IMDSv1 - imdsv1_fallback: bool, + imdsv1_fallback: ConfigValue, /// When set to true, virtual hosted style request has to be used - virtual_hosted_style_request: bool, + virtual_hosted_style_request: ConfigValue, /// When set to true, unsigned payload option has to be used - unsigned_payload: bool, + unsigned_payload: ConfigValue, /// Checksum algorithm which has to be used for object integrity check during upload - checksum_algorithm: Option, + checksum_algorithm: Option>, /// Metadata endpoint, see metadata_endpoint: Option, /// Profile name, see @@ -709,8 +707,9 @@ impl AmazonS3Builder { } if let Ok(text) = std::env::var("AWS_ALLOW_HTTP") { - builder.client_options = - builder.client_options.with_allow_http(str_is_truthy(&text)); + builder.client_options = builder + .client_options + .with_config(ClientConfigKey::AllowHttp, text); } builder @@ -755,11 +754,9 @@ impl AmazonS3Builder { AmazonS3ConfigKey::Bucket => self.bucket_name = Some(value.into()), AmazonS3ConfigKey::Endpoint => self.endpoint = Some(value.into()), AmazonS3ConfigKey::Token => self.token = Some(value.into()), - AmazonS3ConfigKey::ImdsV1Fallback => { - self.imdsv1_fallback = str_is_truthy(&value.into()) - } + AmazonS3ConfigKey::ImdsV1Fallback => self.imdsv1_fallback.parse(value), AmazonS3ConfigKey::VirtualHostedStyleRequest => { - self.virtual_hosted_style_request = str_is_truthy(&value.into()) + self.virtual_hosted_style_request.parse(value) } AmazonS3ConfigKey::DefaultRegion => { self.region = self.region.or_else(|| Some(value.into())) @@ -768,10 +765,10 @@ impl AmazonS3Builder { self.metadata_endpoint = Some(value.into()) } AmazonS3ConfigKey::Profile => self.profile = Some(value.into()), - AmazonS3ConfigKey::UnsignedPayload => { - self.unsigned_payload = str_is_truthy(&value.into()) + AmazonS3ConfigKey::UnsignedPayload => self.unsigned_payload.parse(value), + AmazonS3ConfigKey::Checksum => { + self.checksum_algorithm = Some(ConfigValue::Deferred(value.into())) } - AmazonS3ConfigKey::Checksum => self.checksum_algorithm = Some(value.into()), }; self } @@ -833,7 +830,9 @@ impl AmazonS3Builder { AmazonS3ConfigKey::MetadataEndpoint => self.metadata_endpoint.clone(), AmazonS3ConfigKey::Profile => self.profile.clone(), AmazonS3ConfigKey::UnsignedPayload => Some(self.unsigned_payload.to_string()), - AmazonS3ConfigKey::Checksum => self.checksum_algorithm.clone(), + AmazonS3ConfigKey::Checksum => { + self.checksum_algorithm.as_ref().map(ToString::to_string) + } } } @@ -858,7 +857,7 @@ impl AmazonS3Builder { Some((bucket, "s3", region, "amazonaws.com")) => { self.bucket_name = Some(bucket.to_string()); self.region = Some(region.to_string()); - self.virtual_hosted_style_request = true; + self.virtual_hosted_style_request = true.into(); } _ => return Err(UrlNotRecognisedSnafu { url }.build().into()), }, @@ -934,7 +933,7 @@ impl AmazonS3Builder { mut self, virtual_hosted_style_request: bool, ) -> Self { - self.virtual_hosted_style_request = virtual_hosted_style_request; + self.virtual_hosted_style_request = virtual_hosted_style_request.into(); self } @@ -957,7 +956,7 @@ impl AmazonS3Builder { /// [SSRF attack]: https://aws.amazon.com/blogs/security/defense-in-depth-open-firewalls-reverse-proxies-ssrf-vulnerabilities-ec2-instance-metadata-service/ /// pub fn with_imdsv1_fallback(mut self) -> Self { - self.imdsv1_fallback = true; + self.imdsv1_fallback = true.into(); self } @@ -966,7 +965,7 @@ impl AmazonS3Builder { /// * false (default): Signed payload option is used, where the checksum for the request body is computed and included when constructing a canonical request. /// * true: Unsigned payload option is used. `UNSIGNED-PAYLOAD` literal is included when constructing a canonical request, pub fn with_unsigned_payload(mut self, unsigned_payload: bool) -> Self { - self.unsigned_payload = unsigned_payload; + self.unsigned_payload = unsigned_payload.into(); self } @@ -975,7 +974,7 @@ impl AmazonS3Builder { /// [checksum algorithm]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html pub fn with_checksum_algorithm(mut self, checksum_algorithm: Checksum) -> Self { // Convert to String to enable deferred parsing of config - self.checksum_algorithm = Some(checksum_algorithm.to_string()); + self.checksum_algorithm = Some(checksum_algorithm.into()); self } @@ -1028,11 +1027,7 @@ impl AmazonS3Builder { let bucket = self.bucket_name.context(MissingBucketNameSnafu)?; let region = self.region.context(MissingRegionSnafu)?; - let checksum = self - .checksum_algorithm - .map(|c| c.parse()) - .transpose() - .map_err(|_| Error::InvalidChecksumAlgorithm)?; + let checksum = self.checksum_algorithm.map(|x| x.get()).transpose()?; let credentials = match (self.access_key_id, self.secret_access_key, self.token) { (Some(key_id), Some(secret_key), token) => { @@ -1093,7 +1088,7 @@ impl AmazonS3Builder { cache: Default::default(), client: client_options.client()?, retry_config: self.retry_config.clone(), - imdsv1_fallback: self.imdsv1_fallback, + imdsv1_fallback: self.imdsv1_fallback.get()?, metadata_endpoint: self .metadata_endpoint .unwrap_or_else(|| METADATA_ENDPOINT.into()), @@ -1109,7 +1104,7 @@ impl AmazonS3Builder { // If `endpoint` is provided then its assumed to be consistent with // `virtual_hosted_style_request`. i.e. if `virtual_hosted_style_request` is true then // `endpoint` should have bucket name included. - if self.virtual_hosted_style_request { + if self.virtual_hosted_style_request.get()? { endpoint = self .endpoint .unwrap_or_else(|| format!("https://{bucket}.s3.{region}.amazonaws.com")); @@ -1129,7 +1124,7 @@ impl AmazonS3Builder { credentials, retry_config: self.retry_config, client_options: self.client_options, - sign_payload: !self.unsigned_payload, + sign_payload: !self.unsigned_payload.get()?, checksum, }; @@ -1305,10 +1300,10 @@ mod tests { let metadata_uri = format!("{METADATA_ENDPOINT}{container_creds_relative_uri}"); assert_eq!(builder.metadata_endpoint.unwrap(), metadata_uri); assert_eq!( - builder.checksum_algorithm.unwrap(), - Checksum::SHA256.to_string() + builder.checksum_algorithm.unwrap().get().unwrap(), + Checksum::SHA256 ); - assert!(builder.unsigned_payload); + assert!(builder.unsigned_payload.get().unwrap()); } #[test] @@ -1341,10 +1336,10 @@ mod tests { assert_eq!(builder.endpoint.unwrap(), aws_endpoint); assert_eq!(builder.token.unwrap(), aws_session_token); assert_eq!( - builder.checksum_algorithm.unwrap(), - Checksum::SHA256.to_string() + builder.checksum_algorithm.unwrap().get().unwrap(), + Checksum::SHA256 ); - assert!(builder.unsigned_payload); + assert!(builder.unsigned_payload.get().unwrap()); } #[test] @@ -1554,7 +1549,7 @@ mod tests { .unwrap(); assert_eq!(builder.bucket_name, Some("bucket".to_string())); assert_eq!(builder.region, Some("region".to_string())); - assert!(builder.virtual_hosted_style_request); + assert!(builder.virtual_hosted_style_request.get().unwrap()); let err_cases = [ "mailto://bucket/path", @@ -1569,6 +1564,35 @@ mod tests { builder.parse_url(case).unwrap_err(); } } + + #[test] + fn test_invalid_config() { + let err = AmazonS3Builder::new() + .with_config(AmazonS3ConfigKey::ImdsV1Fallback, "enabled") + .with_bucket_name("bucket") + .with_region("region") + .build() + .unwrap_err() + .to_string(); + + assert_eq!( + err, + "Generic Config error: failed to parse \"enabled\" as boolean" + ); + + let err = AmazonS3Builder::new() + .with_config(AmazonS3ConfigKey::Checksum, "md5") + .with_bucket_name("bucket") + .with_region("region") + .build() + .unwrap_err() + .to_string(); + + assert_eq!( + err, + "Generic Config error: \"md5\" is not a valid checksum algorithm" + ); + } } #[cfg(test)] diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs index 15033dca7ae5..2b5b43adabe0 100644 --- a/object_store/src/azure/mod.rs +++ b/object_store/src/azure/mod.rs @@ -51,7 +51,9 @@ use std::{collections::BTreeSet, str::FromStr}; use tokio::io::AsyncWrite; use url::Url; -use crate::util::{str_is_truthy, RFC1123_FMT}; +use crate::client::ClientConfigKey; +use crate::config::ConfigValue; +use crate::util::RFC1123_FMT; pub use credential::authority_hosts; mod client; @@ -417,7 +419,7 @@ pub struct MicrosoftAzureBuilder { /// Url url: Option, /// When set to true, azurite storage emulator has to be used - use_emulator: bool, + use_emulator: ConfigValue, /// Msi endpoint for acquiring managed identity token msi_endpoint: Option, /// Object id for use with managed identity authentication @@ -427,7 +429,7 @@ pub struct MicrosoftAzureBuilder { /// File containing token for Azure AD workload identity federation federated_token_file: Option, /// When set to true, azure cli has to be used for acquiring access token - use_azure_cli: bool, + use_azure_cli: ConfigValue, /// Retry config retry_config: RetryConfig, /// Client options @@ -672,8 +674,9 @@ impl MicrosoftAzureBuilder { } if let Ok(text) = std::env::var("AZURE_ALLOW_HTTP") { - builder.client_options = - builder.client_options.with_allow_http(str_is_truthy(&text)); + builder.client_options = builder + .client_options + .with_config(ClientConfigKey::AllowHttp, text) } if let Ok(text) = std::env::var(MSI_ENDPOINT_ENV_KEY) { @@ -726,12 +729,8 @@ impl MicrosoftAzureBuilder { AzureConfigKey::FederatedTokenFile => { self.federated_token_file = Some(value.into()) } - AzureConfigKey::UseAzureCli => { - self.use_azure_cli = str_is_truthy(&value.into()) - } - AzureConfigKey::UseEmulator => { - self.use_emulator = str_is_truthy(&value.into()) - } + AzureConfigKey::UseAzureCli => self.use_azure_cli.parse(value), + AzureConfigKey::UseEmulator => self.use_emulator.parse(value), }; self } @@ -898,7 +897,7 @@ impl MicrosoftAzureBuilder { /// Set if the Azure emulator should be used (defaults to false) pub fn with_use_emulator(mut self, use_emulator: bool) -> Self { - self.use_emulator = use_emulator; + self.use_emulator = use_emulator.into(); self } @@ -956,7 +955,7 @@ impl MicrosoftAzureBuilder { /// Set if the Azure Cli should be used for acquiring access token /// pub fn with_use_azure_cli(mut self, use_azure_cli: bool) -> Self { - self.use_azure_cli = use_azure_cli; + self.use_azure_cli = use_azure_cli.into(); self } @@ -969,7 +968,7 @@ impl MicrosoftAzureBuilder { let container = self.container_name.ok_or(Error::MissingContainerName {})?; - let (is_emulator, storage_url, auth, account) = if self.use_emulator { + let (is_emulator, storage_url, auth, account) = if self.use_emulator.get()? { let account_name = self .account_name .unwrap_or_else(|| EMULATOR_ACCOUNT.to_string()); @@ -1022,7 +1021,7 @@ impl MicrosoftAzureBuilder { credential::CredentialProvider::SASToken(query_pairs) } else if let Some(sas) = self.sas_key { credential::CredentialProvider::SASToken(split_sas(&sas)?) - } else if self.use_azure_cli { + } else if self.use_azure_cli.get()? { credential::CredentialProvider::TokenCredential( TokenCache::default(), Box::new(credential::AzureCliCredential::new()), diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs index d019e8119ac2..d7b0b86d99e5 100644 --- a/object_store/src/client/mod.rs +++ b/object_store/src/client/mod.rs @@ -26,8 +26,10 @@ pub mod retry; #[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] pub mod token; +use crate::config::ConfigValue; use reqwest::header::{HeaderMap, HeaderValue}; use reqwest::{Client, ClientBuilder, Proxy}; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::time::Duration; @@ -43,6 +45,14 @@ fn map_client_error(e: reqwest::Error) -> super::Error { static DEFAULT_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); +/// Configuration keys for [`ClientOptions`] +#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy, Deserialize, Serialize)] +#[non_exhaustive] +pub enum ClientConfigKey { + /// Allow non-TLS, i.e. non-HTTPS connections + AllowHttp, +} + /// HTTP client configuration for remote object stores #[derive(Debug, Clone, Default)] pub struct ClientOptions { @@ -51,7 +61,7 @@ pub struct ClientOptions { default_content_type: Option, default_headers: Option, proxy_url: Option, - allow_http: bool, + allow_http: ConfigValue, allow_insecure: bool, timeout: Option, connect_timeout: Option, @@ -70,6 +80,21 @@ impl ClientOptions { Default::default() } + /// Set an option by key + pub fn with_config(mut self, key: ClientConfigKey, value: impl Into) -> Self { + match key { + ClientConfigKey::AllowHttp => self.allow_http.parse(value), + } + self + } + + /// Get an option by key + pub fn get_config_value(&self, key: &ClientConfigKey) -> Option { + match key { + ClientConfigKey::AllowHttp => Some(self.allow_http.to_string()), + } + } + /// Sets the User-Agent header to be used by this client /// /// Default is based on the version of this crate @@ -104,7 +129,7 @@ impl ClientOptions { /// * false (default): Only HTTPS are allowed /// * true: HTTP and HTTPS are allowed pub fn with_allow_http(mut self, allow_http: bool) -> Self { - self.allow_http = allow_http; + self.allow_http = allow_http.into(); self } /// Allows connections to invalid SSL certificates @@ -280,7 +305,7 @@ impl ClientOptions { } builder - .https_only(!self.allow_http) + .https_only(!self.allow_http.get()?) .build() .map_err(map_client_error) } diff --git a/object_store/src/config.rs b/object_store/src/config.rs new file mode 100644 index 000000000000..3ecce2e52bf1 --- /dev/null +++ b/object_store/src/config.rs @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{Error, Result}; +use std::fmt::{Debug, Display, Formatter}; + +/// Provides deferred parsing of a value +/// +/// This allows builders to defer fallibility to build +#[derive(Debug, Clone)] +pub enum ConfigValue { + Parsed(T), + Deferred(String), +} + +impl Display for ConfigValue { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Parsed(v) => write!(f, "{v}"), + Self::Deferred(v) => write!(f, "{v}"), + } + } +} + +impl From for ConfigValue { + fn from(value: T) -> Self { + Self::Parsed(value) + } +} + +impl ConfigValue { + pub fn parse(&mut self, v: impl Into) { + *self = Self::Deferred(v.into()) + } + + pub fn get(&self) -> Result { + match self { + Self::Parsed(v) => Ok(v.clone()), + Self::Deferred(v) => T::parse(v), + } + } +} + +impl Default for ConfigValue { + fn default() -> Self { + Self::Parsed(T::default()) + } +} + +/// A value that can be stored in [`ConfigValue`] +pub trait Parse: Sized { + fn parse(v: &str) -> Result; +} + +impl Parse for bool { + fn parse(v: &str) -> Result { + let lower = v.to_ascii_lowercase(); + match lower.as_str() { + "1" | "true" | "on" | "yes" | "y" => Ok(true), + "0" | "false" | "off" | "no" | "n" => Ok(false), + _ => Err(Error::Generic { + store: "Config", + source: format!("failed to parse \"{v}\" as boolean").into(), + }), + } + } +} diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs index c31027c0715c..1390a0140d1c 100644 --- a/object_store/src/lib.rs +++ b/object_store/src/lib.rs @@ -247,6 +247,9 @@ mod client; #[cfg(any(feature = "gcp", feature = "aws", feature = "azure", feature = "http"))] pub use client::{backoff::BackoffConfig, retry::RetryConfig}; +#[cfg(any(feature = "gcp", feature = "aws", feature = "azure", feature = "http"))] +mod config; + #[cfg(any(feature = "azure", feature = "aws", feature = "gcp"))] mod multipart; mod util; diff --git a/object_store/src/util.rs b/object_store/src/util.rs index 1ec63f219a20..e5c701dd8b1b 100644 --- a/object_store/src/util.rs +++ b/object_store/src/util.rs @@ -185,15 +185,6 @@ fn merge_ranges( ret } -#[allow(dead_code)] -pub(crate) fn str_is_truthy(val: &str) -> bool { - val.eq_ignore_ascii_case("1") - | val.eq_ignore_ascii_case("true") - | val.eq_ignore_ascii_case("on") - | val.eq_ignore_ascii_case("yes") - | val.eq_ignore_ascii_case("y") -} - #[cfg(test)] mod tests { use super::*;