From 6c4ed33b18b6bcb1acbeff40249bd32d5fb47dc1 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 13 Aug 2024 16:27:39 -0400 Subject: [PATCH] Allow streaming ingest for FlightClient::execute_ingest --- arrow-flight/src/sql/client.rs | 13 ++++++++----- arrow-flight/tests/common/fixture.rs | 1 + arrow-flight/tests/flight_sql_client.rs | 14 +++++++++----- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index 3f0ed7e07a1e..cfd6c9e78992 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -53,7 +53,7 @@ use arrow_ipc::convert::fb_to_schema; use arrow_ipc::reader::read_record_batch; use arrow_ipc::{root_as_message, MessageHeader}; use arrow_schema::{ArrowError, Schema, SchemaRef}; -use futures::{stream, StreamExt, TryStreamExt}; +use futures::{stream, Stream, StreamExt, TryStreamExt}; use prost::Message; use tonic::transport::Channel; use tonic::{IntoRequest, IntoStreamingRequest, Streaming}; @@ -228,15 +228,18 @@ impl FlightSqlServiceClient { } /// Execute a bulk ingest on the server and return the number of records added - pub async fn execute_ingest( + pub async fn execute_ingest( &mut self, command: CommandStatementIngest, - batches: Vec, - ) -> Result { + stream: S, + ) -> Result + where + S: Stream> + Send + 'static, + { let descriptor = FlightDescriptor::new_cmd(command.as_any().encode_to_vec()); let flight_data_encoder = FlightDataEncoderBuilder::new() .with_flight_descriptor(Some(descriptor)) - .build(stream::iter(batches).map(Ok)); + .build(stream); // Safe unwrap, explicitly wrapped on line above. let flight_data = flight_data_encoder.map(|fd| fd.unwrap()); let req = self.set_request_headers(flight_data.into_streaming_request())?; diff --git a/arrow-flight/tests/common/fixture.rs b/arrow-flight/tests/common/fixture.rs index 141879e2a358..a666fa5d0d59 100644 --- a/arrow-flight/tests/common/fixture.rs +++ b/arrow-flight/tests/common/fixture.rs @@ -41,6 +41,7 @@ pub struct TestFixture { impl TestFixture { /// create a new test fixture from the server + #[allow(dead_code)] pub async fn new(test_server: FlightServiceServer) -> Self { // let OS choose a free port let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); diff --git a/arrow-flight/tests/flight_sql_client.rs b/arrow-flight/tests/flight_sql_client.rs index 81d58138ed11..e65309493649 100644 --- a/arrow-flight/tests/flight_sql_client.rs +++ b/arrow-flight/tests/flight_sql_client.rs @@ -31,7 +31,7 @@ use arrow_flight::sql::{ TableNotExistOption, }; use arrow_flight::Action; -use futures::TryStreamExt; +use futures::{StreamExt, TryStreamExt}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::Mutex; @@ -40,9 +40,7 @@ use uuid::Uuid; #[tokio::test] pub async fn test_begin_end_transaction() { - let test_server = FlightSqlServiceImpl { - transactions: Arc::new(Mutex::new(HashMap::new())), - }; + let test_server = FlightSqlServiceImpl::new(); let fixture = TestFixture::new(test_server.service()).await; let channel = fixture.channel().await; let mut flight_sql_client = FlightSqlServiceClient::new(channel); @@ -94,21 +92,26 @@ pub async fn test_execute_ingest() { make_primitive_batch(2), ]; let actual_rows = flight_sql_client - .execute_ingest(cmd, batches) + .execute_ingest(cmd, futures::stream::iter(batches.clone()).map(Ok)) .await .expect("ingest should succeed"); assert_eq!(actual_rows, expected_rows); + // make sure the batches made it through to the server + let ingested_batches = test_server.ingested_batches.lock().await.clone(); + assert_eq!(ingested_batches, batches); } #[derive(Clone)] pub struct FlightSqlServiceImpl { transactions: Arc>>, + ingested_batches: Arc>>, } impl FlightSqlServiceImpl { pub fn new() -> Self { Self { transactions: Arc::new(Mutex::new(HashMap::new())), + ingested_batches: Arc::new(Mutex::new(Vec::new())), } } @@ -177,6 +180,7 @@ impl FlightSqlService for FlightSqlServiceImpl { .try_collect() .await?; let affected_rows = batches.iter().map(|batch| batch.num_rows() as i64).sum(); + *self.ingested_batches.lock().await.as_mut() = batches; Ok(affected_rows) } }