Skip to content

Commit

Permalink
fix: allow for assuming a role when bootstrapping AWS credential conf…
Browse files Browse the repository at this point in the history
…iguration

Fixes #2879
  • Loading branch information
rtyler committed Sep 17, 2024
1 parent feb2f4c commit 83043b6
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 146 deletions.
1 change: 1 addition & 0 deletions crates/aws/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ maplit = "1"
# workspace dependencies
async-trait = { workspace = true }
bytes = { workspace = true }
chrono = { workspace = true }
futures = { workspace = true }
tracing = { workspace = true }
object_store = { workspace = true, features = ["aws"]}
Expand Down
126 changes: 64 additions & 62 deletions crates/aws/src/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ impl AWSForObjectStore {
impl CredentialProvider for AWSForObjectStore {
type Credential = AwsCredential;

/// Invoke the underlying [AssumeRoleProvider] to retrieve the temporary credentials associated
/// with the role assumed
/// Provide the necessary configured credentials from the AWS SDK for use by
/// [object_store::aws::AmazonS3]
async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
let provider = self
.sdk_config
Expand All @@ -66,61 +66,6 @@ impl CredentialProvider for AWSForObjectStore {
}
}

/// An [object_store::CredentialProvider] which handles retrieving the necessary
/// temporary credentials associated with the assumed role
#[derive(Debug)]
pub(crate) struct AssumeRoleCredentialProvider {
sdk_config: SdkConfig,
}

impl AssumeRoleCredentialProvider {
fn session_name(&self) -> String {
/*
if let Some(_) = str_option(options, s3_constants::AWS_S3_ROLE_SESSION_NAME) {
warn!(
"AWS_S3_ROLE_SESSION_NAME is deprecated please AWS_IAM_ROLE_SESSION_NAME instead!"
);
}
str_option(options, s3_constants::AWS_IAM_ROLE_SESSION_NAME)
.or(str_option(options, s3_constants::AWS_S3_ROLE_SESSION_NAME))
.unwrap_or("delta-rs".into())
*/
todo!()
}

fn iam_role(&self) -> String {
todo!()
}
}

#[async_trait::async_trait]
impl CredentialProvider for AssumeRoleCredentialProvider {
type Credential = AwsCredential;

/// Invoke the underlying [AssumeRoleProvider] to retrieve the temporary credentials associated
/// with the role assumed
async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
let provider = AssumeRoleProvider::builder(self.iam_role())
.configure(&self.sdk_config)
.session_name(self.session_name())
.build()
.await;
let credentials =
provider
.provide_credentials()
.await
.map_err(|e| ObjectStoreError::NotSupported {
source: Box::new(e),
})?;

Ok(Arc::new(Self::Credential {
key_id: credentials.access_key_id().into(),
secret_key: credentials.secret_access_key().into(),
token: credentials.session_token().map(|o| o.to_string()),
}))
}
}

/// Name of the [OptionsCredentialsProvider] for AWS SDK use
const OPTS_PROVIDER: &str = "DeltaStorageOptionsProvider";

Expand Down Expand Up @@ -165,15 +110,72 @@ impl ProvideCredentials for OptionsCredentialsProvider {
}
}

/// Generate a random session name for assuming IAM roles
fn assume_role_sessio_name() -> String {
let now = chrono::Utc::now();

format!("delta-rs_{}", now.timestamp_millis())
}

/// Return the configured IAM role ARN or whatever is defined in the environment
fn assume_role_arn(options: &StorageOptions) -> Option<String> {
options
.0
.get(constants::AWS_IAM_ROLE_ARN)
.or(options.0.get(constants::AWS_S3_ASSUME_ROLE_ARN))
.or(std::env::var_os(constants::AWS_IAM_ROLE_ARN)
.map(|o| {
o.into_string()
.expect("Failed to unwrap AWS_IAM_ROLE_ARN which may have invalid data")
})
.as_ref())
.or(std::env::var_os(constants::AWS_S3_ASSUME_ROLE_ARN)
.map(|o| {
o.into_string()
.expect("Failed to unwrap AWS_S3_ASSUME_ROLE_ARN which may have invalid data")
})
.as_ref())
.cloned()
}

/// Return the configured IAM assume role session name or provide a unique one
fn assume_session_name(options: &StorageOptions) -> String {
let assume_session = options
.0
.get(constants::AWS_IAM_ROLE_SESSION_NAME)
.or(options.0.get(constants::AWS_S3_ROLE_SESSION_NAME))
.cloned();

match assume_session {
Some(s) => s,
None => assume_role_sessio_name(),
}
}

/// Take a set of [StorageOptions] and produce an appropriate AWS SDK [SdkConfig]
/// for use with various AWS SDK APIs, such as in our [crate::logstore::S3DynamoDbLogStore]
pub async fn resolve_credentials(options: StorageOptions) -> DeltaResult<SdkConfig> {
let options_provider = OptionsCredentialsProvider { options };

let default_provider = DefaultCredentialsChain::builder().build().await;
let credentials_provider =
CredentialsProviderChain::first_try("StorageOptions", options_provider)
.or_else("DefaultChain", default_provider);

let credentials_provider = match assume_role_arn(&options) {
Some(arn) => {
debug!("Configuring AssumeRoleProvider with role arn: {arn}");
CredentialsProviderChain::first_try(
"AssumeRoleProvider",
AssumeRoleProvider::builder(arn)
.session_name(assume_session_name(&options))
.build()
.await,
)
.or_else("StorageOptions", OptionsCredentialsProvider { options })
.or_else("DefaultChain", default_provider)
}
None => CredentialsProviderChain::first_try(
"StorageOptions",
OptionsCredentialsProvider { options },
)
.or_else("DefaultChain", default_provider),
};

Ok(aws_config::from_env()
.credentials_provider(credentials_provider)
Expand Down
98 changes: 14 additions & 84 deletions crates/aws/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -573,34 +573,20 @@ mod tests {
constants::AWS_SECRET_ACCESS_KEY.to_string() => "test_secret".to_string(),
}).unwrap();

// Get a default SdkConfig first, this ensures that if there are environment or profile
// information in the default load of credentials for the test run that it will pass
// the equivalence below
let storage_options = StorageOptions(HashMap::new());
let sdk_config =
execute_sdk_future(crate::credentials::resolve_credentials(storage_options))
.expect("Failed to run future")
.expect("Failed to load default SdkConfig")
.to_builder()
.endpoint_url("http://localhost:1234".to_string())
.region(Region::from_static("us-west-2"))
.build();

assert_eq!(
S3StorageOptions {
sdk_config,
virtual_hosted_style_request: true,
locking_provider: Some("another_locking_provider".to_string()),
dynamodb_endpoint: None,
s3_pool_idle_timeout: Duration::from_secs(1),
sts_pool_idle_timeout: Duration::from_secs(2),
s3_get_internal_server_error_retries: 3,
extra_opts: hashmap! {
s3_constants::AWS_S3_ADDRESSING_STYLE.to_string() => "virtual".to_string()
},
allow_unsafe_rename: false,
Some("another_locking_provider"),
options.locking_provider.as_deref()
);
assert_eq!(Duration::from_secs(1), options.s3_pool_idle_timeout);
assert_eq!(Duration::from_secs(2), options.sts_pool_idle_timeout);
assert_eq!(3, options.s3_get_internal_server_error_retries);
assert!(options.virtual_hosted_style_request);
assert!(!options.allow_unsafe_rename);
assert_eq!(
hashmap! {
constants::AWS_S3_ADDRESSING_STYLE.to_string() => "virtual".to_string()
},
options
options.extra_opts
);
});
}
Expand Down Expand Up @@ -628,23 +614,8 @@ mod tests {
}).unwrap();

assert_eq!(
S3StorageOptions {
sdk_config: SdkConfig::builder()
.endpoint_url("http://localhost:1234".to_string())
.region(Region::from_static("us-west-2"))
.build(),
virtual_hosted_style_request: true,
locking_provider: Some("another_locking_provider".to_string()),
dynamodb_endpoint: Some("http://localhost:2345".to_string()),
s3_pool_idle_timeout: Duration::from_secs(1),
sts_pool_idle_timeout: Duration::from_secs(2),
s3_get_internal_server_error_retries: 3,
extra_opts: hashmap! {
s3_constants::AWS_S3_ADDRESSING_STYLE.to_string() => "virtual".to_string()
},
allow_unsafe_rename: false,
},
options
Some("http://localhost:2345"),
options.dynamodb_endpoint.as_deref()
);
});
}
Expand Down Expand Up @@ -781,45 +752,4 @@ mod tests {
}
});
}

#[tokio::test]
#[serial]
async fn storage_options_toggle_imds() {
ScopedEnv::run_async(async {
clear_env_of_aws_keys();
let disabled_time = storage_options_configure_imds(Some("true")).await;
let enabled_time = storage_options_configure_imds(Some("false")).await;
let default_time = storage_options_configure_imds(None).await;
println!(
"enabled_time: {}, disabled_time: {}, default_time: {}",
enabled_time.as_micros(),
disabled_time.as_micros(),
default_time.as_micros(),
);
assert!(disabled_time < enabled_time);
assert!(default_time < enabled_time);
})
.await;
}

async fn storage_options_configure_imds(value: Option<&str>) -> Duration {
let _options = match value {
Some(value) => S3StorageOptions::from_map(&hashmap! {
constants::AWS_REGION.to_string() => "eu-west-1".to_string(),
constants::AWS_EC2_METADATA_DISABLED.to_string() => value.to_string(),
})
.unwrap(),
None => S3StorageOptions::from_map(&hashmap! {
constants::AWS_REGION.to_string() => "eu-west-1".to_string(),
})
.unwrap(),
};

assert_eq!("eu-west-1", std::env::var(constants::AWS_REGION).unwrap());

let provider = _options.sdk_config.credentials_provider().unwrap();
let now = SystemTime::now();
_ = provider.provide_credentials().await;
now.elapsed().unwrap()
}
}

0 comments on commit 83043b6

Please sign in to comment.