diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index d5168debc433..81afecf85625 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -46,9 +46,9 @@ use arrow_flight::sql::{ ActionEndTransactionRequest, Any, CommandGetCatalogs, CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, - CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, - CommandStatementSubstraitPlan, CommandStatementUpdate, Nullable, ProstMessageExt, Searchable, - SqlInfo, TicketStatementQuery, XdbcDataType, + CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementIngest, + CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, Nullable, + ProstMessageExt, Searchable, SqlInfo, TicketStatementQuery, XdbcDataType, }; use arrow_flight::utils::batches_to_flight_data; use arrow_flight::{ @@ -615,6 +615,14 @@ impl FlightSqlService for FlightSqlServiceImpl { Ok(FAKE_UPDATE_RESULT) } + async fn do_put_statement_ingest( + &self, + _ticket: CommandStatementIngest, + _request: Request, + ) -> Result { + Ok(FAKE_UPDATE_RESULT) + } + async fn do_put_substrait_plan( &self, _ticket: CommandStatementSubstraitPlan, diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index 3f62b256d56e..af3c8fba30ff 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -679,7 +679,7 @@ impl FlightClient { /// it encounters an error it uses the oneshot sender to /// notify the error and stop any further streaming. See `do_put` or /// `do_exchange` for it's uses. -struct FallibleRequestStream { +pub(crate) struct FallibleRequestStream { /// sender to notify error sender: Option>, /// fallible stream @@ -687,7 +687,7 @@ struct FallibleRequestStream { } impl FallibleRequestStream { - fn new( + pub(crate) fn new( sender: Sender, fallible_stream: Pin> + Send + 'static>>, ) -> Self { diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index 91790898b1cb..9f9963c92531 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -24,6 +24,7 @@ use std::collections::HashMap; use std::str::FromStr; use tonic::metadata::AsciiMetadataKey; +use crate::client::FallibleRequestStream; use crate::decode::FlightRecordBatchStream; use crate::encode::FlightDataEncoderBuilder; use crate::error::FlightError; @@ -39,8 +40,8 @@ use crate::sql::{ CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, - CommandStatementQuery, CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, - ProstMessageExt, SqlInfo, + CommandStatementIngest, CommandStatementQuery, CommandStatementUpdate, + DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo, }; use crate::trailers::extract_lazy_trailers; use crate::{ @@ -53,10 +54,10 @@ use arrow_ipc::convert::fb_to_schema; use arrow_ipc::reader::read_record_batch; use arrow_ipc::{root_as_message, MessageHeader}; use arrow_schema::{ArrowError, Schema, SchemaRef}; -use futures::{stream, TryStreamExt}; +use futures::{stream, Stream, TryStreamExt}; use prost::Message; use tonic::transport::Channel; -use tonic::{IntoRequest, Streaming}; +use tonic::{IntoRequest, IntoStreamingRequest, Streaming}; /// A FlightSQLServiceClient is an endpoint for retrieving or storing Arrow data /// by FlightSQL protocol. @@ -227,6 +228,52 @@ impl FlightSqlServiceClient { Ok(result.record_count) } + /// Execute a bulk ingest on the server and return the number of records added + pub async fn execute_ingest( + &mut self, + command: CommandStatementIngest, + stream: S, + ) -> Result + where + S: Stream> + Send + 'static, + { + let (sender, receiver) = futures::channel::oneshot::channel(); + + let descriptor = FlightDescriptor::new_cmd(command.as_any().encode_to_vec()); + let flight_data = FlightDataEncoderBuilder::new() + .with_flight_descriptor(Some(descriptor)) + .build(stream); + + // Intercept client errors and send them to the one shot channel above + let flight_data = Box::pin(flight_data); + let flight_data: FallibleRequestStream = + FallibleRequestStream::new(sender, flight_data); + + let req = self.set_request_headers(flight_data.into_streaming_request())?; + let mut result = self + .flight_client + .do_put(req) + .await + .map_err(status_to_arrow_error)? + .into_inner(); + + // check if the there were any errors in the input stream provided note + // if receiver.await fails, it means the sender was dropped and there is + // no message to return. + if let Ok(msg) = receiver.await { + return Err(ArrowError::ExternalError(Box::new(msg))); + } + + let result = result + .message() + .await + .map_err(status_to_arrow_error)? + .unwrap(); + let any = Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; + let result: DoPutUpdateResult = any.unpack()?.unwrap(); + Ok(result.record_count) + } + /// Request a list of catalogs as tabular FlightInfo results pub async fn get_catalogs(&mut self) -> Result { self.get_flight_info_for_command(CommandGetCatalogs {}) diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs index 61eb67b6933e..453f608d353a 100644 --- a/arrow-flight/src/sql/mod.rs +++ b/arrow-flight/src/sql/mod.rs @@ -50,6 +50,10 @@ mod gen { } pub use gen::action_end_transaction_request::EndTransaction; +pub use gen::command_statement_ingest::table_definition_options::{ + TableExistsOption, TableNotExistOption, +}; +pub use gen::command_statement_ingest::TableDefinitionOptions; pub use gen::ActionBeginSavepointRequest; pub use gen::ActionBeginSavepointResult; pub use gen::ActionBeginTransactionRequest; @@ -74,6 +78,7 @@ pub use gen::CommandGetTables; pub use gen::CommandGetXdbcTypeInfo; pub use gen::CommandPreparedStatementQuery; pub use gen::CommandPreparedStatementUpdate; +pub use gen::CommandStatementIngest; pub use gen::CommandStatementQuery; pub use gen::CommandStatementSubstraitPlan; pub use gen::CommandStatementUpdate; @@ -250,11 +255,12 @@ prost_message_ext!( CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, + CommandStatementIngest, CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, - DoPutUpdateResult, DoPutPreparedStatementResult, + DoPutUpdateResult, TicketStatementQuery, ); diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index b47691c7da5d..e348367a91eb 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -32,9 +32,9 @@ use super::{ CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, - CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, - DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo, - TicketStatementQuery, + CommandStatementIngest, CommandStatementQuery, CommandStatementSubstraitPlan, + CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, + SqlInfo, TicketStatementQuery, }; use crate::{ flight_service_server::FlightService, gen::PollInfo, Action, ActionType, Criteria, Empty, @@ -397,6 +397,17 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static { )) } + /// Execute a bulk ingestion. + async fn do_put_statement_ingest( + &self, + _ticket: CommandStatementIngest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_put_statement_ingest has no default implementation", + )) + } + /// Bind parameters to given prepared statement. /// /// Returns an opaque handle that the client should pass @@ -713,6 +724,14 @@ where })]); Ok(Response::new(Box::pin(output))) } + Command::CommandStatementIngest(command) => { + let record_count = self.do_put_statement_ingest(command, request).await?; + let result = DoPutUpdateResult { record_count }; + let output = futures::stream::iter(vec![Ok(PutResult { + app_metadata: result.as_any().encode_to_vec().into(), + })]); + Ok(Response::new(Box::pin(output))) + } Command::CommandPreparedStatementQuery(command) => { let result = self .do_put_prepared_statement_query(command, request) diff --git a/arrow-flight/tests/common/fixture.rs b/arrow-flight/tests/common/fixture.rs index 141879e2a358..a666fa5d0d59 100644 --- a/arrow-flight/tests/common/fixture.rs +++ b/arrow-flight/tests/common/fixture.rs @@ -41,6 +41,7 @@ pub struct TestFixture { impl TestFixture { /// create a new test fixture from the server + #[allow(dead_code)] pub async fn new(test_server: FlightServiceServer) -> Self { // let OS choose a free port let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); diff --git a/arrow-flight/tests/common/mod.rs b/arrow-flight/tests/common/mod.rs index 85716e56058c..c4ac027c5890 100644 --- a/arrow-flight/tests/common/mod.rs +++ b/arrow-flight/tests/common/mod.rs @@ -18,3 +18,4 @@ pub mod fixture; pub mod server; pub mod trailers_layer; +pub mod utils; diff --git a/arrow-flight/tests/common/utils.rs b/arrow-flight/tests/common/utils.rs new file mode 100644 index 000000000000..0f70e4b31021 --- /dev/null +++ b/arrow-flight/tests/common/utils.rs @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common utilities for testing flight clients and servers + +use std::sync::Arc; + +use arrow_array::{ + types::Int32Type, ArrayRef, BinaryViewArray, DictionaryArray, Float64Array, RecordBatch, + StringViewArray, UInt8Array, +}; +use arrow_schema::{DataType, Field, Schema}; + +/// Make a primitive batch for testing +/// +/// Example: +/// i: 0, 1, None, 3, 4 +/// f: 5.0, 4.0, None, 2.0, 1.0 +#[allow(dead_code)] +pub fn make_primitive_batch(num_rows: usize) -> RecordBatch { + let i: UInt8Array = (0..num_rows) + .map(|i| { + if i == num_rows / 2 { + None + } else { + Some(i.try_into().unwrap()) + } + }) + .collect(); + + let f: Float64Array = (0..num_rows) + .map(|i| { + if i == num_rows / 2 { + None + } else { + Some((num_rows - i) as f64) + } + }) + .collect(); + + RecordBatch::try_from_iter(vec![("i", Arc::new(i) as ArrayRef), ("f", Arc::new(f))]).unwrap() +} + +/// Make a dictionary batch for testing +/// +/// Example: +/// a: value0, value1, value2, None, value1, value2 +#[allow(dead_code)] +pub fn make_dictionary_batch(num_rows: usize) -> RecordBatch { + let values: Vec<_> = (0..num_rows) + .map(|i| { + if i == num_rows / 2 { + None + } else { + // repeat some values for low cardinality + let v = i / 3; + Some(format!("value{v}")) + } + }) + .collect(); + + let a: DictionaryArray = values + .iter() + .map(|s| s.as_ref().map(|s| s.as_str())) + .collect(); + + RecordBatch::try_from_iter(vec![("a", Arc::new(a) as ArrayRef)]).unwrap() +} + +#[allow(dead_code)] +pub fn make_view_batches(num_rows: usize) -> RecordBatch { + const LONG_TEST_STRING: &str = + "This is a long string to make sure binary view array handles it"; + let schema = Schema::new(vec![ + Field::new("field1", DataType::BinaryView, true), + Field::new("field2", DataType::Utf8View, true), + ]); + + let string_view_values: Vec> = (0..num_rows) + .map(|i| match i % 3 { + 0 => None, + 1 => Some("foo"), + 2 => Some(LONG_TEST_STRING), + _ => unreachable!(), + }) + .collect(); + + let bin_view_values: Vec> = (0..num_rows) + .map(|i| match i % 3 { + 0 => None, + 1 => Some("bar".as_bytes()), + 2 => Some(LONG_TEST_STRING.as_bytes()), + _ => unreachable!(), + }) + .collect(); + + let binary_array = BinaryViewArray::from_iter(bin_view_values); + let utf8_array = StringViewArray::from_iter(string_view_values); + RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(binary_array), Arc::new(utf8_array)], + ) + .unwrap() +} diff --git a/arrow-flight/tests/encode_decode.rs b/arrow-flight/tests/encode_decode.rs index 0185fa77f067..cbfae1825845 100644 --- a/arrow-flight/tests/encode_decode.rs +++ b/arrow-flight/tests/encode_decode.rs @@ -19,11 +19,7 @@ use std::{collections::HashMap, sync::Arc}; -use arrow_array::types::Int32Type; -use arrow_array::{ - ArrayRef, BinaryViewArray, DictionaryArray, Float64Array, RecordBatch, StringViewArray, - UInt8Array, -}; +use arrow_array::{ArrayRef, RecordBatch}; use arrow_cast::pretty::pretty_format_batches; use arrow_flight::flight_descriptor::DescriptorType; use arrow_flight::FlightDescriptor; @@ -36,6 +32,9 @@ use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; use bytes::Bytes; use futures::{StreamExt, TryStreamExt}; +mod common; +use common::utils::{make_dictionary_batch, make_primitive_batch, make_view_batches}; + #[tokio::test] async fn test_empty() { roundtrip(vec![]).await; @@ -415,95 +414,6 @@ async fn test_mismatched_schema_message() { .await; } -/// Make a primitive batch for testing -/// -/// Example: -/// i: 0, 1, None, 3, 4 -/// f: 5.0, 4.0, None, 2.0, 1.0 -fn make_primitive_batch(num_rows: usize) -> RecordBatch { - let i: UInt8Array = (0..num_rows) - .map(|i| { - if i == num_rows / 2 { - None - } else { - Some(i.try_into().unwrap()) - } - }) - .collect(); - - let f: Float64Array = (0..num_rows) - .map(|i| { - if i == num_rows / 2 { - None - } else { - Some((num_rows - i) as f64) - } - }) - .collect(); - - RecordBatch::try_from_iter(vec![("i", Arc::new(i) as ArrayRef), ("f", Arc::new(f))]).unwrap() -} - -/// Make a dictionary batch for testing -/// -/// Example: -/// a: value0, value1, value2, None, value1, value2 -fn make_dictionary_batch(num_rows: usize) -> RecordBatch { - let values: Vec<_> = (0..num_rows) - .map(|i| { - if i == num_rows / 2 { - None - } else { - // repeat some values for low cardinality - let v = i / 3; - Some(format!("value{v}")) - } - }) - .collect(); - - let a: DictionaryArray = values - .iter() - .map(|s| s.as_ref().map(|s| s.as_str())) - .collect(); - - RecordBatch::try_from_iter(vec![("a", Arc::new(a) as ArrayRef)]).unwrap() -} - -fn make_view_batches(num_rows: usize) -> RecordBatch { - const LONG_TEST_STRING: &str = - "This is a long string to make sure binary view array handles it"; - let schema = Schema::new(vec![ - Field::new("field1", DataType::BinaryView, true), - Field::new("field2", DataType::Utf8View, true), - ]); - - let string_view_values: Vec> = (0..num_rows) - .map(|i| match i % 3 { - 0 => None, - 1 => Some("foo"), - 2 => Some(LONG_TEST_STRING), - _ => unreachable!(), - }) - .collect(); - - let bin_view_values: Vec> = (0..num_rows) - .map(|i| match i % 3 { - 0 => None, - 1 => Some("bar".as_bytes()), - 2 => Some(LONG_TEST_STRING.as_bytes()), - _ => unreachable!(), - }) - .collect(); - - let binary_array = BinaryViewArray::from_iter(bin_view_values); - let utf8_array = StringViewArray::from_iter(string_view_values); - RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(binary_array), Arc::new(utf8_array)], - ) - .unwrap() -} - /// Encodes input as a FlightData stream, and then decodes it using /// FlightRecordBatchStream and validates the decoded record batches /// match the input. diff --git a/arrow-flight/tests/flight_sql_client.rs b/arrow-flight/tests/flight_sql_client.rs index 94b768a13621..349da062a82d 100644 --- a/arrow-flight/tests/flight_sql_client.rs +++ b/arrow-flight/tests/flight_sql_client.rs @@ -18,14 +18,21 @@ mod common; use crate::common::fixture::TestFixture; +use crate::common::utils::make_primitive_batch; + +use arrow_array::RecordBatch; +use arrow_flight::decode::FlightRecordBatchStream; +use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::FlightServiceServer; use arrow_flight::sql::client::FlightSqlServiceClient; -use arrow_flight::sql::server::FlightSqlService; +use arrow_flight::sql::server::{FlightSqlService, PeekableFlightDataStream}; use arrow_flight::sql::{ ActionBeginTransactionRequest, ActionBeginTransactionResult, ActionEndTransactionRequest, - EndTransaction, SqlInfo, + CommandStatementIngest, EndTransaction, SqlInfo, TableDefinitionOptions, TableExistsOption, + TableNotExistOption, }; use arrow_flight::Action; +use futures::{StreamExt, TryStreamExt}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::Mutex; @@ -34,9 +41,7 @@ use uuid::Uuid; #[tokio::test] pub async fn test_begin_end_transaction() { - let test_server = FlightSqlServiceImpl { - transactions: Arc::new(Mutex::new(HashMap::new())), - }; + let test_server = FlightSqlServiceImpl::new(); let fixture = TestFixture::new(test_server.service()).await; let channel = fixture.channel().await; let mut flight_sql_client = FlightSqlServiceClient::new(channel); @@ -63,12 +68,83 @@ pub async fn test_begin_end_transaction() { .is_err()); } +#[tokio::test] +pub async fn test_execute_ingest() { + let test_server = FlightSqlServiceImpl::new(); + let fixture = TestFixture::new(test_server.service()).await; + let channel = fixture.channel().await; + let mut flight_sql_client = FlightSqlServiceClient::new(channel); + let cmd = make_ingest_command(); + let expected_rows = 10; + let batches = vec![ + make_primitive_batch(5), + make_primitive_batch(3), + make_primitive_batch(2), + ]; + let actual_rows = flight_sql_client + .execute_ingest(cmd, futures::stream::iter(batches.clone()).map(Ok)) + .await + .expect("ingest should succeed"); + assert_eq!(actual_rows, expected_rows); + // make sure the batches made it through to the server + let ingested_batches = test_server.ingested_batches.lock().await.clone(); + assert_eq!(ingested_batches, batches); +} + +#[tokio::test] +pub async fn test_execute_ingest_error() { + let test_server = FlightSqlServiceImpl::new(); + let fixture = TestFixture::new(test_server.service()).await; + let channel = fixture.channel().await; + let mut flight_sql_client = FlightSqlServiceClient::new(channel); + let cmd = make_ingest_command(); + // send an error from the client + let batches = vec![ + Ok(make_primitive_batch(5)), + Err(FlightError::NotYetImplemented( + "Client error message".to_string(), + )), + ]; + // make sure the client returns the error from the client + let err = flight_sql_client + .execute_ingest(cmd, futures::stream::iter(batches)) + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "External error: Not yet implemented: Client error message" + ); +} + +fn make_ingest_command() -> CommandStatementIngest { + CommandStatementIngest { + table_definition_options: Some(TableDefinitionOptions { + if_not_exist: TableNotExistOption::Create.into(), + if_exists: TableExistsOption::Fail.into(), + }), + table: String::from("test"), + schema: None, + catalog: None, + temporary: true, + transaction_id: None, + options: HashMap::default(), + } +} + #[derive(Clone)] pub struct FlightSqlServiceImpl { transactions: Arc>>, + ingested_batches: Arc>>, } impl FlightSqlServiceImpl { + pub fn new() -> Self { + Self { + transactions: Arc::new(Mutex::new(HashMap::new())), + ingested_batches: Arc::new(Mutex::new(Vec::new())), + } + } + /// Return an [`FlightServiceServer`] that can be used with a /// [`Server`](tonic::transport::Server) pub fn service(&self) -> FlightServiceServer { @@ -77,6 +153,12 @@ impl FlightSqlServiceImpl { } } +impl Default for FlightSqlServiceImpl { + fn default() -> Self { + Self::new() + } +} + #[tonic::async_trait] impl FlightSqlService for FlightSqlServiceImpl { type FlightService = FlightSqlServiceImpl; @@ -116,4 +198,19 @@ impl FlightSqlService for FlightSqlServiceImpl { } async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} + + async fn do_put_statement_ingest( + &self, + _ticket: CommandStatementIngest, + request: Request, + ) -> Result { + let batches: Vec = FlightRecordBatchStream::new_from_flight_data( + request.into_inner().map_err(|e| e.into()), + ) + .try_collect() + .await?; + let affected_rows = batches.iter().map(|batch| batch.num_rows() as i64).sum(); + *self.ingested_batches.lock().await.as_mut() = batches; + Ok(affected_rows) + } }