From 03ab9a3a60445431e248972acea3c2775858a706 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 8 May 2022 13:25:31 -0700 Subject: [PATCH] Receive schema from flight data. (#1670) --- .../integration_test.rs | 51 +++++++++++-------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/integration-testing/src/flight_client_scenarios/integration_test.rs b/integration-testing/src/flight_client_scenarios/integration_test.rs index 703a0f9cfba7..fa24952947e3 100644 --- a/integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/integration-testing/src/flight_client_scenarios/integration_test.rs @@ -32,6 +32,7 @@ use arrow_flight::{ use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt}; use tonic::{Request, Streaming}; +use arrow::datatypes::Schema; use std::sync::Arc; type Error = Box; @@ -61,7 +62,7 @@ pub async fn run_scenario(host: &str, port: u16, path: &str) -> Result { batches.clone(), ) .await?; - verify_data(client, descriptor, schema, &batches).await?; + verify_data(client, descriptor, &batches).await?; Ok(()) } @@ -144,7 +145,6 @@ async fn send_batch( async fn verify_data( mut client: Client, descriptor: FlightDescriptor, - expected_schema: SchemaRef, expected_data: &[RecordBatch], ) -> Result { let resp = client.get_flight_info(Request::new(descriptor)).await?; @@ -164,13 +164,7 @@ async fn verify_data( "No locations returned from Flight server", ); for location in endpoint.location { - consume_flight_location( - location, - ticket.clone(), - expected_data, - expected_schema.clone(), - ) - .await?; + consume_flight_location(location, ticket.clone(), expected_data).await?; } } @@ -181,7 +175,6 @@ async fn consume_flight_location( location: Location, ticket: Ticket, expected_data: &[RecordBatch], - schema: SchemaRef, ) -> Result { let mut location = location; // The other Flight implementations use the `grpc+tcp` scheme, but the Rust http libs @@ -193,29 +186,33 @@ async fn consume_flight_location( let resp = client.do_get(ticket).await?; let mut resp = resp.into_inner(); - // We already have the schema from the FlightInfo, but the server sends it again as the - // first FlightData. Ignore this one. - let _schema_again = resp.next().await.unwrap(); + let flight_schema = receive_schema_flight_data(&mut resp) + .await + .unwrap_or_else(|| panic!("Failed to receive flight schema")); + let actual_schema = Arc::new(flight_schema); let mut dictionaries_by_id = HashMap::new(); for (counter, expected_batch) in expected_data.iter().enumerate() { - let data = - receive_batch_flight_data(&mut resp, schema.clone(), &mut dictionaries_by_id) - .await - .unwrap_or_else(|| { - panic!( + let data = receive_batch_flight_data( + &mut resp, + actual_schema.clone(), + &mut dictionaries_by_id, + ) + .await + .unwrap_or_else(|| { + panic!( "Got fewer batches than expected, received so far: {} expected: {}", counter, expected_data.len(), ) - }); + }); let metadata = counter.to_string().into_bytes(); assert_eq!(metadata, data.app_metadata); let actual_batch = - flight_data_to_arrow_batch(&data, schema.clone(), &dictionaries_by_id) + flight_data_to_arrow_batch(&data, actual_schema.clone(), &dictionaries_by_id) .expect("Unable to convert flight data to Arrow batch"); assert_eq!(expected_batch.schema(), actual_batch.schema()); @@ -242,6 +239,20 @@ async fn consume_flight_location( Ok(()) } +async fn receive_schema_flight_data(resp: &mut Streaming) -> Option { + let data = resp.next().await?.ok()?; + let message = arrow::ipc::root_as_message(&data.data_header[..]) + .expect("Error parsing message"); + + // message header is a Schema, so read it + let ipc_schema: ipc::Schema = message + .header_as_schema() + .expect("Unable to read IPC message as schema"); + let schema = ipc::convert::fb_to_schema(ipc_schema); + + Some(schema) +} + async fn receive_batch_flight_data( resp: &mut Streaming, schema: SchemaRef,