Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for async query response #42

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion snowflake-api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Since it does a lot of I/O the library is async-only, and currently has hard dep
- [x] PUT support [example](./examples/filetransfer.rs)
- [ ] GET support
- [x] AWS integration
- [ ] GCloud integration
- [ ] `GCloud` integration
- [ ] Azure integration
- [x] Parallel uploading of small files
- [x] Glob support for PUT (eg `*.csv`)
Expand Down
47 changes: 31 additions & 16 deletions snowflake-api/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ pub enum ConnectionError {
/// Container for query parameters
/// This API has different endpoints and MIME types for different requests
struct QueryContext {
path: &'static str,
path: String,
accept_mime: &'static str,
method: reqwest::Method,
}

pub enum QueryType {
Expand All @@ -39,30 +40,40 @@ pub enum QueryType {
CloseSession,
JsonQuery,
ArrowQuery,
ArrowQueryResult(String),
}

impl QueryType {
const fn query_context(&self) -> QueryContext {
fn query_context(&self) -> QueryContext {
match self {
Self::LoginRequest => QueryContext {
path: "session/v1/login-request",
path: "session/v1/login-request".to_string(),
accept_mime: "application/json",
method: reqwest::Method::POST,
},
Self::TokenRequest => QueryContext {
path: "/session/token-request",
path: "/session/token-request".to_string(),
accept_mime: "application/snowflake",
method: reqwest::Method::POST,
},
Self::CloseSession => QueryContext {
path: "session",
path: "session".to_string(),
accept_mime: "application/snowflake",
method: reqwest::Method::POST,
},
Self::JsonQuery => QueryContext {
path: "queries/v1/query-request",
path: "queries/v1/query-request".to_string(),
accept_mime: "application/json",
method: reqwest::Method::POST,
},
Self::ArrowQuery => QueryContext {
path: "queries/v1/query-request",
path: "queries/v1/query-request".to_string(),
accept_mime: "application/snowflake",
method: reqwest::Method::POST,
},
Self::ArrowQueryResult(query_result_url) => QueryContext {
path: query_result_url.to_string(),
accept_mime: "application/snowflake",
method: reqwest::Method::GET,
},
}
}
Expand Down Expand Up @@ -163,14 +174,18 @@ impl Connection {
}

// todo: persist client to use connection polling
let resp = self
.client
.post(url)
.headers(headers)
.json(&body)
.send()
.await?;

let resp = match context.method {
reqwest::Method::POST => {
self.client
.post(url)
.headers(headers)
.json(&body)
.send()
.await?
}
reqwest::Method::GET => self.client.get(url).headers(headers).send().await?,
_ => panic!("Unsupported method"),
};
Ok(resp.json::<R>().await?)
}

Expand Down
51 changes: 47 additions & 4 deletions snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,9 @@ impl SnowflakeApi {
log::debug!("Got PUT response: {:?}", resp);

match resp {
ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::Query(_) | ExecResponse::QueryAsync(_) => {
Err(SnowflakeApiError::UnexpectedResponse)
}
ExecResponse::PutGet(pg) => put::put(pg).await,
ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
e.data.error_code,
Expand All @@ -430,15 +432,25 @@ impl SnowflakeApi {
}

async fn exec_arrow_raw(&self, sql: &str) -> Result<RawQueryResult, SnowflakeApiError> {
let resp = self
let mut resp = self
.run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
.await?;
log::debug!("Got query response: {:?}", resp);

if let ExecResponse::QueryAsync(data) = &resp {
log::debug!("Got async exec response");
resp = self
.get_async_exec_result(&data.data.get_result_url)
.await?;
log::debug!("Got result for async exec: {:?}", resp);
}

let resp = match resp {
// processable response
ExecResponse::Query(qr) => Ok(qr),
ExecResponse::PutGet(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::PutGet(_) | ExecResponse::QueryAsync(_) => {
Err(SnowflakeApiError::UnexpectedResponse)
}
ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
e.data.error_code,
e.message.unwrap_or_default(),
Expand Down Expand Up @@ -504,10 +516,41 @@ impl SnowflakeApi {
&self.account_identifier,
&[],
Some(&parts.session_token_auth_header),
body,
Some(body),
)
.await?;

Ok(resp)
}

pub async fn get_async_exec_result(
&self,
query_result_url: &String,
) -> Result<ExecResponse, SnowflakeApiError> {
log::debug!("Getting async exec result: {}", query_result_url);

let mut delay = 1; // Initial delay of 1 second

loop {
let parts = self.session.get_token().await?;
let resp = self
.connection
.request::<ExecResponse>(
QueryType::ArrowQueryResult(query_result_url.to_string()),
&self.account_identifier,
&[],
Some(&parts.session_token_auth_header),
serde_json::Value::default(),
)
.await?;

if let ExecResponse::QueryAsync(_) = &resp {
// simple exponential retry with a maximum wait time of 5 seconds
tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await;
delay = (delay * 2).min(5); // cap delay to 5 seconds
} else {
return Ok(resp);
}
}
}
}
24 changes: 22 additions & 2 deletions snowflake-api/src/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ use serde::Deserialize;
#[serde(untagged)]
pub enum ExecResponse {
Query(QueryExecResponse),
QueryAsync(QueryAsyncExecResponse),
PutGet(PutGetExecResponse),
Error(ExecErrorResponse),
}

// todo: add close session response, which should be just empty?
#[allow(clippy::large_enum_variant)]
// FIXME: dead_code
#[allow(clippy::large_enum_variant, dead_code)]
#[derive(Deserialize, Debug)]
#[serde(untagged)]
pub enum AuthResponse {
Expand All @@ -34,6 +36,7 @@ pub struct BaseRestResponse<D> {

pub type PutGetExecResponse = BaseRestResponse<PutGetResponseData>;
pub type QueryExecResponse = BaseRestResponse<QueryExecResponseData>;
pub type QueryAsyncExecResponse = BaseRestResponse<QueryAsyncExecResponseData>;
pub type ExecErrorResponse = BaseRestResponse<ExecErrorResponseData>;
pub type AuthErrorResponse = BaseRestResponse<AuthErrorResponseData>;
pub type AuthenticatorResponse = BaseRestResponse<AuthenticatorResponseData>;
Expand All @@ -54,12 +57,14 @@ pub struct ExecErrorResponseData {
pub pos: Option<i64>,

// fixme: only valid for exec query response error? present in any exec query response?
pub query_id: String,
pub query_id: Option<String>,
sgrebnov marked this conversation as resolved.
Show resolved Hide resolved
pub sql_state: String,
}

#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
// FIXME: dead_code
#[allow(dead_code)]
pub struct AuthErrorResponseData {
pub authn_method: String,
}
Expand All @@ -72,6 +77,8 @@ pub struct NameValueParameter {

#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
// FIXME
#[allow(dead_code)]
pub struct LoginResponseData {
pub session_id: i64,
pub token: String,
Expand All @@ -86,6 +93,8 @@ pub struct LoginResponseData {

#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
// FIXME: dead_code
#[allow(dead_code)]
pub struct SessionInfo {
pub database_name: Option<String>,
pub schema_name: Option<String>,
Expand All @@ -95,6 +104,8 @@ pub struct SessionInfo {

#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
// FIXME: dead_code
#[allow(dead_code)]
pub struct AuthenticatorResponseData {
pub token_url: String,
pub sso_url: String,
Expand All @@ -103,6 +114,8 @@ pub struct AuthenticatorResponseData {

#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
// FIXME: dead_code
#[allow(dead_code)]
pub struct RenewSessionResponseData {
pub session_token: String,
pub validity_in_seconds_s_t: i64,
Expand Down Expand Up @@ -151,6 +164,13 @@ pub struct QueryExecResponseData {
// `sendResultTime`, `queryResultFormat`, `queryContext` also exist
}

#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct QueryAsyncExecResponseData {
pub query_id: String,
pub get_result_url: String,
}

#[derive(Deserialize, Debug)]
pub struct ExecResponseRowType {
pub name: String,
Expand Down
Loading