Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Native Downloader add Retry Config parameters #1244

Merged
merged 5 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/daft-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ aws-credential-types = {version = "0.55.3", features = ["hardcoded-credentials"]
aws-sdk-s3 = "0.28.0"
aws-sig-auth = "0.55.3"
aws-sigv4 = "0.55.3"
aws-smithy-async = "0.55.3"
bytes = {workspace = true}
common-error = {path = "../common/error", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
Expand Down
23 changes: 22 additions & 1 deletion src/daft-io/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,33 @@ use std::fmt::Formatter;

use serde::Deserialize;
use serde::Serialize;
#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct S3Config {
pub region_name: Option<String>,
pub endpoint_url: Option<String>,
pub key_id: Option<String>,
pub session_token: Option<String>,
pub access_key: Option<String>,
pub retry_initial_backoff_ms: u32,
pub num_tries: u32,
pub anonymous: bool,
}

impl Default for S3Config {
fn default() -> Self {
S3Config {
region_name: None,
endpoint_url: None,
key_id: None,
session_token: None,
access_key: None,
retry_initial_backoff_ms: 1000,
num_tries: 5,
anonymous: false,
}
}
}

#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct IOConfig {
pub s3: S3Config,
Expand All @@ -28,12 +45,16 @@ impl Display for S3Config {
key_id: {:?}
session_token: {:?},
access_key: {:?}
retry_initial_backoff_ms: {:?},
num_tries: {:?},
anonymous: {}",
self.region_name,
self.endpoint_url,
self.key_id,
self.session_token,
self.access_key,
self.retry_initial_backoff_ms,
self.num_tries,
self.anonymous
)
}
Expand Down
41 changes: 34 additions & 7 deletions src/daft-io/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use pyo3::prelude::*;
/// key_id: AWS Access Key ID, defaults to auto-detection from the current environment
/// access_key: AWS Secret Access Key, defaults to auto-detection from the current environment
/// session_token: AWS Session Token, required only if `key_id` and `access_key` are temporary credentials
/// retry_initial_backoff_ms: Initial backoff duration in milliseconds for an S3 retry, defaults to 1000ms
/// num_tries: Number of attempts to make a connection, defaults to 5
/// anonymous: Whether or not to use "anonymous mode", which will access S3 without any credentials
///
/// Example:
Expand All @@ -28,7 +30,7 @@ pub struct S3Config {
/// s3: Configurations to use when accessing URLs with the `s3://` scheme
///
/// Example:
/// >>> io_config = IOConfig(s3=S3Config(key_id="xxx", access_key="xxx"))
/// >>> io_config = IOConfig(s3=S3Config(key_id="xxx", access_key="xxx", num_tries=10))
/// >>> daft.read_parquet("s3://some-path", io_config=io_config)
#[derive(Clone, Default)]
#[pyclass]
Expand Down Expand Up @@ -73,23 +75,30 @@ impl IOConfig {

#[pymethods]
impl S3Config {
#[allow(clippy::too_many_arguments)]
#[new]
pub fn new(
region_name: Option<String>,
endpoint_url: Option<String>,
key_id: Option<String>,
session_token: Option<String>,
access_key: Option<String>,
retry_initial_backoff_ms: Option<u32>,
num_tries: Option<u32>,
anonymous: Option<bool>,
) -> Self {
let def = config::S3Config::default();
S3Config {
config: config::S3Config {
region_name,
endpoint_url,
key_id,
session_token,
access_key,
anonymous: anonymous.unwrap_or(false),
region_name: region_name.or(def.region_name),
endpoint_url: endpoint_url.or(def.endpoint_url),
key_id: key_id.or(def.key_id),
session_token: session_token.or(def.session_token),
access_key: access_key.or(def.access_key),
retry_initial_backoff_ms: retry_initial_backoff_ms
.unwrap_or(def.retry_initial_backoff_ms),
num_tries: num_tries.unwrap_or(def.num_tries),
anonymous: anonymous.unwrap_or(def.anonymous),
},
}
}
Expand All @@ -116,11 +125,29 @@ impl S3Config {
Ok(self.config.key_id.clone())
}

/// AWS Session Token
#[getter]
pub fn session_token(&self) -> PyResult<Option<String>> {
Ok(self.config.session_token.clone())
}

/// AWS Secret Access Key
#[getter]
pub fn access_key(&self) -> PyResult<Option<String>> {
Ok(self.config.access_key.clone())
}

/// AWS Retry Initial Backoff Time in Milliseconds
#[getter]
pub fn retry_initial_backoff_ms(&self) -> PyResult<u32> {
Ok(self.config.retry_initial_backoff_ms)
}

/// AWS Number Retries
#[getter]
pub fn num_tries(&self) -> PyResult<u32> {
Ok(self.config.num_tries)
}
}

impl From<config::IOConfig> for IOConfig {
Expand Down
22 changes: 20 additions & 2 deletions src/daft-io/src/s3_like.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use async_trait::async_trait;
use aws_smithy_async::rt::sleep::TokioSleep;
use reqwest::StatusCode;
use s3::operation::head_object::HeadObjectError;

use crate::config::S3Config;
use crate::SourceType;
use crate::{InvalidArgumentSnafu, SourceType};
use aws_config::SdkConfig;
use aws_credential_types::cache::ProvideCachedCredentials;
use aws_credential_types::provider::error::CredentialsError;
Expand All @@ -13,7 +14,7 @@ use s3::client::customize::Response;
use s3::config::{Credentials, Region};
use s3::error::SdkError;
use s3::operation::get_object::GetObjectError;
use snafu::{IntoError, ResultExt, Snafu};
use snafu::{ensure, IntoError, ResultExt, Snafu};
use url::ParseError;

use super::object_io::{GetResult, ObjectSource};
Expand All @@ -25,6 +26,7 @@ use std::collections::HashMap;
use std::ops::Range;
use std::string::FromUtf8Error;
use std::sync::Arc;
use std::time::Duration;
pub(crate) struct S3LikeSource {
region_to_client_map: tokio::sync::RwLock<HashMap<Region, Arc<s3::Client>>>,
default_region: Region,
Expand Down Expand Up @@ -140,6 +142,22 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)>
builder
};

ensure!(
config.num_tries > 0,
InvalidArgumentSnafu {
msg: "num_tries must be greater than zero"
}
);
let retry_config = s3::config::retry::RetryConfig::standard()
.with_max_attempts(config.num_tries)
.with_initial_backoff(Duration::from_millis(
config.retry_initial_backoff_ms as u64,
));
let builder = builder.retry_config(retry_config);

let sleep_impl = Arc::new(TokioSleep::new());
let builder = builder.sleep_impl(sleep_impl);

let builder = if config.access_key.is_some() && config.key_id.is_some() {
let creds = Credentials::from_keys(
config.key_id.clone().unwrap(),
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/io/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def nginx_config() -> tuple[str, pathlib.Path]:
def retry_server_s3_config() -> daft.io.IOConfig:
"""Returns the URL to the local retry_server fixture"""
return daft.io.IOConfig(
s3=daft.io.S3Config(endpoint_url="http://127.0.0.1:8001"),
s3=daft.io.S3Config(endpoint_url="http://127.0.0.1:8001", anonymous=True, num_tries=10),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,6 @@


@pytest.mark.integration()
@pytest.mark.skip(
reason="""[IO-RETRIES] This currently fails: we need better retry policies to have this work consistently.
Currently, if all the retries for a given URL happens to land in the same 1-second window, the request fails.
We should be able to get around this with a more generous retry policy, with larger increments between backoffs.
"""
)
def test_url_download_local_retry_server(retry_server_s3_config):
bucket = "80-per-second-rate-limited-gets-bucket"
data = {"urls": [f"s3://{bucket}/foo{i}" for i in range(100)]}
Expand Down