Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deffered Object Store Config Parsing (#4191) #4192

Merged
merged 1 commit into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions object_store/src/aws/checksum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use crate::config::Parse;
use ring::digest::{self, digest as ring_digest};
use std::str::FromStr;

Expand Down Expand Up @@ -66,3 +67,12 @@ impl TryFrom<&String> for Checksum {
value.parse()
}
}

impl Parse for Checksum {
fn parse(v: &str) -> crate::Result<Self> {
v.parse().map_err(|_| crate::Error::Generic {
store: "Config",
source: format!("\"{v}\" is not a valid checksum algorithm").into(),
})
}
}
100 changes: 62 additions & 38 deletions object_store/src/aws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ use crate::aws::credential::{
AwsCredential, CredentialProvider, InstanceCredentialProvider,
StaticCredentialProvider, WebIdentityProvider,
};
use crate::client::ClientConfigKey;
use crate::config::ConfigValue;
use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart};
use crate::util::str_is_truthy;
use crate::{
ClientOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Path,
Result, RetryConfig, StreamExt,
Expand Down Expand Up @@ -103,9 +104,6 @@ enum Error {
source: std::num::ParseIntError,
},

#[snafu(display("Invalid Checksum algorithm"))]
InvalidChecksumAlgorithm,

#[snafu(display("Missing region"))]
MissingRegion,

Expand Down Expand Up @@ -461,13 +459,13 @@ pub struct AmazonS3Builder {
/// Retry config
retry_config: RetryConfig,
/// When set to true, fallback to IMDSv1
imdsv1_fallback: bool,
imdsv1_fallback: ConfigValue<bool>,
/// When set to true, virtual hosted style request has to be used
virtual_hosted_style_request: bool,
virtual_hosted_style_request: ConfigValue<bool>,
/// When set to true, unsigned payload option has to be used
unsigned_payload: bool,
unsigned_payload: ConfigValue<bool>,
/// Checksum algorithm which has to be used for object integrity check during upload
checksum_algorithm: Option<String>,
checksum_algorithm: Option<ConfigValue<Checksum>>,
/// Metadata endpoint, see <https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html>
metadata_endpoint: Option<String>,
/// Profile name, see <https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-profiles.html>
Expand Down Expand Up @@ -709,8 +707,9 @@ impl AmazonS3Builder {
}

if let Ok(text) = std::env::var("AWS_ALLOW_HTTP") {
builder.client_options =
builder.client_options.with_allow_http(str_is_truthy(&text));
builder.client_options = builder
.client_options
.with_config(ClientConfigKey::AllowHttp, text);
}

builder
Expand Down Expand Up @@ -755,11 +754,9 @@ impl AmazonS3Builder {
AmazonS3ConfigKey::Bucket => self.bucket_name = Some(value.into()),
AmazonS3ConfigKey::Endpoint => self.endpoint = Some(value.into()),
AmazonS3ConfigKey::Token => self.token = Some(value.into()),
AmazonS3ConfigKey::ImdsV1Fallback => {
self.imdsv1_fallback = str_is_truthy(&value.into())
}
AmazonS3ConfigKey::ImdsV1Fallback => self.imdsv1_fallback.parse(value),
AmazonS3ConfigKey::VirtualHostedStyleRequest => {
self.virtual_hosted_style_request = str_is_truthy(&value.into())
self.virtual_hosted_style_request.parse(value)
}
AmazonS3ConfigKey::DefaultRegion => {
self.region = self.region.or_else(|| Some(value.into()))
Expand All @@ -768,10 +765,10 @@ impl AmazonS3Builder {
self.metadata_endpoint = Some(value.into())
}
AmazonS3ConfigKey::Profile => self.profile = Some(value.into()),
AmazonS3ConfigKey::UnsignedPayload => {
self.unsigned_payload = str_is_truthy(&value.into())
AmazonS3ConfigKey::UnsignedPayload => self.unsigned_payload.parse(value),
AmazonS3ConfigKey::Checksum => {
self.checksum_algorithm = Some(ConfigValue::Deferred(value.into()))
}
AmazonS3ConfigKey::Checksum => self.checksum_algorithm = Some(value.into()),
};
self
}
Expand Down Expand Up @@ -833,7 +830,9 @@ impl AmazonS3Builder {
AmazonS3ConfigKey::MetadataEndpoint => self.metadata_endpoint.clone(),
AmazonS3ConfigKey::Profile => self.profile.clone(),
AmazonS3ConfigKey::UnsignedPayload => Some(self.unsigned_payload.to_string()),
AmazonS3ConfigKey::Checksum => self.checksum_algorithm.clone(),
AmazonS3ConfigKey::Checksum => {
self.checksum_algorithm.as_ref().map(ToString::to_string)
}
}
}

Expand All @@ -858,7 +857,7 @@ impl AmazonS3Builder {
Some((bucket, "s3", region, "amazonaws.com")) => {
self.bucket_name = Some(bucket.to_string());
self.region = Some(region.to_string());
self.virtual_hosted_style_request = true;
self.virtual_hosted_style_request = true.into();
}
_ => return Err(UrlNotRecognisedSnafu { url }.build().into()),
},
Expand Down Expand Up @@ -934,7 +933,7 @@ impl AmazonS3Builder {
mut self,
virtual_hosted_style_request: bool,
) -> Self {
self.virtual_hosted_style_request = virtual_hosted_style_request;
self.virtual_hosted_style_request = virtual_hosted_style_request.into();
self
}

Expand All @@ -957,7 +956,7 @@ impl AmazonS3Builder {
/// [SSRF attack]: https://aws.amazon.com/blogs/security/defense-in-depth-open-firewalls-reverse-proxies-ssrf-vulnerabilities-ec2-instance-metadata-service/
///
pub fn with_imdsv1_fallback(mut self) -> Self {
self.imdsv1_fallback = true;
self.imdsv1_fallback = true.into();
self
}

Expand All @@ -966,7 +965,7 @@ impl AmazonS3Builder {
/// * false (default): Signed payload option is used, where the checksum for the request body is computed and included when constructing a canonical request.
/// * true: Unsigned payload option is used. `UNSIGNED-PAYLOAD` literal is included when constructing a canonical request,
pub fn with_unsigned_payload(mut self, unsigned_payload: bool) -> Self {
self.unsigned_payload = unsigned_payload;
self.unsigned_payload = unsigned_payload.into();
self
}

Expand All @@ -975,7 +974,7 @@ impl AmazonS3Builder {
/// [checksum algorithm]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html
pub fn with_checksum_algorithm(mut self, checksum_algorithm: Checksum) -> Self {
// Convert to String to enable deferred parsing of config
self.checksum_algorithm = Some(checksum_algorithm.to_string());
self.checksum_algorithm = Some(checksum_algorithm.into());
self
}

Expand Down Expand Up @@ -1028,11 +1027,7 @@ impl AmazonS3Builder {

let bucket = self.bucket_name.context(MissingBucketNameSnafu)?;
let region = self.region.context(MissingRegionSnafu)?;
let checksum = self
.checksum_algorithm
.map(|c| c.parse())
.transpose()
.map_err(|_| Error::InvalidChecksumAlgorithm)?;
let checksum = self.checksum_algorithm.map(|x| x.get()).transpose()?;

let credentials = match (self.access_key_id, self.secret_access_key, self.token) {
(Some(key_id), Some(secret_key), token) => {
Expand Down Expand Up @@ -1093,7 +1088,7 @@ impl AmazonS3Builder {
cache: Default::default(),
client: client_options.client()?,
retry_config: self.retry_config.clone(),
imdsv1_fallback: self.imdsv1_fallback,
imdsv1_fallback: self.imdsv1_fallback.get()?,
metadata_endpoint: self
.metadata_endpoint
.unwrap_or_else(|| METADATA_ENDPOINT.into()),
Expand All @@ -1109,7 +1104,7 @@ impl AmazonS3Builder {
// If `endpoint` is provided then its assumed to be consistent with
// `virtual_hosted_style_request`. i.e. if `virtual_hosted_style_request` is true then
// `endpoint` should have bucket name included.
if self.virtual_hosted_style_request {
if self.virtual_hosted_style_request.get()? {
endpoint = self
.endpoint
.unwrap_or_else(|| format!("https://{bucket}.s3.{region}.amazonaws.com"));
Expand All @@ -1129,7 +1124,7 @@ impl AmazonS3Builder {
credentials,
retry_config: self.retry_config,
client_options: self.client_options,
sign_payload: !self.unsigned_payload,
sign_payload: !self.unsigned_payload.get()?,
checksum,
};

Expand Down Expand Up @@ -1305,10 +1300,10 @@ mod tests {
let metadata_uri = format!("{METADATA_ENDPOINT}{container_creds_relative_uri}");
assert_eq!(builder.metadata_endpoint.unwrap(), metadata_uri);
assert_eq!(
builder.checksum_algorithm.unwrap(),
Checksum::SHA256.to_string()
builder.checksum_algorithm.unwrap().get().unwrap(),
Checksum::SHA256
);
assert!(builder.unsigned_payload);
assert!(builder.unsigned_payload.get().unwrap());
}

#[test]
Expand Down Expand Up @@ -1341,10 +1336,10 @@ mod tests {
assert_eq!(builder.endpoint.unwrap(), aws_endpoint);
assert_eq!(builder.token.unwrap(), aws_session_token);
assert_eq!(
builder.checksum_algorithm.unwrap(),
Checksum::SHA256.to_string()
builder.checksum_algorithm.unwrap().get().unwrap(),
Checksum::SHA256
);
assert!(builder.unsigned_payload);
assert!(builder.unsigned_payload.get().unwrap());
}

#[test]
Expand Down Expand Up @@ -1554,7 +1549,7 @@ mod tests {
.unwrap();
assert_eq!(builder.bucket_name, Some("bucket".to_string()));
assert_eq!(builder.region, Some("region".to_string()));
assert!(builder.virtual_hosted_style_request);
assert!(builder.virtual_hosted_style_request.get().unwrap());

let err_cases = [
"mailto://bucket/path",
Expand All @@ -1569,6 +1564,35 @@ mod tests {
builder.parse_url(case).unwrap_err();
}
}

#[test]
fn test_invalid_config() {
let err = AmazonS3Builder::new()
.with_config(AmazonS3ConfigKey::ImdsV1Fallback, "enabled")
.with_bucket_name("bucket")
.with_region("region")
.build()
.unwrap_err()
.to_string();

assert_eq!(
err,
"Generic Config error: failed to parse \"enabled\" as boolean"
);

let err = AmazonS3Builder::new()
.with_config(AmazonS3ConfigKey::Checksum, "md5")
.with_bucket_name("bucket")
.with_region("region")
.build()
.unwrap_err()
.to_string();

assert_eq!(
err,
"Generic Config error: \"md5\" is not a valid checksum algorithm"
);
}
}

#[cfg(test)]
Expand Down
29 changes: 14 additions & 15 deletions object_store/src/azure/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ use std::{collections::BTreeSet, str::FromStr};
use tokio::io::AsyncWrite;
use url::Url;

use crate::util::{str_is_truthy, RFC1123_FMT};
use crate::client::ClientConfigKey;
use crate::config::ConfigValue;
use crate::util::RFC1123_FMT;
pub use credential::authority_hosts;

mod client;
Expand Down Expand Up @@ -417,7 +419,7 @@ pub struct MicrosoftAzureBuilder {
/// Url
url: Option<String>,
/// When set to true, azurite storage emulator has to be used
use_emulator: bool,
use_emulator: ConfigValue<bool>,
/// Msi endpoint for acquiring managed identity token
msi_endpoint: Option<String>,
/// Object id for use with managed identity authentication
Expand All @@ -427,7 +429,7 @@ pub struct MicrosoftAzureBuilder {
/// File containing token for Azure AD workload identity federation
federated_token_file: Option<String>,
/// When set to true, azure cli has to be used for acquiring access token
use_azure_cli: bool,
use_azure_cli: ConfigValue<bool>,
/// Retry config
retry_config: RetryConfig,
/// Client options
Expand Down Expand Up @@ -672,8 +674,9 @@ impl MicrosoftAzureBuilder {
}

if let Ok(text) = std::env::var("AZURE_ALLOW_HTTP") {
builder.client_options =
builder.client_options.with_allow_http(str_is_truthy(&text));
builder.client_options = builder
.client_options
.with_config(ClientConfigKey::AllowHttp, text)
}

if let Ok(text) = std::env::var(MSI_ENDPOINT_ENV_KEY) {
Expand Down Expand Up @@ -726,12 +729,8 @@ impl MicrosoftAzureBuilder {
AzureConfigKey::FederatedTokenFile => {
self.federated_token_file = Some(value.into())
}
AzureConfigKey::UseAzureCli => {
self.use_azure_cli = str_is_truthy(&value.into())
}
AzureConfigKey::UseEmulator => {
self.use_emulator = str_is_truthy(&value.into())
}
AzureConfigKey::UseAzureCli => self.use_azure_cli.parse(value),
AzureConfigKey::UseEmulator => self.use_emulator.parse(value),
};
self
}
Expand Down Expand Up @@ -898,7 +897,7 @@ impl MicrosoftAzureBuilder {

/// Set if the Azure emulator should be used (defaults to false)
pub fn with_use_emulator(mut self, use_emulator: bool) -> Self {
self.use_emulator = use_emulator;
self.use_emulator = use_emulator.into();
self
}

Expand Down Expand Up @@ -956,7 +955,7 @@ impl MicrosoftAzureBuilder {
/// Set if the Azure Cli should be used for acquiring access token
/// <https://learn.microsoft.com/en-us/cli/azure/account?view=azure-cli-latest#az-account-get-access-token>
pub fn with_use_azure_cli(mut self, use_azure_cli: bool) -> Self {
self.use_azure_cli = use_azure_cli;
self.use_azure_cli = use_azure_cli.into();
self
}

Expand All @@ -969,7 +968,7 @@ impl MicrosoftAzureBuilder {

let container = self.container_name.ok_or(Error::MissingContainerName {})?;

let (is_emulator, storage_url, auth, account) = if self.use_emulator {
let (is_emulator, storage_url, auth, account) = if self.use_emulator.get()? {
let account_name = self
.account_name
.unwrap_or_else(|| EMULATOR_ACCOUNT.to_string());
Expand Down Expand Up @@ -1022,7 +1021,7 @@ impl MicrosoftAzureBuilder {
credential::CredentialProvider::SASToken(query_pairs)
} else if let Some(sas) = self.sas_key {
credential::CredentialProvider::SASToken(split_sas(&sas)?)
} else if self.use_azure_cli {
} else if self.use_azure_cli.get()? {
credential::CredentialProvider::TokenCredential(
TokenCache::default(),
Box::new(credential::AzureCliCredential::new()),
Expand Down
Loading