Skip to content

Commit

Permalink
Receive schema from flight data. (#1670)
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya authored May 8, 2022
1 parent 8b1ad09 commit 03ab9a3
Showing 1 changed file with 31 additions and 20 deletions.
51 changes: 31 additions & 20 deletions integration-testing/src/flight_client_scenarios/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use arrow_flight::{
use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt};
use tonic::{Request, Streaming};

use arrow::datatypes::Schema;
use std::sync::Arc;

type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
Expand Down Expand Up @@ -61,7 +62,7 @@ pub async fn run_scenario(host: &str, port: u16, path: &str) -> Result {
batches.clone(),
)
.await?;
verify_data(client, descriptor, schema, &batches).await?;
verify_data(client, descriptor, &batches).await?;

Ok(())
}
Expand Down Expand Up @@ -144,7 +145,6 @@ async fn send_batch(
async fn verify_data(
mut client: Client,
descriptor: FlightDescriptor,
expected_schema: SchemaRef,
expected_data: &[RecordBatch],
) -> Result {
let resp = client.get_flight_info(Request::new(descriptor)).await?;
Expand All @@ -164,13 +164,7 @@ async fn verify_data(
"No locations returned from Flight server",
);
for location in endpoint.location {
consume_flight_location(
location,
ticket.clone(),
expected_data,
expected_schema.clone(),
)
.await?;
consume_flight_location(location, ticket.clone(), expected_data).await?;
}
}

Expand All @@ -181,7 +175,6 @@ async fn consume_flight_location(
location: Location,
ticket: Ticket,
expected_data: &[RecordBatch],
schema: SchemaRef,
) -> Result {
let mut location = location;
// The other Flight implementations use the `grpc+tcp` scheme, but the Rust http libs
Expand All @@ -193,29 +186,33 @@ async fn consume_flight_location(
let resp = client.do_get(ticket).await?;
let mut resp = resp.into_inner();

// We already have the schema from the FlightInfo, but the server sends it again as the
// first FlightData. Ignore this one.
let _schema_again = resp.next().await.unwrap();
let flight_schema = receive_schema_flight_data(&mut resp)
.await
.unwrap_or_else(|| panic!("Failed to receive flight schema"));
let actual_schema = Arc::new(flight_schema);

let mut dictionaries_by_id = HashMap::new();

for (counter, expected_batch) in expected_data.iter().enumerate() {
let data =
receive_batch_flight_data(&mut resp, schema.clone(), &mut dictionaries_by_id)
.await
.unwrap_or_else(|| {
panic!(
let data = receive_batch_flight_data(
&mut resp,
actual_schema.clone(),
&mut dictionaries_by_id,
)
.await
.unwrap_or_else(|| {
panic!(
"Got fewer batches than expected, received so far: {} expected: {}",
counter,
expected_data.len(),
)
});
});

let metadata = counter.to_string().into_bytes();
assert_eq!(metadata, data.app_metadata);

let actual_batch =
flight_data_to_arrow_batch(&data, schema.clone(), &dictionaries_by_id)
flight_data_to_arrow_batch(&data, actual_schema.clone(), &dictionaries_by_id)
.expect("Unable to convert flight data to Arrow batch");

assert_eq!(expected_batch.schema(), actual_batch.schema());
Expand All @@ -242,6 +239,20 @@ async fn consume_flight_location(
Ok(())
}

async fn receive_schema_flight_data(resp: &mut Streaming<FlightData>) -> Option<Schema> {
let data = resp.next().await?.ok()?;
let message = arrow::ipc::root_as_message(&data.data_header[..])
.expect("Error parsing message");

// message header is a Schema, so read it
let ipc_schema: ipc::Schema = message
.header_as_schema()
.expect("Unable to read IPC message as schema");
let schema = ipc::convert::fb_to_schema(ipc_schema);

Some(schema)
}

async fn receive_batch_flight_data(
resp: &mut Streaming<FlightData>,
schema: SchemaRef,
Expand Down

0 comments on commit 03ab9a3

Please sign in to comment.