Skip to content

Commit

Permalink
Deffered config parsing (#4191) (#4192)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold authored May 10, 2023
1 parent 615dde0 commit b314118
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 65 deletions.
10 changes: 10 additions & 0 deletions object_store/src/aws/checksum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use crate::config::Parse;
use ring::digest::{self, digest as ring_digest};
use std::str::FromStr;

Expand Down Expand Up @@ -66,3 +67,12 @@ impl TryFrom<&String> for Checksum {
value.parse()
}
}

impl Parse for Checksum {
fn parse(v: &str) -> crate::Result<Self> {
v.parse().map_err(|_| crate::Error::Generic {
store: "Config",
source: format!("\"{v}\" is not a valid checksum algorithm").into(),
})
}
}
100 changes: 62 additions & 38 deletions object_store/src/aws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ use crate::aws::credential::{
AwsCredential, CredentialProvider, InstanceCredentialProvider,
StaticCredentialProvider, WebIdentityProvider,
};
use crate::client::ClientConfigKey;
use crate::config::ConfigValue;
use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart};
use crate::util::str_is_truthy;
use crate::{
ClientOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Path,
Result, RetryConfig, StreamExt,
Expand Down Expand Up @@ -103,9 +104,6 @@ enum Error {
source: std::num::ParseIntError,
},

#[snafu(display("Invalid Checksum algorithm"))]
InvalidChecksumAlgorithm,

#[snafu(display("Missing region"))]
MissingRegion,

Expand Down Expand Up @@ -461,13 +459,13 @@ pub struct AmazonS3Builder {
/// Retry config
retry_config: RetryConfig,
/// When set to true, fallback to IMDSv1
imdsv1_fallback: bool,
imdsv1_fallback: ConfigValue<bool>,
/// When set to true, virtual hosted style request has to be used
virtual_hosted_style_request: bool,
virtual_hosted_style_request: ConfigValue<bool>,
/// When set to true, unsigned payload option has to be used
unsigned_payload: bool,
unsigned_payload: ConfigValue<bool>,
/// Checksum algorithm which has to be used for object integrity check during upload
checksum_algorithm: Option<String>,
checksum_algorithm: Option<ConfigValue<Checksum>>,
/// Metadata endpoint, see <https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html>
metadata_endpoint: Option<String>,
/// Profile name, see <https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-profiles.html>
Expand Down Expand Up @@ -709,8 +707,9 @@ impl AmazonS3Builder {
}

if let Ok(text) = std::env::var("AWS_ALLOW_HTTP") {
builder.client_options =
builder.client_options.with_allow_http(str_is_truthy(&text));
builder.client_options = builder
.client_options
.with_config(ClientConfigKey::AllowHttp, text);
}

builder
Expand Down Expand Up @@ -756,11 +755,9 @@ impl AmazonS3Builder {
AmazonS3ConfigKey::Bucket => self.bucket_name = Some(value.into()),
AmazonS3ConfigKey::Endpoint => self.endpoint = Some(value.into()),
AmazonS3ConfigKey::Token => self.token = Some(value.into()),
AmazonS3ConfigKey::ImdsV1Fallback => {
self.imdsv1_fallback = str_is_truthy(&value.into())
}
AmazonS3ConfigKey::ImdsV1Fallback => self.imdsv1_fallback.parse(value),
AmazonS3ConfigKey::VirtualHostedStyleRequest => {
self.virtual_hosted_style_request = str_is_truthy(&value.into())
self.virtual_hosted_style_request.parse(value)
}
AmazonS3ConfigKey::DefaultRegion => {
self.region = self.region.or_else(|| Some(value.into()))
Expand All @@ -769,10 +766,10 @@ impl AmazonS3Builder {
self.metadata_endpoint = Some(value.into())
}
AmazonS3ConfigKey::Profile => self.profile = Some(value.into()),
AmazonS3ConfigKey::UnsignedPayload => {
self.unsigned_payload = str_is_truthy(&value.into())
AmazonS3ConfigKey::UnsignedPayload => self.unsigned_payload.parse(value),
AmazonS3ConfigKey::Checksum => {
self.checksum_algorithm = Some(ConfigValue::Deferred(value.into()))
}
AmazonS3ConfigKey::Checksum => self.checksum_algorithm = Some(value.into()),
};
self
}
Expand Down Expand Up @@ -834,7 +831,9 @@ impl AmazonS3Builder {
AmazonS3ConfigKey::MetadataEndpoint => self.metadata_endpoint.clone(),
AmazonS3ConfigKey::Profile => self.profile.clone(),
AmazonS3ConfigKey::UnsignedPayload => Some(self.unsigned_payload.to_string()),
AmazonS3ConfigKey::Checksum => self.checksum_algorithm.clone(),
AmazonS3ConfigKey::Checksum => {
self.checksum_algorithm.as_ref().map(ToString::to_string)
}
}
}

Expand All @@ -858,7 +857,7 @@ impl AmazonS3Builder {
Some((bucket, "s3", region, "amazonaws.com")) => {
self.bucket_name = Some(bucket.to_string());
self.region = Some(region.to_string());
self.virtual_hosted_style_request = true;
self.virtual_hosted_style_request = true.into();
}
Some((account, "r2", "cloudflarestorage", "com")) => {
self.region = Some("auto".to_string());
Expand Down Expand Up @@ -944,7 +943,7 @@ impl AmazonS3Builder {
mut self,
virtual_hosted_style_request: bool,
) -> Self {
self.virtual_hosted_style_request = virtual_hosted_style_request;
self.virtual_hosted_style_request = virtual_hosted_style_request.into();
self
}

Expand All @@ -967,7 +966,7 @@ impl AmazonS3Builder {
/// [SSRF attack]: https://aws.amazon.com/blogs/security/defense-in-depth-open-firewalls-reverse-proxies-ssrf-vulnerabilities-ec2-instance-metadata-service/
///
pub fn with_imdsv1_fallback(mut self) -> Self {
self.imdsv1_fallback = true;
self.imdsv1_fallback = true.into();
self
}

Expand All @@ -976,7 +975,7 @@ impl AmazonS3Builder {
/// * false (default): Signed payload option is used, where the checksum for the request body is computed and included when constructing a canonical request.
/// * true: Unsigned payload option is used. `UNSIGNED-PAYLOAD` literal is included when constructing a canonical request,
pub fn with_unsigned_payload(mut self, unsigned_payload: bool) -> Self {
self.unsigned_payload = unsigned_payload;
self.unsigned_payload = unsigned_payload.into();
self
}

Expand All @@ -985,7 +984,7 @@ impl AmazonS3Builder {
/// [checksum algorithm]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html
pub fn with_checksum_algorithm(mut self, checksum_algorithm: Checksum) -> Self {
// Convert to String to enable deferred parsing of config
self.checksum_algorithm = Some(checksum_algorithm.to_string());
self.checksum_algorithm = Some(checksum_algorithm.into());
self
}

Expand Down Expand Up @@ -1038,11 +1037,7 @@ impl AmazonS3Builder {

let bucket = self.bucket_name.context(MissingBucketNameSnafu)?;
let region = self.region.context(MissingRegionSnafu)?;
let checksum = self
.checksum_algorithm
.map(|c| c.parse())
.transpose()
.map_err(|_| Error::InvalidChecksumAlgorithm)?;
let checksum = self.checksum_algorithm.map(|x| x.get()).transpose()?;

let credentials = match (self.access_key_id, self.secret_access_key, self.token) {
(Some(key_id), Some(secret_key), token) => {
Expand Down Expand Up @@ -1103,7 +1098,7 @@ impl AmazonS3Builder {
cache: Default::default(),
client: client_options.client()?,
retry_config: self.retry_config.clone(),
imdsv1_fallback: self.imdsv1_fallback,
imdsv1_fallback: self.imdsv1_fallback.get()?,
metadata_endpoint: self
.metadata_endpoint
.unwrap_or_else(|| METADATA_ENDPOINT.into()),
Expand All @@ -1119,7 +1114,7 @@ impl AmazonS3Builder {
// If `endpoint` is provided then its assumed to be consistent with
// `virtual_hosted_style_request`. i.e. if `virtual_hosted_style_request` is true then
// `endpoint` should have bucket name included.
if self.virtual_hosted_style_request {
if self.virtual_hosted_style_request.get()? {
endpoint = self
.endpoint
.unwrap_or_else(|| format!("https://{bucket}.s3.{region}.amazonaws.com"));
Expand All @@ -1139,7 +1134,7 @@ impl AmazonS3Builder {
credentials,
retry_config: self.retry_config,
client_options: self.client_options,
sign_payload: !self.unsigned_payload,
sign_payload: !self.unsigned_payload.get()?,
checksum,
};

Expand Down Expand Up @@ -1315,10 +1310,10 @@ mod tests {
let metadata_uri = format!("{METADATA_ENDPOINT}{container_creds_relative_uri}");
assert_eq!(builder.metadata_endpoint.unwrap(), metadata_uri);
assert_eq!(
builder.checksum_algorithm.unwrap(),
Checksum::SHA256.to_string()
builder.checksum_algorithm.unwrap().get().unwrap(),
Checksum::SHA256
);
assert!(builder.unsigned_payload);
assert!(builder.unsigned_payload.get().unwrap());
}

#[test]
Expand Down Expand Up @@ -1351,10 +1346,10 @@ mod tests {
assert_eq!(builder.endpoint.unwrap(), aws_endpoint);
assert_eq!(builder.token.unwrap(), aws_session_token);
assert_eq!(
builder.checksum_algorithm.unwrap(),
Checksum::SHA256.to_string()
builder.checksum_algorithm.unwrap().get().unwrap(),
Checksum::SHA256
);
assert!(builder.unsigned_payload);
assert!(builder.unsigned_payload.get().unwrap());
}

#[test]
Expand Down Expand Up @@ -1564,7 +1559,7 @@ mod tests {
.unwrap();
assert_eq!(builder.bucket_name, Some("bucket".to_string()));
assert_eq!(builder.region, Some("region".to_string()));
assert!(builder.virtual_hosted_style_request);
assert!(builder.virtual_hosted_style_request.get().unwrap());

let mut builder = AmazonS3Builder::new();
builder
Expand All @@ -1591,6 +1586,35 @@ mod tests {
builder.parse_url(case).unwrap_err();
}
}

#[test]
fn test_invalid_config() {
let err = AmazonS3Builder::new()
.with_config(AmazonS3ConfigKey::ImdsV1Fallback, "enabled")
.with_bucket_name("bucket")
.with_region("region")
.build()
.unwrap_err()
.to_string();

assert_eq!(
err,
"Generic Config error: failed to parse \"enabled\" as boolean"
);

let err = AmazonS3Builder::new()
.with_config(AmazonS3ConfigKey::Checksum, "md5")
.with_bucket_name("bucket")
.with_region("region")
.build()
.unwrap_err()
.to_string();

assert_eq!(
err,
"Generic Config error: \"md5\" is not a valid checksum algorithm"
);
}
}

#[cfg(test)]
Expand Down
29 changes: 14 additions & 15 deletions object_store/src/azure/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ use std::{collections::BTreeSet, str::FromStr};
use tokio::io::AsyncWrite;
use url::Url;

use crate::util::{str_is_truthy, RFC1123_FMT};
use crate::client::ClientConfigKey;
use crate::config::ConfigValue;
use crate::util::RFC1123_FMT;
pub use credential::authority_hosts;

mod client;
Expand Down Expand Up @@ -417,7 +419,7 @@ pub struct MicrosoftAzureBuilder {
/// Url
url: Option<String>,
/// When set to true, azurite storage emulator has to be used
use_emulator: bool,
use_emulator: ConfigValue<bool>,
/// Msi endpoint for acquiring managed identity token
msi_endpoint: Option<String>,
/// Object id for use with managed identity authentication
Expand All @@ -427,7 +429,7 @@ pub struct MicrosoftAzureBuilder {
/// File containing token for Azure AD workload identity federation
federated_token_file: Option<String>,
/// When set to true, azure cli has to be used for acquiring access token
use_azure_cli: bool,
use_azure_cli: ConfigValue<bool>,
/// Retry config
retry_config: RetryConfig,
/// Client options
Expand Down Expand Up @@ -672,8 +674,9 @@ impl MicrosoftAzureBuilder {
}

if let Ok(text) = std::env::var("AZURE_ALLOW_HTTP") {
builder.client_options =
builder.client_options.with_allow_http(str_is_truthy(&text));
builder.client_options = builder
.client_options
.with_config(ClientConfigKey::AllowHttp, text)
}

if let Ok(text) = std::env::var(MSI_ENDPOINT_ENV_KEY) {
Expand Down Expand Up @@ -726,12 +729,8 @@ impl MicrosoftAzureBuilder {
AzureConfigKey::FederatedTokenFile => {
self.federated_token_file = Some(value.into())
}
AzureConfigKey::UseAzureCli => {
self.use_azure_cli = str_is_truthy(&value.into())
}
AzureConfigKey::UseEmulator => {
self.use_emulator = str_is_truthy(&value.into())
}
AzureConfigKey::UseAzureCli => self.use_azure_cli.parse(value),
AzureConfigKey::UseEmulator => self.use_emulator.parse(value),
};
self
}
Expand Down Expand Up @@ -898,7 +897,7 @@ impl MicrosoftAzureBuilder {

/// Set if the Azure emulator should be used (defaults to false)
pub fn with_use_emulator(mut self, use_emulator: bool) -> Self {
self.use_emulator = use_emulator;
self.use_emulator = use_emulator.into();
self
}

Expand Down Expand Up @@ -956,7 +955,7 @@ impl MicrosoftAzureBuilder {
/// Set if the Azure Cli should be used for acquiring access token
/// <https://learn.microsoft.com/en-us/cli/azure/account?view=azure-cli-latest#az-account-get-access-token>
pub fn with_use_azure_cli(mut self, use_azure_cli: bool) -> Self {
self.use_azure_cli = use_azure_cli;
self.use_azure_cli = use_azure_cli.into();
self
}

Expand All @@ -969,7 +968,7 @@ impl MicrosoftAzureBuilder {

let container = self.container_name.ok_or(Error::MissingContainerName {})?;

let (is_emulator, storage_url, auth, account) = if self.use_emulator {
let (is_emulator, storage_url, auth, account) = if self.use_emulator.get()? {
let account_name = self
.account_name
.unwrap_or_else(|| EMULATOR_ACCOUNT.to_string());
Expand Down Expand Up @@ -1022,7 +1021,7 @@ impl MicrosoftAzureBuilder {
credential::CredentialProvider::SASToken(query_pairs)
} else if let Some(sas) = self.sas_key {
credential::CredentialProvider::SASToken(split_sas(&sas)?)
} else if self.use_azure_cli {
} else if self.use_azure_cli.get()? {
credential::CredentialProvider::TokenCredential(
TokenCache::default(),
Box::new(credential::AzureCliCredential::new()),
Expand Down
Loading

0 comments on commit b314118

Please sign in to comment.