Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose bulk ingest in flight sql client and server #6201

Merged
merged 12 commits into from
Aug 15, 2024
14 changes: 11 additions & 3 deletions arrow-flight/examples/flight_sql_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ use arrow_flight::sql::{
ActionEndTransactionRequest, Any, CommandGetCatalogs, CommandGetCrossReference,
CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys,
CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo,
CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery,
CommandStatementSubstraitPlan, CommandStatementUpdate, Nullable, ProstMessageExt, Searchable,
SqlInfo, TicketStatementQuery, XdbcDataType,
CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementIngest,
CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, Nullable,
ProstMessageExt, Searchable, SqlInfo, TicketStatementQuery, XdbcDataType,
};
use arrow_flight::utils::batches_to_flight_data;
use arrow_flight::{
Expand Down Expand Up @@ -615,6 +615,14 @@ impl FlightSqlService for FlightSqlServiceImpl {
Ok(FAKE_UPDATE_RESULT)
}

async fn do_put_statement_ingest(
&self,
_ticket: CommandStatementIngest,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Ok(FAKE_UPDATE_RESULT)
}

async fn do_put_substrait_plan(
&self,
_ticket: CommandStatementSubstraitPlan,
Expand Down
37 changes: 33 additions & 4 deletions arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ use crate::sql::{
CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
CommandStatementQuery, CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult,
ProstMessageExt, SqlInfo,
CommandStatementIngest, CommandStatementQuery, CommandStatementUpdate,
DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo,
};
use crate::trailers::extract_lazy_trailers;
use crate::{
Expand All @@ -53,10 +53,10 @@ use arrow_ipc::convert::fb_to_schema;
use arrow_ipc::reader::read_record_batch;
use arrow_ipc::{root_as_message, MessageHeader};
use arrow_schema::{ArrowError, Schema, SchemaRef};
use futures::{stream, TryStreamExt};
use futures::{stream, StreamExt, TryStreamExt};
use prost::Message;
use tonic::transport::Channel;
use tonic::{IntoRequest, Streaming};
use tonic::{IntoRequest, IntoStreamingRequest, Streaming};

/// A FlightSQLServiceClient is an endpoint for retrieving or storing Arrow data
/// by FlightSQL protocol.
Expand Down Expand Up @@ -227,6 +227,35 @@ impl FlightSqlServiceClient<Channel> {
Ok(result.record_count)
}

/// Execute a bulk ingest on the server and return the number of records added
pub async fn execute_ingest(
&mut self,
command: CommandStatementIngest,
batches: Vec<RecordBatch>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see two potential problems with this implementation:

  1. It requires the client to buffer the entire dataset into memory
  2. It is not consistent with other APIs that send RecordBatches to the server, such as do_put
    /// Push a stream to the flight service associated with a particular flight stream.
    pub async fn do_put(
        &mut self,
        request: impl tonic::IntoStreamingRequest<Message = FlightData>,
    ) -> Result<Streaming<PutResult>, ArrowError> {

I think it is non obvious that this craziness means "RecordBatchStream":

        request: impl tonic::IntoStreamingRequest<Message = FlightData>,

So in other words, I think it would be better if execute_ingest did the same thing (rather than internally encoding the stream

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a shot at making this change, and here is what I came up with, for your consideration: djanderson#1

) -> Result<i64, ArrowError> {
let descriptor = FlightDescriptor::new_cmd(command.as_any().encode_to_vec());
let flight_data_encoder = FlightDataEncoderBuilder::new()
.with_flight_descriptor(Some(descriptor))
.build(stream::iter(batches).map(Ok));
// Safe unwrap, explicitly wrapped on line above.
let flight_data = flight_data_encoder.map(|fd| fd.unwrap());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is no longer true, but is it still reasonable to unwrap or should we raise the error, or fuse the input stream, or...?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should return the error from the client. Let me see if I can figure it out....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I fixed that in 22a4a1d -- but it was quite fiddly

let req = self.set_request_headers(flight_data.into_streaming_request())?;
let mut result = self
.flight_client
.do_put(req)
.await
.map_err(status_to_arrow_error)?
.into_inner();
let result = result
.message()
.await
.map_err(status_to_arrow_error)?
.unwrap();
let any = Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
let result: DoPutUpdateResult = any.unpack()?.unwrap();
Ok(result.record_count)
}

/// Request a list of catalogs as tabular FlightInfo results
pub async fn get_catalogs(&mut self) -> Result<FlightInfo, ArrowError> {
self.get_flight_info_for_command(CommandGetCatalogs {})
Expand Down
8 changes: 7 additions & 1 deletion arrow-flight/src/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ mod gen {
}

pub use gen::action_end_transaction_request::EndTransaction;
pub use gen::command_statement_ingest::table_definition_options::{
TableExistsOption, TableNotExistOption,
};
pub use gen::command_statement_ingest::TableDefinitionOptions;
pub use gen::ActionBeginSavepointRequest;
pub use gen::ActionBeginSavepointResult;
pub use gen::ActionBeginTransactionRequest;
Expand All @@ -74,6 +78,7 @@ pub use gen::CommandGetTables;
pub use gen::CommandGetXdbcTypeInfo;
pub use gen::CommandPreparedStatementQuery;
pub use gen::CommandPreparedStatementUpdate;
pub use gen::CommandStatementIngest;
pub use gen::CommandStatementQuery;
pub use gen::CommandStatementSubstraitPlan;
pub use gen::CommandStatementUpdate;
Expand Down Expand Up @@ -250,11 +255,12 @@ prost_message_ext!(
CommandGetXdbcTypeInfo,
CommandPreparedStatementQuery,
CommandPreparedStatementUpdate,
CommandStatementIngest,
CommandStatementQuery,
CommandStatementSubstraitPlan,
CommandStatementUpdate,
DoPutUpdateResult,
DoPutPreparedStatementResult,
DoPutUpdateResult,
TicketStatementQuery,
);

Expand Down
25 changes: 22 additions & 3 deletions arrow-flight/src/sql/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ use super::{
CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate,
DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo,
TicketStatementQuery,
CommandStatementIngest, CommandStatementQuery, CommandStatementSubstraitPlan,
CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt,
SqlInfo, TicketStatementQuery,
};
use crate::{
flight_service_server::FlightService, gen::PollInfo, Action, ActionType, Criteria, Empty,
Expand Down Expand Up @@ -397,6 +397,17 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
))
}

/// Execute a bulk ingestion.
async fn do_put_statement_ingest(
&self,
_ticket: CommandStatementIngest,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_statement_ingest has no default implementation",
))
}

/// Bind parameters to given prepared statement.
///
/// Returns an opaque handle that the client should pass
Expand Down Expand Up @@ -713,6 +724,14 @@ where
})]);
Ok(Response::new(Box::pin(output)))
}
Command::CommandStatementIngest(command) => {
let record_count = self.do_put_statement_ingest(command, request).await?;
let result = DoPutUpdateResult { record_count };
let output = futures::stream::iter(vec![Ok(PutResult {
app_metadata: result.as_any().encode_to_vec().into(),
})]);
Ok(Response::new(Box::pin(output)))
}
Command::CommandPreparedStatementQuery(command) => {
let result = self
.do_put_prepared_statement_query(command, request)
Expand Down
1 change: 1 addition & 0 deletions arrow-flight/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
pub mod fixture;
pub mod server;
pub mod trailers_layer;
pub mod utils;
118 changes: 118 additions & 0 deletions arrow-flight/tests/common/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Common utilities for testing flight clients and servers

use std::sync::Arc;

use arrow_array::{
types::Int32Type, ArrayRef, BinaryViewArray, DictionaryArray, Float64Array, RecordBatch,
StringViewArray, UInt8Array,
};
use arrow_schema::{DataType, Field, Schema};

/// Make a primitive batch for testing
///
/// Example:
/// i: 0, 1, None, 3, 4
/// f: 5.0, 4.0, None, 2.0, 1.0
#[allow(dead_code)]
pub fn make_primitive_batch(num_rows: usize) -> RecordBatch {
let i: UInt8Array = (0..num_rows)
.map(|i| {
if i == num_rows / 2 {
None
} else {
Some(i.try_into().unwrap())
}
})
.collect();

let f: Float64Array = (0..num_rows)
.map(|i| {
if i == num_rows / 2 {
None
} else {
Some((num_rows - i) as f64)
}
})
.collect();

RecordBatch::try_from_iter(vec![("i", Arc::new(i) as ArrayRef), ("f", Arc::new(f))]).unwrap()
}

/// Make a dictionary batch for testing
///
/// Example:
/// a: value0, value1, value2, None, value1, value2
#[allow(dead_code)]
pub fn make_dictionary_batch(num_rows: usize) -> RecordBatch {
let values: Vec<_> = (0..num_rows)
.map(|i| {
if i == num_rows / 2 {
None
} else {
// repeat some values for low cardinality
let v = i / 3;
Some(format!("value{v}"))
}
})
.collect();

let a: DictionaryArray<Int32Type> = values
.iter()
.map(|s| s.as_ref().map(|s| s.as_str()))
.collect();

RecordBatch::try_from_iter(vec![("a", Arc::new(a) as ArrayRef)]).unwrap()
}

#[allow(dead_code)]
pub fn make_view_batches(num_rows: usize) -> RecordBatch {
const LONG_TEST_STRING: &str =
"This is a long string to make sure binary view array handles it";
let schema = Schema::new(vec![
Field::new("field1", DataType::BinaryView, true),
Field::new("field2", DataType::Utf8View, true),
]);

let string_view_values: Vec<Option<&str>> = (0..num_rows)
.map(|i| match i % 3 {
0 => None,
1 => Some("foo"),
2 => Some(LONG_TEST_STRING),
_ => unreachable!(),
})
.collect();

let bin_view_values: Vec<Option<&[u8]>> = (0..num_rows)
.map(|i| match i % 3 {
0 => None,
1 => Some("bar".as_bytes()),
2 => Some(LONG_TEST_STRING.as_bytes()),
_ => unreachable!(),
})
.collect();

let binary_array = BinaryViewArray::from_iter(bin_view_values);
let utf8_array = StringViewArray::from_iter(string_view_values);
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(binary_array), Arc::new(utf8_array)],
)
.unwrap()
}
Loading
Loading