Skip to content

Commit

Permalink
io_config
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Aug 8, 2023
1 parent c0a5040 commit ce72b4c
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 18 deletions.
23 changes: 22 additions & 1 deletion src/daft-io/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,33 @@ use std::fmt::Formatter;

use serde::Deserialize;
use serde::Serialize;
#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct S3Config {
pub region_name: Option<String>,
pub endpoint_url: Option<String>,
pub key_id: Option<String>,
pub session_token: Option<String>,
pub access_key: Option<String>,
pub retry_initial_backoff_ms: u32,
pub num_tries: u32,
pub anonymous: bool,
}

impl Default for S3Config {
fn default() -> Self {
S3Config {
region_name: None,
endpoint_url: None,
key_id: None,
session_token: None,
access_key: None,
retry_initial_backoff_ms: 1000,
num_tries: 5,
anonymous: false,
}
}
}

#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct IOConfig {
pub s3: S3Config,
Expand All @@ -28,12 +45,16 @@ impl Display for S3Config {
key_id: {:?}
session_token: {:?},
access_key: {:?}
retry_initial_backoff_ms: {:?},
num_tries: {:?},
anonymous: {}",
self.region_name,
self.endpoint_url,
self.key_id,
self.session_token,
self.access_key,
self.retry_initial_backoff_ms,
self.num_tries,
self.anonymous
)
}
Expand Down
36 changes: 30 additions & 6 deletions src/daft-io/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,22 @@ impl S3Config {
key_id: Option<String>,
session_token: Option<String>,
access_key: Option<String>,
retry_initial_backoff_ms: Option<u32>,
num_tries: Option<u32>,
anonymous: Option<bool>,
) -> Self {
let def = config::S3Config::default();
S3Config {
config: config::S3Config {
region_name,
endpoint_url,
key_id,
session_token,
access_key,
anonymous: anonymous.unwrap_or(false),
region_name: region_name.or(def.region_name),
endpoint_url: endpoint_url.or(def.endpoint_url),
key_id: key_id.or(def.key_id),
session_token: session_token.or(def.session_token),
access_key: access_key.or(def.access_key),
retry_initial_backoff_ms: retry_initial_backoff_ms
.unwrap_or(def.retry_initial_backoff_ms),
num_tries: num_tries.unwrap_or(def.num_tries),
anonymous: anonymous.unwrap_or(def.anonymous),
},
}
}
Expand All @@ -116,11 +122,29 @@ impl S3Config {
Ok(self.config.key_id.clone())
}

/// AWS Session Token
#[getter]
pub fn session_token(&self) -> PyResult<Option<String>> {
Ok(self.config.session_token.clone())
}

/// AWS Secret Access Key
#[getter]
pub fn access_key(&self) -> PyResult<Option<String>> {
Ok(self.config.access_key.clone())
}

/// AWS Retry Initial Backoff Time in Milliseconds
#[getter]
pub fn retry_initial_backoff_ms(&self) -> PyResult<u32> {
Ok(self.config.retry_initial_backoff_ms)
}

/// AWS Number Retries
#[getter]
pub fn num_tries(&self) -> PyResult<u32> {
Ok(self.config.num_tries)
}
}

impl From<config::IOConfig> for IOConfig {
Expand Down
17 changes: 13 additions & 4 deletions src/daft-io/src/s3_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use reqwest::StatusCode;
use s3::operation::head_object::HeadObjectError;

use crate::config::S3Config;
use crate::SourceType;
use crate::{InvalidArgumentSnafu, SourceType};
use aws_config::SdkConfig;
use aws_credential_types::cache::ProvideCachedCredentials;
use aws_credential_types::provider::error::CredentialsError;
Expand All @@ -14,7 +14,7 @@ use s3::client::customize::Response;
use s3::config::{Credentials, Region};
use s3::error::SdkError;
use s3::operation::get_object::GetObjectError;
use snafu::{IntoError, ResultExt, Snafu};
use snafu::{ensure, IntoError, ResultExt, Snafu};
use url::ParseError;

use super::object_io::{GetResult, ObjectSource};
Expand All @@ -26,6 +26,7 @@ use std::collections::HashMap;
use std::ops::Range;
use std::string::FromUtf8Error;
use std::sync::Arc;
use std::time::Duration;
pub(crate) struct S3LikeSource {
region_to_client_map: tokio::sync::RwLock<HashMap<Region, Arc<s3::Client>>>,
default_region: Region,
Expand Down Expand Up @@ -141,9 +142,17 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)>
builder
};

ensure!(
config.num_tries > 0,
InvalidArgumentSnafu {
msg: "num_tries must be greater than zero"
}
);
let retry_config = s3::config::retry::RetryConfig::standard()
.with_retry_mode(aws_config::retry::RetryMode::Adaptive)
.with_max_attempts(3);
.with_max_attempts(config.num_tries)
.with_initial_backoff(Duration::from_millis(
config.retry_initial_backoff_ms as u64,
));
let builder = builder.retry_config(retry_config);

let sleep_impl = Arc::new(TokioSleep::new());
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/io/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def nginx_config() -> tuple[str, pathlib.Path]:
def retry_server_s3_config() -> daft.io.IOConfig:
"""Returns the URL to the local retry_server fixture"""
return daft.io.IOConfig(
s3=daft.io.S3Config(endpoint_url="http://127.0.0.1:8001"),
s3=daft.io.S3Config(endpoint_url="http://127.0.0.1:8001", anonymous=True, num_tries=10),
)


Expand Down
12 changes: 6 additions & 6 deletions tests/integration/io/test_url_download_s3_local_retry_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@


@pytest.mark.integration()
@pytest.mark.skip(
reason="""[IO-RETRIES] This currently fails: we need better retry policies to have this work consistently.
Currently, if all the retries for a given URL happens to land in the same 1-second window, the request fails.
We should be able to get around this with a more generous retry policy, with larger increments between backoffs.
"""
)
# @pytest.mark.skip(
# reason="""[IO-RETRIES] This currently fails: we need better retry policies to have this work consistently.
# Currently, if all the retries for a given URL happens to land in the same 1-second window, the request fails.
# We should be able to get around this with a more generous retry policy, with larger increments between backoffs.
# """
# )
def test_url_download_local_retry_server(retry_server_s3_config):
bucket = "80-per-second-rate-limited-gets-bucket"
data = {"urls": [f"s3://{bucket}/foo{i}" for i in range(100)]}
Expand Down

0 comments on commit ce72b4c

Please sign in to comment.