Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: set FlightDescriptor on FlightDataEncoderBuilder #4101

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions arrow-flight/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll};

use crate::{error::Result, FlightData, SchemaAsIpc};
use crate::{error::Result, FlightData, FlightDescriptor, SchemaAsIpc};
use arrow_array::{ArrayRef, RecordBatch, RecordBatchOptions};
use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef};
Expand Down Expand Up @@ -72,6 +72,8 @@ pub struct FlightDataEncoderBuilder {
app_metadata: Bytes,
/// Optional schema, if known before data.
schema: Option<SchemaRef>,
/// Optional flight descriptor, if known before data.
descriptor: Option<FlightDescriptor>,
}

/// Default target size for encoded [`FlightData`].
Expand All @@ -87,6 +89,7 @@ impl Default for FlightDataEncoderBuilder {
options: IpcWriteOptions::default(),
app_metadata: Bytes::new(),
schema: None,
descriptor: None,
}
}
}
Expand Down Expand Up @@ -134,6 +137,15 @@ impl FlightDataEncoderBuilder {
self
}

/// Specify a flight descriptor in the first FlightData message.
pub fn with_flight_descriptor(
alamb marked this conversation as resolved.
Show resolved Hide resolved
mut self,
descriptor: Option<FlightDescriptor>,
) -> Self {
self.descriptor = descriptor;
self
}

/// Return a [`Stream`](futures::Stream) of [`FlightData`],
/// consuming self. More details on [`FlightDataEncoder`]
pub fn build<S>(self, input: S) -> FlightDataEncoder
Expand All @@ -145,6 +157,7 @@ impl FlightDataEncoderBuilder {
options,
app_metadata,
schema,
descriptor,
} = self;

FlightDataEncoder::new(
Expand All @@ -153,6 +166,7 @@ impl FlightDataEncoderBuilder {
max_flight_data_size,
options,
app_metadata,
descriptor,
)
}
}
Expand All @@ -176,6 +190,8 @@ pub struct FlightDataEncoder {
queue: VecDeque<FlightData>,
/// Is this stream done (inner is empty or errored)
done: bool,
/// cleared after the first FlightData message is sent
descriptor: Option<FlightDescriptor>,
}

impl FlightDataEncoder {
Expand All @@ -185,6 +201,7 @@ impl FlightDataEncoder {
max_flight_data_size: usize,
options: IpcWriteOptions,
app_metadata: Bytes,
descriptor: Option<FlightDescriptor>,
) -> Self {
let mut encoder = Self {
inner,
Expand All @@ -194,17 +211,22 @@ impl FlightDataEncoder {
app_metadata: Some(app_metadata),
queue: VecDeque::new(),
done: false,
descriptor,
};

// If schema is known up front, enqueue it immediately
if let Some(schema) = schema {
encoder.encode_schema(&schema);
}

encoder
}

/// Place the `FlightData` in the queue to send
fn queue_message(&mut self, data: FlightData) {
fn queue_message(&mut self, mut data: FlightData) {
if let Some(descriptor) = self.descriptor.take() {
data.flight_descriptor = Some(descriptor);
}
self.queue.push_back(data);
}

Expand Down
25 changes: 25 additions & 0 deletions arrow-flight/tests/encode_decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ use std::{collections::HashMap, sync::Arc};
use arrow_array::types::Int32Type;
use arrow_array::{ArrayRef, DictionaryArray, Float64Array, RecordBatch, UInt8Array};
use arrow_cast::pretty::pretty_format_batches;
use arrow_flight::flight_descriptor::DescriptorType;
use arrow_flight::FlightDescriptor;
use arrow_flight::{
decode::{DecodedPayload, FlightDataDecoder, FlightRecordBatchStream},
encode::FlightDataEncoderBuilder,
Expand Down Expand Up @@ -136,6 +138,29 @@ async fn test_zero_batches_schema_specified() {
assert_eq!(decoder.schema(), Some(&schema));
}

#[tokio::test]
async fn test_with_flight_descriptor() {
let stream = futures::stream::iter(vec![Ok(make_dictionary_batch(5))]);
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));

let descriptor = Some(FlightDescriptor {
r#type: DescriptorType::Path.into(),
path: vec!["table_name".to_string()],
cmd: Bytes::default(),
});

let encoder = FlightDataEncoderBuilder::default()
.with_schema(schema.clone())
.with_flight_descriptor(descriptor.clone());

let mut encoder = encoder.build(stream);

// First batch should be the schema
let first_batch = encoder.next().await.unwrap().unwrap();

assert_eq!(first_batch.flight_descriptor, descriptor);
}

#[tokio::test]
async fn test_zero_batches_dictionary_schema_specified() {
let schema = Arc::new(Schema::new(vec![
Expand Down