Skip to content

Commit

Permalink
Merge pull request #1 from alamb/alamb/flight_ingest_check
Browse files Browse the repository at this point in the history
Allow streaming ingest for FlightClient::execute_ingest
  • Loading branch information
djanderson authored Aug 15, 2024
2 parents b48a060 + 6c4ed33 commit 5641350
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
13 changes: 8 additions & 5 deletions arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ use arrow_ipc::convert::fb_to_schema;
use arrow_ipc::reader::read_record_batch;
use arrow_ipc::{root_as_message, MessageHeader};
use arrow_schema::{ArrowError, Schema, SchemaRef};
use futures::{stream, StreamExt, TryStreamExt};
use futures::{stream, Stream, StreamExt, TryStreamExt};
use prost::Message;
use tonic::transport::Channel;
use tonic::{IntoRequest, IntoStreamingRequest, Streaming};
Expand Down Expand Up @@ -228,15 +228,18 @@ impl FlightSqlServiceClient<Channel> {
}

/// Execute a bulk ingest on the server and return the number of records added
pub async fn execute_ingest(
pub async fn execute_ingest<S>(
&mut self,
command: CommandStatementIngest,
batches: Vec<RecordBatch>,
) -> Result<i64, ArrowError> {
stream: S,
) -> Result<i64, ArrowError>
where
S: Stream<Item = crate::error::Result<RecordBatch>> + Send + 'static,
{
let descriptor = FlightDescriptor::new_cmd(command.as_any().encode_to_vec());
let flight_data_encoder = FlightDataEncoderBuilder::new()
.with_flight_descriptor(Some(descriptor))
.build(stream::iter(batches).map(Ok));
.build(stream);
// Safe unwrap, explicitly wrapped on line above.
let flight_data = flight_data_encoder.map(|fd| fd.unwrap());
let req = self.set_request_headers(flight_data.into_streaming_request())?;
Expand Down
1 change: 1 addition & 0 deletions arrow-flight/tests/common/fixture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub struct TestFixture {

impl TestFixture {
/// create a new test fixture from the server
#[allow(dead_code)]
pub async fn new<T: FlightService>(test_server: FlightServiceServer<T>) -> Self {
// let OS choose a free port
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
Expand Down
14 changes: 9 additions & 5 deletions arrow-flight/tests/flight_sql_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use arrow_flight::sql::{
TableNotExistOption,
};
use arrow_flight::Action;
use futures::TryStreamExt;
use futures::{StreamExt, TryStreamExt};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
Expand All @@ -40,9 +40,7 @@ use uuid::Uuid;

#[tokio::test]
pub async fn test_begin_end_transaction() {
let test_server = FlightSqlServiceImpl {
transactions: Arc::new(Mutex::new(HashMap::new())),
};
let test_server = FlightSqlServiceImpl::new();
let fixture = TestFixture::new(test_server.service()).await;
let channel = fixture.channel().await;
let mut flight_sql_client = FlightSqlServiceClient::new(channel);
Expand Down Expand Up @@ -94,21 +92,26 @@ pub async fn test_execute_ingest() {
make_primitive_batch(2),
];
let actual_rows = flight_sql_client
.execute_ingest(cmd, batches)
.execute_ingest(cmd, futures::stream::iter(batches.clone()).map(Ok))
.await
.expect("ingest should succeed");
assert_eq!(actual_rows, expected_rows);
// make sure the batches made it through to the server
let ingested_batches = test_server.ingested_batches.lock().await.clone();
assert_eq!(ingested_batches, batches);
}

#[derive(Clone)]
pub struct FlightSqlServiceImpl {
transactions: Arc<Mutex<HashMap<String, ()>>>,
ingested_batches: Arc<Mutex<Vec<RecordBatch>>>,
}

impl FlightSqlServiceImpl {
pub fn new() -> Self {
Self {
transactions: Arc::new(Mutex::new(HashMap::new())),
ingested_batches: Arc::new(Mutex::new(Vec::new())),
}
}

Expand Down Expand Up @@ -177,6 +180,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
.try_collect()
.await?;
let affected_rows = batches.iter().map(|batch| batch.num_rows() as i64).sum();
*self.ingested_batches.lock().await.as_mut() = batches;
Ok(affected_rows)
}
}

0 comments on commit 5641350

Please sign in to comment.