Skip to content

Commit

Permalink
Minor: move FallibleRequestStream and FallibleTonicResponseStream
Browse files Browse the repository at this point in the history
… to a module (#6258)

* Minor: move FallibleRequestStream and FallibleTonicResponseStream to their own modules

* Improve documentation and add links
  • Loading branch information
alamb authored Aug 20, 2024
1 parent 23b6ff9 commit 0bbad36
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 106 deletions.
107 changes: 2 additions & 105 deletions arrow-flight/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
// specific language governing permissions and limitations
// under the License.

use std::{pin::Pin, task::Poll};

use crate::{
decode::FlightRecordBatchStream,
flight_service_client::FlightServiceClient,
Expand All @@ -28,16 +26,15 @@ use crate::{
use arrow_schema::Schema;
use bytes::Bytes;
use futures::{
channel::oneshot::{Receiver, Sender},
future::ready,
ready,
stream::{self, BoxStream},
FutureExt, Stream, StreamExt, TryStreamExt,
Stream, StreamExt, TryStreamExt,
};
use prost::Message;
use tonic::{metadata::MetadataMap, transport::Channel};

use crate::error::{FlightError, Result};
use crate::streams::{FallibleRequestStream, FallibleTonicResponseStream};

/// A "Mid level" [Apache Arrow Flight](https://arrow.apache.org/docs/format/Flight.html) client.
///
Expand Down Expand Up @@ -674,103 +671,3 @@ impl FlightClient {
request
}
}

/// Wrapper around fallible stream such that when
/// it encounters an error it uses the oneshot sender to
/// notify the error and stop any further streaming. See `do_put` or
/// `do_exchange` for it's uses.
pub(crate) struct FallibleRequestStream<T, E> {
/// sender to notify error
sender: Option<Sender<E>>,
/// fallible stream
fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
}

impl<T, E> FallibleRequestStream<T, E> {
pub(crate) fn new(
sender: Sender<E>,
fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
) -> Self {
Self {
sender: Some(sender),
fallible_stream,
}
}
}

impl<T, E> Stream for FallibleRequestStream<T, E> {
type Item = T;

fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let pinned = self.get_mut();
let mut request_streams = pinned.fallible_stream.as_mut();
match ready!(request_streams.poll_next_unpin(cx)) {
Some(Ok(data)) => Poll::Ready(Some(data)),
Some(Err(e)) => {
// in theory this should only ever be called once
// as this stream should not be polled again after returning
// None, however we still check for None to be safe
if let Some(sender) = pinned.sender.take() {
// an error means the other end of the channel is not around
// to receive the error, so ignore it
let _ = sender.send(e);
}
Poll::Ready(None)
}
None => Poll::Ready(None),
}
}
}

/// Wrapper for a tonic response stream that can produce a tonic
/// error. This is tied to a oneshot receiver which can be notified
/// of other errors. When it receives an error through receiver
/// end, it prioritises that error to be sent back. See `do_put` or
/// `do_exchange` for it's uses
struct FallibleTonicResponseStream<T> {
/// Receiver for FlightError
receiver: Receiver<FlightError>,
/// Tonic response stream
response_stream:
Pin<Box<dyn Stream<Item = std::result::Result<T, tonic::Status>> + Send + 'static>>,
}

impl<T> FallibleTonicResponseStream<T> {
fn new(
receiver: Receiver<FlightError>,
response_stream: Pin<
Box<dyn Stream<Item = std::result::Result<T, tonic::Status>> + Send + 'static>,
>,
) -> Self {
Self {
receiver,
response_stream,
}
}
}

impl<T> Stream for FallibleTonicResponseStream<T> {
type Item = Result<T>;

fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let pinned = self.get_mut();
let receiver = &mut pinned.receiver;
// Prioritise sending the error that's been notified over
// polling the response_stream
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
};

match ready!(pinned.response_stream.poll_next_unpin(cx)) {
Some(Ok(res)) => Poll::Ready(Some(Ok(res))),
Some(Err(status)) => Poll::Ready(Some(Err(FlightError::Tonic(status)))),
None => Poll::Ready(None),
}
}
}
1 change: 1 addition & 0 deletions arrow-flight/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ pub mod utils;

#[cfg(feature = "flight-sql-experimental")]
pub mod sql;
mod streams;

use flight_descriptor::DescriptorType;

Expand Down
2 changes: 1 addition & 1 deletion arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use std::collections::HashMap;
use std::str::FromStr;
use tonic::metadata::AsciiMetadataKey;

use crate::client::FallibleRequestStream;
use crate::decode::FlightRecordBatchStream;
use crate::encode::FlightDataEncoderBuilder;
use crate::error::FlightError;
Expand All @@ -43,6 +42,7 @@ use crate::sql::{
CommandStatementIngest, CommandStatementQuery, CommandStatementUpdate,
DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo,
};
use crate::streams::FallibleRequestStream;
use crate::trailers::extract_lazy_trailers;
use crate::{
Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse,
Expand Down
134 changes: 134 additions & 0 deletions arrow-flight/src/streams.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! [`FallibleRequestStream`] and [`FallibleTonicResponseStream`] adapters
use crate::error::FlightError;
use futures::{
channel::oneshot::{Receiver, Sender},
FutureExt, Stream, StreamExt,
};
use std::pin::Pin;
use std::task::{ready, Poll};

/// Wrapper around a fallible stream (one that returns errors) that makes it infallible.
///
/// Any errors encountered in the stream are ignored are sent to the provided
/// oneshot sender.
///
/// This can be used to accept a stream of `Result<_>` from a client API and send
/// them to the remote server that wants only the successful results.
pub(crate) struct FallibleRequestStream<T, E> {
/// sender to notify error
sender: Option<Sender<E>>,
/// fallible stream
fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
}

impl<T, E> FallibleRequestStream<T, E> {
pub(crate) fn new(
sender: Sender<E>,
fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
) -> Self {
Self {
sender: Some(sender),
fallible_stream,
}
}
}

impl<T, E> Stream for FallibleRequestStream<T, E> {
type Item = T;

fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let pinned = self.get_mut();
let mut request_streams = pinned.fallible_stream.as_mut();
match ready!(request_streams.poll_next_unpin(cx)) {
Some(Ok(data)) => Poll::Ready(Some(data)),
Some(Err(e)) => {
// in theory this should only ever be called once
// as this stream should not be polled again after returning
// None, however we still check for None to be safe
if let Some(sender) = pinned.sender.take() {
// an error means the other end of the channel is not around
// to receive the error, so ignore it
let _ = sender.send(e);
}
Poll::Ready(None)
}
None => Poll::Ready(None),
}
}
}

/// Wrapper for a tonic response stream that maps errors to `FlightError` and
/// returns errors from a oneshot channel into the stream.
///
/// The user of this stream can inject an error into the response stream using
/// the one shot receiver. This is used to propagate errors in
/// [`FlightClient::do_put`] and [`FlightClient::do_exchange`] from the client
/// provided input stream to the response stream.
///
/// # Error Priority
/// Error from the receiver are prioritised over the response stream.
///
/// [`FlightClient::do_put`]: crate::FlightClient::do_put
/// [`FlightClient::do_exchange`]: crate::FlightClient::do_exchange
pub(crate) struct FallibleTonicResponseStream<T> {
/// Receiver for FlightError
receiver: Receiver<FlightError>,
/// Tonic response stream
response_stream: Pin<Box<dyn Stream<Item = Result<T, tonic::Status>> + Send + 'static>>,
}

impl<T> FallibleTonicResponseStream<T> {
pub(crate) fn new(
receiver: Receiver<FlightError>,
response_stream: Pin<Box<dyn Stream<Item = Result<T, tonic::Status>> + Send + 'static>>,
) -> Self {
Self {
receiver,
response_stream,
}
}
}

impl<T> Stream for FallibleTonicResponseStream<T> {
type Item = Result<T, FlightError>;

fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let pinned = self.get_mut();
let receiver = &mut pinned.receiver;
// Prioritise sending the error that's been notified over
// polling the response_stream
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
};

match ready!(pinned.response_stream.poll_next_unpin(cx)) {
Some(Ok(res)) => Poll::Ready(Some(Ok(res))),
Some(Err(status)) => Poll::Ready(Some(Err(FlightError::Tonic(status)))),
None => Poll::Ready(None),
}
}
}

0 comments on commit 0bbad36

Please sign in to comment.