Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better flight SQL example codes #4144

Merged
merged 5 commits into from
Apr 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 107 additions & 89 deletions arrow-flight/examples/flight_sql_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,7 @@ impl ProstMessageExt for FetchResults {
#[cfg(test)]
mod tests {
use super::*;
use futures::future::BoxFuture;
use futures::{FutureExt, TryStreamExt};
use futures::TryStreamExt;
use std::fs;
use std::future::Future;
use std::net::SocketAddr;
Expand All @@ -571,42 +570,6 @@ mod tests {
(incoming, addr)
}

async fn client_with_uds(path: String) -> FlightSqlServiceClient<Channel> {
let connector = service_fn(move |_| UnixStream::connect(path.clone()));
let channel = Endpoint::try_from("http://example.com")
.unwrap()
.connect_with_connector(connector)
.await
.unwrap();
FlightSqlServiceClient::new(channel)
}

type ServeFut = BoxFuture<'static, Result<(), tonic::transport::Error>>;

async fn create_https_server(
) -> Result<(ServeFut, SocketAddr), tonic::transport::Error> {
let cert = std::fs::read_to_string("examples/data/server.pem").unwrap();
let key = std::fs::read_to_string("examples/data/server.key").unwrap();
let client_ca = std::fs::read_to_string("examples/data/client_ca.pem").unwrap();

let tls_config = ServerTlsConfig::new()
.identity(Identity::from_pem(&cert, &key))
.client_ca_root(Certificate::from_pem(&client_ca));

let (incoming, addr) = bind_tcp().await;

let svc = FlightServiceServer::new(FlightSqlServiceImpl {});

let serve = Server::builder()
.tls_config(tls_config)
.unwrap()
.add_service(svc)
.serve_with_incoming(incoming)
.boxed();

Ok((serve, addr))
}

fn endpoint(uri: String) -> Result<Endpoint, ArrowError> {
let endpoint = Endpoint::new(uri)
.map_err(|_| ArrowError::IoError("Cannot create endpoint".to_string()))?
Expand All @@ -621,56 +584,12 @@ mod tests {
Ok(endpoint)
}

#[tokio::test]
async fn test_select_https() {
let (serve, addr) = create_https_server().await.unwrap();
let uri = format!("https://{}:{}", addr.ip(), addr.port());

let request_future = async {
let cert = std::fs::read_to_string("examples/data/client1.pem").unwrap();
let key = std::fs::read_to_string("examples/data/client1.key").unwrap();
let server_ca = std::fs::read_to_string("examples/data/ca.pem").unwrap();

let tls_config = ClientTlsConfig::new()
.domain_name("localhost")
.ca_certificate(Certificate::from_pem(&server_ca))
.identity(Identity::from_pem(cert, key));
let endpoint = endpoint(uri).unwrap().tls_config(tls_config).unwrap();
let channel = endpoint.connect().await.unwrap();
let mut client = FlightSqlServiceClient::new(channel);
let token = client.handshake("admin", "password").await.unwrap();
client.set_token(String::from_utf8(token.to_vec()).unwrap());
println!("Auth succeeded with token: {:?}", token);
let mut stmt = client.prepare("select 1;".to_string()).await.unwrap();
let flight_info = stmt.execute().await.unwrap();
let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone();
let flight_data = client.do_get(ticket).await.unwrap();
let flight_data: Vec<FlightData> = flight_data.try_collect().await.unwrap();
let batches = flight_data_to_batches(&flight_data).unwrap();
let res = pretty_format_batches(batches.as_slice()).unwrap();
let expected = r#"
+-------------------+
| salutation |
+-------------------+
| Hello, FlightSQL! |
+-------------------+"#
.trim()
.to_string();
assert_eq!(res.to_string(), expected);
};

tokio::select! {
_ = serve => panic!("server finished"),
_ = request_future => println!("Client finished!"),
}
}

async fn auth_client(client: &mut FlightSqlServiceClient<Channel>) {
let token = client.handshake("admin", "password").await.unwrap();
client.set_token(String::from_utf8(token.to_vec()).unwrap());
}

async fn test_client<F, C>(f: F)
async fn test_uds_client<F, C>(f: F)
where
F: FnOnce(FlightSqlServiceClient<Channel>) -> C,
C: Future<Output = ()>,
Expand All @@ -682,14 +601,91 @@ mod tests {
let uds = UnixListener::bind(path.clone()).unwrap();
let stream = UnixListenerStream::new(uds);

// We would just listen on TCP, but it seems impossible to know when tonic is ready to serve
let service = FlightSqlServiceImpl {};
let serve_future = Server::builder()
.add_service(FlightServiceServer::new(service))
.serve_with_incoming(stream);

let request_future = async {
let client = client_with_uds(path).await;
let connector = service_fn(move |_| UnixStream::connect(path.clone()));
let channel = Endpoint::try_from("http://example.com")
.unwrap()
.connect_with_connector(connector)
.await
.unwrap();
let client = FlightSqlServiceClient::new(channel);
f(client).await
};

tokio::select! {
_ = serve_future => panic!("server returned first"),
_ = request_future => println!("Client finished!"),
}
}

async fn test_http_client<F, C>(f: F)
where
F: FnOnce(FlightSqlServiceClient<Channel>) -> C,
C: Future<Output = ()>,
{
let (incoming, addr) = bind_tcp().await;
let uri = format!("http://{}:{}", addr.ip(), addr.port());

let service = FlightSqlServiceImpl {};
let serve_future = Server::builder()
.add_service(FlightServiceServer::new(service))
.serve_with_incoming(incoming);

let request_future = async {
let endpoint = endpoint(uri).unwrap();
let channel = endpoint.connect().await.unwrap();
let client = FlightSqlServiceClient::new(channel);
f(client).await
};

tokio::select! {
_ = serve_future => panic!("server returned first"),
_ = request_future => println!("Client finished!"),
}
}

async fn test_https_client<F, C>(f: F)
where
F: FnOnce(FlightSqlServiceClient<Channel>) -> C,
C: Future<Output = ()>,
{
let cert = std::fs::read_to_string("examples/data/server.pem").unwrap();
let key = std::fs::read_to_string("examples/data/server.key").unwrap();
let client_ca = std::fs::read_to_string("examples/data/client_ca.pem").unwrap();

let tls_config = ServerTlsConfig::new()
.identity(Identity::from_pem(&cert, &key))
.client_ca_root(Certificate::from_pem(&client_ca));

let (incoming, addr) = bind_tcp().await;
let uri = format!("https://{}:{}", addr.ip(), addr.port());

let svc = FlightServiceServer::new(FlightSqlServiceImpl {});

let serve_future = Server::builder()
.tls_config(tls_config)
.unwrap()
.add_service(svc)
.serve_with_incoming(incoming);

let request_future = async {
let cert = std::fs::read_to_string("examples/data/client1.pem").unwrap();
let key = std::fs::read_to_string("examples/data/client1.key").unwrap();
let server_ca = std::fs::read_to_string("examples/data/ca.pem").unwrap();

let tls_config = ClientTlsConfig::new()
.domain_name("localhost")
.ca_certificate(Certificate::from_pem(&server_ca))
.identity(Identity::from_pem(cert, key));

let endpoint = endpoint(uri).unwrap().tls_config(tls_config).unwrap();
let channel = endpoint.connect().await.unwrap();
let client = FlightSqlServiceClient::new(channel);
f(client).await
};

Expand All @@ -699,16 +695,38 @@ mod tests {
}
}

async fn test_all_clients<F, C>(task: F)
where
F: FnOnce(FlightSqlServiceClient<Channel>) -> C + Copy,
C: Future<Output = ()>,
{
println!("testing uds client");
test_uds_client(task).await;
println!("=======");

println!("testing http client");
test_http_client(task).await;
println!("=======");

println!("testing https client");
test_https_client(task).await;
println!("=======");
}

#[tokio::test]
async fn test_select_1() {
test_client(|mut client| async move {
async fn test_select() {
test_all_clients(|mut client| async move {
auth_client(&mut client).await;

let mut stmt = client.prepare("select 1;".to_string()).await.unwrap();

let flight_info = stmt.execute().await.unwrap();

let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone();
let flight_data = client.do_get(ticket).await.unwrap();
let flight_data: Vec<FlightData> = flight_data.try_collect().await.unwrap();
let batches = flight_data_to_batches(&flight_data).unwrap();

let res = pretty_format_batches(batches.as_slice()).unwrap();
let expected = r#"
+-------------------+
Expand All @@ -725,7 +743,7 @@ mod tests {

#[tokio::test]
async fn test_execute_update() {
test_client(|mut client| async move {
test_all_clients(|mut client| async move {
auth_client(&mut client).await;
let res = client
.execute_update("creat table test(a int);".to_string())
Expand All @@ -738,7 +756,7 @@ mod tests {

#[tokio::test]
async fn test_auth() {
test_client(|mut client| async move {
test_all_clients(|mut client| async move {
// no handshake
assert!(client
.prepare("select 1;".to_string())
Expand Down