Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: UUID regeneration on retry #26

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions snowflake-api/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
/target
.env
Cargo.lock
7 changes: 7 additions & 0 deletions snowflake-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ async-trait = "0.1"
base64 = "0.21"
bytes = "1"
futures = "0.3"
http = "1"
log = "0.4"
object_store = { version = "0.9", features = ["aws"] }
regex = "1"
Expand All @@ -33,13 +34,15 @@ reqwest = { version = "0.11", default-features = false, features = [
"rustls-tls",
] }
reqwest-middleware = "0.2"
task-local-extensions = "0.1"
reqwest-retry = "0.3"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
snowflake-jwt = { version = "0.3.0", optional = true }
thiserror = "1"
url = "2"
uuid = { version = "1", features = ["v4"] }

polars-io = { version = ">=0.32", features = ["json", "ipc_streaming"], optional = true}
polars-core = { version = ">=0.32", optional = true}

Expand All @@ -50,3 +53,7 @@ arrow = { version = "50", features = ["prettyprint"] }
clap = { version = "4", features = ["derive"] }
pretty_env_logger = "0.5"
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
mockito = "1.3.1"
tracing-subscriber = "0.3"
serde_urlencoded = "0.7.1"
dashmap = "5"
2 changes: 1 addition & 1 deletion snowflake-api/examples/tracing/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async fn main() -> Result<()> {

#[tracing::instrument(name = "snowflake_api", skip(api))]
async fn run_in_span(api: &snowflake_api::SnowflakeApi) -> anyhow::Result<()> {
let res = api.exec("select 'hello from snowflake' as col1;").await?;
let res = api.exec("select 1;").await?;

match res {
QueryResult::Arrow(a) => {
Expand Down
143 changes: 118 additions & 25 deletions snowflake-api/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use http::uri::Scheme;
use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue};
use reqwest_middleware::ClientWithMiddleware;
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::RetryTransientMiddleware;
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
use thiserror::Error;
use url::Url;
use uuid::Uuid;

use crate::middleware::UuidMiddleware;

#[derive(Error, Debug)]
pub enum ConnectionError {
Expand Down Expand Up @@ -73,13 +74,19 @@ impl QueryType {
pub struct Connection {
// no need for Arc as it's already inside the reqwest client
client: ClientWithMiddleware,
base_url: String,
scheme: http::uri::Scheme,
}

impl Connection {
pub fn new() -> Result<Self, ConnectionError> {
let client = Self::default_client_builder()?;

Ok(Self::new_with_middware(client.build()))
Ok(Self::new_with_middware(
client.build(),
None,
Some(http::uri::Scheme::HTTPS),
))
}

/// Allow a user to provide their own middleware
Expand All @@ -89,11 +96,19 @@ impl Connection {
/// use snowflake_api::connection::Connection;
/// let mut client = Connection::default_client_builder();
/// // modify the client builder here
/// let connection = Connection::new_with_middware(client.unwrap().build());
/// let connection = Connection::new_with_middware(client.unwrap().build(), None, Some(http::uri::Scheme::HTTPS));
/// ```
/// This is not intended to be called directly, but is used by `SnowflakeApiBuilder::with_client`
pub fn new_with_middware(client: ClientWithMiddleware) -> Self {
Self { client }
pub fn new_with_middware(
client: ClientWithMiddleware,
base_url: Option<String>,
scheme: Option<Scheme>,
) -> Self {
Self {
client,
base_url: base_url.unwrap_or(".snowflakecomputing.com".to_string()),
scheme: scheme.unwrap_or(Scheme::HTTPS),
}
}

pub fn default_client_builder() -> Result<reqwest_middleware::ClientBuilder, ConnectionError> {
Expand All @@ -110,7 +125,8 @@ impl Connection {
let client = client.build()?;

Ok(reqwest_middleware::ClientBuilder::new(client)
.with(RetryTransientMiddleware::new_with_policy(retry_policy)))
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.with(UuidMiddleware))
}

/// Perform request of given query type with extra body or parameters
Expand All @@ -126,27 +142,12 @@ impl Connection {
) -> Result<R, ConnectionError> {
let context = query_type.query_context();

let request_id = Uuid::new_v4();
let request_guid = Uuid::new_v4();
let client_start_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
.to_string();
// fixme: update uuid's on the retry
let request_id = request_id.to_string();
let request_guid = request_guid.to_string();

let mut get_params = vec![
("clientStartTime", client_start_time.as_str()),
("requestId", request_id.as_str()),
("request_guid", request_guid.as_str()),
];
let mut get_params = vec![];
get_params.extend_from_slice(extra_get_params);

let url = format!(
"https://{}.snowflakecomputing.com/{}",
&account_identifier, context.path
"{}://{}{}/{}",
self.scheme, &account_identifier, self.base_url, context.path
);
let url = Url::parse_with_params(&url, get_params)?;

Expand Down Expand Up @@ -197,3 +198,95 @@ impl Connection {
Ok(bytes)
}
}

#[cfg(test)]
mod tests {
use super::*;
use dashmap::DashMap;
use http::uri::Scheme;
use serde_json::json;
use std::sync::Arc;
use uuid::Uuid;

#[tokio::test]
async fn test_request() {
tracing_subscriber::fmt::init();

let opts = mockito::ServerOpts {
host: "127.0.0.1",
port: 1234,
..Default::default()
};

let client = Connection::default_client_builder();
let conn = Connection::new_with_middware(
client.unwrap().build(),
Some("127.0.0.1:1234".to_string()),
Some(Scheme::HTTP),
);

let mut server = mockito::Server::new_with_opts_async(opts).await;

let ctx = QueryType::LoginRequest.query_context();

// using a dashmap to capture the requestIds across
// all requests to our mock server
let request_ids = Arc::new(DashMap::new());
let request_ids_clone = Arc::clone(&request_ids);

let _m1 = server
.mock("POST", "/session/v1/login-request")
.match_query(mockito::Matcher::Any)
// force an error to validate retries
.with_status(500)
.with_header("content-type", ctx.accept_mime)
// mechanism to validate the request body (feed it back to the client)
.with_body_from_request(move |request| {
let path_and_query = request.path_and_query();
let binding = String::new();
let query = path_and_query.split('?').nth(1).unwrap_or(&binding);
let params: HashMap<String, String> =
serde_urlencoded::from_str(query).unwrap_or_else(|_| HashMap::new());

let another_binding = String::new();
let request_id = params.get("requestId").unwrap_or(&another_binding);

request_ids_clone.insert(request_id.clone(), true);

let body = json!({"error": "an error happened", "requestId": request_id});
body.to_string().as_bytes().to_vec()
})
.expect(4)
.create_async()
.await;

match conn
.request::<serde_json::Value>(
QueryType::LoginRequest,
"",
&[],
None,
json!({"query": "SELECT 1"}),
)
.await
{
Ok(res) => {
assert_eq!(res["error"], "an error happened");
}
Err(e) => {
log::error!("Error: {}", e);
}
};

// assert that all requests were made with different requestIds
assert_eq!(request_ids.len(), 4);

request_ids.iter().for_each(|entry| {
let request_id = entry.key();
log::info!("Captured Request ID: {}", request_id);
assert_eq!(Uuid::parse_str(request_id).is_ok(), true);
});

_m1.assert_async().await;
}
}
9 changes: 8 additions & 1 deletion snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use arrow::record_batch::RecordBatch;
use base64::Engine;
use bytes::{Buf, Bytes};
use futures::future::try_join_all;
use http::uri::Scheme;
use object_store::aws::AmazonS3Builder;
use object_store::local::LocalFileSystem;
use object_store::ObjectStore;
Expand All @@ -42,6 +43,8 @@ use crate::responses::{
};

pub mod connection;

mod middleware;
#[cfg(feature = "polars")]
mod polars;
mod requests;
Expand Down Expand Up @@ -218,7 +221,11 @@ impl SnowflakeApiBuilder {

pub fn build(self) -> Result<SnowflakeApi, SnowflakeApiError> {
let connection = match self.client {
Some(client) => Arc::new(Connection::new_with_middware(client)),
Some(client) => Arc::new(Connection::new_with_middware(
client,
None,
Some(Scheme::HTTPS),
)),
None => Arc::new(Connection::new()?),
};

Expand Down
40 changes: 40 additions & 0 deletions snowflake-api/src/middleware.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use reqwest::Request;
use reqwest::Response;
use reqwest_middleware::{Middleware, Next, Result as MiddlewareResult};

use std::time::{SystemTime, UNIX_EPOCH};

use task_local_extensions::Extensions;
use uuid::Uuid;

pub struct UuidMiddleware;

#[async_trait::async_trait]
impl Middleware for UuidMiddleware {
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> MiddlewareResult<Response> {
let request_id = Uuid::new_v4();
let request_guid = Uuid::new_v4();
let client_start_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();

let mut new_req = req.try_clone().unwrap();

// Modify the request URL to include the new UUIDs and client start time
let url = new_req.url_mut();

let query = format!(
"{}&clientStartTime={client_start_time}&requestId={request_id}&request_guid={request_guid}",
url.query().unwrap_or("")
);

url.set_query(Some(query.as_str()));
next.run(new_req, extensions).await
}
}
2 changes: 1 addition & 1 deletion snowflake-api/src/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ pub struct ExecResponseRowType {
pub nullable: bool,
}

// fixme: is it good idea to keep this as an enum if more types could be added in future?
#[derive(Deserialize, Debug)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum SnowflakeType {
Fixed,
Real,
Expand Down
Loading