From a842f3a8bbbfe6145c1935b39264be85272bbe6a Mon Sep 17 00:00:00 2001 From: Chris Staite <137425734+chrisstaite-menlo@users.noreply.github.com> Date: Thu, 11 Apr 2024 08:42:17 +0100 Subject: [PATCH] Connection Manager Rewrite (#806) The connection manager written in grpc_utils made incorrect assumptions about how the tonic and tower implementations were written and is not suitable for maintaining multiple connections and ensuring stability. Completely rewrite this to manage the tonic::Channel for each tonic::Endpoint itself to make a simpler external API and ensure that connection errors are handled correctly. This is performed by using a single worker loop that manages all of the connections and wrapping each connection to inform state to the worker. --- nativelink-config/src/schedulers.rs | 5 + nativelink-config/src/stores.rs | 5 + nativelink-scheduler/src/grpc_scheduler.rs | 45 +- nativelink-store/src/grpc_store.rs | 138 +++--- nativelink-util/BUILD.bazel | 2 +- nativelink-util/src/connection_manager.rs | 480 +++++++++++++++++++++ nativelink-util/src/grpc_utils.rs | 111 ----- nativelink-util/src/lib.rs | 2 +- nativelink-util/src/retry.rs | 2 +- 9 files changed, 592 insertions(+), 198 deletions(-) create mode 100644 nativelink-util/src/connection_manager.rs delete mode 100644 nativelink-util/src/grpc_utils.rs diff --git a/nativelink-config/src/schedulers.rs b/nativelink-config/src/schedulers.rs index ff447dca4..ddd522baf 100644 --- a/nativelink-config/src/schedulers.rs +++ b/nativelink-config/src/schedulers.rs @@ -140,6 +140,11 @@ pub struct GrpcScheduler { /// request is queued. #[serde(default)] pub max_concurrent_requests: usize, + + /// The number of connections to make to each specified endpoint to balance + /// the load over multiple TCP connections. Default 1. + #[serde(default)] + pub connections_per_endpoint: usize, } #[derive(Deserialize, Debug)] diff --git a/nativelink-config/src/stores.rs b/nativelink-config/src/stores.rs index e82a5dd9e..0a9daa983 100644 --- a/nativelink-config/src/stores.rs +++ b/nativelink-config/src/stores.rs @@ -560,6 +560,11 @@ pub struct GrpcStore { /// request is queued. #[serde(default)] pub max_concurrent_requests: usize, + + /// The number of connections to make to each specified endpoint to balance + /// the load over multiple TCP connections. Default 1. + #[serde(default)] + pub connections_per_endpoint: usize, } /// The possible error codes that might occur on an upstream request. diff --git a/nativelink-scheduler/src/grpc_scheduler.rs b/nativelink-scheduler/src/grpc_scheduler.rs index 155a67f4f..4373b8d9f 100644 --- a/nativelink-scheduler/src/grpc_scheduler.rs +++ b/nativelink-scheduler/src/grpc_scheduler.rs @@ -30,7 +30,7 @@ use nativelink_proto::google::longrunning::Operation; use nativelink_util::action_messages::{ ActionInfo, ActionInfoHashKey, ActionState, DEFAULT_EXECUTION_PRIORITY, }; -use nativelink_util::grpc_utils::ConnectionManager; +use nativelink_util::connection_manager::ConnectionManager; use nativelink_util::retry::{Retrier, RetryResult}; use nativelink_util::tls_utils; use parking_lot::Mutex; @@ -72,16 +72,20 @@ impl GrpcScheduler { jitter_fn: Box Duration + Send + Sync>, ) -> Result { let endpoint = tls_utils::endpoint(&config.endpoint)?; + let jitter_fn = Arc::new(jitter_fn); Ok(Self { platform_property_managers: Mutex::new(HashMap::new()), retrier: Retrier::new( Arc::new(|duration| Box::pin(sleep(duration))), - Arc::new(jitter_fn), + jitter_fn.clone(), config.retry.to_owned(), ), connection_manager: ConnectionManager::new( std::iter::once(endpoint), + config.connections_per_endpoint, config.max_concurrent_requests, + config.retry.to_owned(), + jitter_fn, ), }) } @@ -164,16 +168,17 @@ impl ActionScheduler for GrpcScheduler { self.perform_request(instance_name, |instance_name| async move { // Not in the cache, lookup the capabilities with the upstream. - let (connection, channel) = self.connection_manager.get_connection().await; + let channel = self + .connection_manager + .connection() + .await + .err_tip(|| "in get_platform_property_manager()")?; let capabilities_result = CapabilitiesClient::new(channel) .get_capabilities(GetCapabilitiesRequest { instance_name: instance_name.to_string(), }) .await .err_tip(|| "Retrieving upstream GrpcScheduler capabilities"); - if let Err(err) = &capabilities_result { - connection.on_error(err); - } let capabilities = capabilities_result?.into_inner(); let platform_property_manager = Arc::new(PlatformPropertyManager::new( capabilities @@ -220,15 +225,15 @@ impl ActionScheduler for GrpcScheduler { }; let result_stream = self .perform_request(request, |request| async move { - let (connection, channel) = self.connection_manager.get_connection().await; - let result = ExecutionClient::new(channel) + let channel = self + .connection_manager + .connection() + .await + .err_tip(|| "in add_action()")?; + ExecutionClient::new(channel) .execute(Request::new(request)) .await - .err_tip(|| "Sending action to upstream scheduler"); - if let Err(err) = &result { - connection.on_error(err); - } - result + .err_tip(|| "Sending action to upstream scheduler") }) .await? .into_inner(); @@ -244,15 +249,15 @@ impl ActionScheduler for GrpcScheduler { }; let result_stream = self .perform_request(request, |request| async move { - let (connection, channel) = self.connection_manager.get_connection().await; - let result = ExecutionClient::new(channel) + let channel = self + .connection_manager + .connection() + .await + .err_tip(|| "in find_existing_action()")?; + ExecutionClient::new(channel) .wait_execution(Request::new(request)) .await - .err_tip(|| "While getting wait_execution stream"); - if let Err(err) = &result { - connection.on_error(err); - } - result + .err_tip(|| "While getting wait_execution stream") }) .and_then(|result_stream| Self::stream_state(result_stream.into_inner())) .await; diff --git a/nativelink-store/src/grpc_store.rs b/nativelink-store/src/grpc_store.rs index 47cbea2cb..368f6bfa6 100644 --- a/nativelink-store/src/grpc_store.rs +++ b/nativelink-store/src/grpc_store.rs @@ -19,7 +19,7 @@ use std::time::Duration; use async_trait::async_trait; use bytes::BytesMut; use futures::stream::{unfold, FuturesUnordered}; -use futures::{future, Future, Stream, StreamExt, TryStreamExt}; +use futures::{future, Future, Stream, StreamExt, TryFutureExt, TryStreamExt}; use nativelink_error::{error_if, make_input_err, Error, ResultExt}; use nativelink_proto::build::bazel::remote::execution::v2::action_cache_client::ActionCacheClient; use nativelink_proto::build::bazel::remote::execution::v2::content_addressable_storage_client::ContentAddressableStorageClient; @@ -36,7 +36,7 @@ use nativelink_proto::google::bytestream::{ }; use nativelink_util::buf_channel::{DropCloserReadHalf, DropCloserWriteHalf}; use nativelink_util::common::DigestInfo; -use nativelink_util::grpc_utils::ConnectionManager; +use nativelink_util::connection_manager::ConnectionManager; use nativelink_util::health_utils::HealthStatusIndicator; use nativelink_util::proto_stream_utils::{ FirstStream, WriteRequestStreamWrapper, WriteState, WriteStateWrapper, @@ -98,17 +98,21 @@ impl GrpcStore { endpoints.push(endpoint); } + let jitter_fn = Arc::new(jitter_fn); Ok(GrpcStore { instance_name: config.instance_name.clone(), store_type: config.store_type, retrier: Retrier::new( Arc::new(|duration| Box::pin(sleep(duration))), - Arc::new(jitter_fn), + jitter_fn.clone(), config.retry.to_owned(), ), connection_manager: ConnectionManager::new( endpoints.into_iter(), + config.connections_per_endpoint, config.max_concurrent_requests, + config.retry.to_owned(), + jitter_fn, ), }) } @@ -145,15 +149,15 @@ impl GrpcStore { let mut request = grpc_request.into_inner(); request.instance_name.clone_from(&self.instance_name); self.perform_request(request, |request| async move { - let (connection, channel) = self.connection_manager.get_connection().await; - let result = ContentAddressableStorageClient::new(channel) + let channel = self + .connection_manager + .connection() + .await + .err_tip(|| "in find_missing_blobs")?; + ContentAddressableStorageClient::new(channel) .find_missing_blobs(Request::new(request)) .await - .err_tip(|| "in GrpcStore::find_missing_blobs"); - if let Err(err) = &result { - connection.on_error(err); - } - result + .err_tip(|| "in GrpcStore::find_missing_blobs") }) .await } @@ -170,15 +174,15 @@ impl GrpcStore { let mut request = grpc_request.into_inner(); request.instance_name.clone_from(&self.instance_name); self.perform_request(request, |request| async move { - let (connection, channel) = self.connection_manager.get_connection().await; - let result = ContentAddressableStorageClient::new(channel) + let channel = self + .connection_manager + .connection() + .await + .err_tip(|| "in batch_update_blobs")?; + ContentAddressableStorageClient::new(channel) .batch_update_blobs(Request::new(request)) .await - .err_tip(|| "in GrpcStore::batch_update_blobs"); - if let Err(err) = &result { - connection.on_error(err); - } - result + .err_tip(|| "in GrpcStore::batch_update_blobs") }) .await } @@ -195,15 +199,15 @@ impl GrpcStore { let mut request = grpc_request.into_inner(); request.instance_name.clone_from(&self.instance_name); self.perform_request(request, |request| async move { - let (connection, channel) = self.connection_manager.get_connection().await; - let result = ContentAddressableStorageClient::new(channel) + let channel = self + .connection_manager + .connection() + .await + .err_tip(|| "in batch_read_blobs")?; + ContentAddressableStorageClient::new(channel) .batch_read_blobs(Request::new(request)) .await - .err_tip(|| "in GrpcStore::batch_read_blobs"); - if let Err(err) = &result { - connection.on_error(err); - } - result + .err_tip(|| "in GrpcStore::batch_read_blobs") }) .await } @@ -220,15 +224,15 @@ impl GrpcStore { let mut request = grpc_request.into_inner(); request.instance_name.clone_from(&self.instance_name); self.perform_request(request, |request| async move { - let (connection, channel) = self.connection_manager.get_connection().await; - let result = ContentAddressableStorageClient::new(channel) + let channel = self + .connection_manager + .connection() + .await + .err_tip(|| "in get_tree")?; + ContentAddressableStorageClient::new(channel) .get_tree(Request::new(request)) .await - .err_tip(|| "in GrpcStore::get_tree"); - if let Err(err) = &result { - connection.on_error(err); - } - result + .err_tip(|| "in GrpcStore::get_tree") }) .await } @@ -247,15 +251,16 @@ impl GrpcStore { &self, request: ReadRequest, ) -> Result>, Error> { - let (connection, channel) = self.connection_manager.get_connection().await; - let result = ByteStreamClient::new(channel) + let channel = self + .connection_manager + .connection() + .await + .err_tip(|| "in read_internal")?; + let mut response = ByteStreamClient::new(channel) .read(Request::new(request)) .await - .err_tip(|| "in GrpcStore::read"); - if let Err(err) = &result { - connection.on_error(err); - } - let mut response = result?.into_inner(); + .err_tip(|| "in GrpcStore::read")? + .into_inner(); let first_response = response .message() .await @@ -300,14 +305,20 @@ impl GrpcStore { let result = self .retrier .retry(unfold(local_state, move |local_state| async move { - let (connection, channel) = self.connection_manager.get_connection().await; // The client write may occur on a separate thread and // therefore in order to share the state with it we have to // wrap it in a Mutex and retrieve it after the write // has completed. There is no way to get the value back // from the client. - let result = ByteStreamClient::new(channel) - .write(WriteStateWrapper::new(local_state.clone())) + let result = self + .connection_manager + .connection() + .and_then(|channel| async { + ByteStreamClient::new(channel) + .write(WriteStateWrapper::new(local_state.clone())) + .await + .err_tip(|| "in GrpcStore::write") + }) .await; // Get the state back from StateWrapper, this should be @@ -319,9 +330,8 @@ impl GrpcStore { RetryResult::Err(err.append("Where read_stream_error was set")) } else { // On error determine whether it is possible to retry. - match result.err_tip(|| "in GrpcStore::write") { + match result { Err(err) => { - connection.on_error(&err); if local_state_locked.can_resume() { local_state_locked.resume(); RetryResult::Retry(err) @@ -359,15 +369,15 @@ impl GrpcStore { } self.perform_request(request, |request| async move { - let (connection, channel) = self.connection_manager.get_connection().await; - let result = ByteStreamClient::new(channel) + let channel = self + .connection_manager + .connection() + .await + .err_tip(|| "in query_write_status")?; + ByteStreamClient::new(channel) .query_write_status(Request::new(request)) .await - .err_tip(|| "in GrpcStore::query_write_status"); - if let Err(err) = &result { - connection.on_error(err); - } - result + .err_tip(|| "in GrpcStore::query_write_status") }) .await } @@ -379,15 +389,15 @@ impl GrpcStore { let mut request = grpc_request.into_inner(); request.instance_name.clone_from(&self.instance_name); self.perform_request(request, |request| async move { - let (connection, channel) = self.connection_manager.get_connection().await; - let result = ActionCacheClient::new(channel) + let channel = self + .connection_manager + .connection() + .await + .err_tip(|| "in get_action_result")?; + ActionCacheClient::new(channel) .get_action_result(Request::new(request)) .await - .err_tip(|| "in GrpcStore::get_action_result"); - if let Err(err) = &result { - connection.on_error(err); - } - result + .err_tip(|| "in GrpcStore::get_action_result") }) .await } @@ -399,15 +409,15 @@ impl GrpcStore { let mut request = grpc_request.into_inner(); request.instance_name.clone_from(&self.instance_name); self.perform_request(request, |request| async move { - let (connection, channel) = self.connection_manager.get_connection().await; - let result = ActionCacheClient::new(channel) + let channel = self + .connection_manager + .connection() + .await + .err_tip(|| "in update_action_result")?; + ActionCacheClient::new(channel) .update_action_result(Request::new(request)) .await - .err_tip(|| "in GrpcStore::update_action_result"); - if let Err(err) = &result { - connection.on_error(err); - } - result + .err_tip(|| "in GrpcStore::update_action_result") }) .await } diff --git a/nativelink-util/BUILD.bazel b/nativelink-util/BUILD.bazel index b6eb4edcf..3e1c87a64 100644 --- a/nativelink-util/BUILD.bazel +++ b/nativelink-util/BUILD.bazel @@ -12,11 +12,11 @@ rust_library( "src/action_messages.rs", "src/buf_channel.rs", "src/common.rs", + "src/connection_manager.rs", "src/digest_hasher.rs", "src/evicting_map.rs", "src/fastcdc.rs", "src/fs.rs", - "src/grpc_utils.rs", "src/health_utils.rs", "src/lib.rs", "src/metrics_utils.rs", diff --git a/nativelink-util/src/connection_manager.rs b/nativelink-util/src/connection_manager.rs new file mode 100644 index 000000000..4a7248df4 --- /dev/null +++ b/nativelink-util/src/connection_manager.rs @@ -0,0 +1,480 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::VecDeque; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; + +use futures::stream::{unfold, FuturesUnordered, StreamExt}; +use futures::Future; +use nativelink_config::stores::Retry; +use nativelink_error::{make_err, Code, Error}; +use tokio::sync::{mpsc, oneshot}; +use tonic::transport::{channel, Channel, Endpoint}; +use tracing::{debug, error, info, warn}; + +use crate::retry::{self, Retrier, RetryResult}; + +/// A helper utility that enables management of a suite of connections to an +/// upstream gRPC endpoint using Tonic. +pub struct ConnectionManager { + // The channel to request connections from the worker. + worker_tx: mpsc::Sender>, +} + +/// The index into ConnectionManagerWorker::endpoints. +type EndpointIndex = usize; +/// The identifier for a given connection to a given Endpoint, used to identify +/// when a particular connection has failed or becomes available. +type ConnectionIndex = usize; + +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +struct ChannelIdentifier { + /// The index into ConnectionManagerWorker::endpoints that established this + /// Channel. + endpoint_index: EndpointIndex, + /// A unique identifier for this particular connection to the Endpoint. + connection_index: ConnectionIndex, +} + +/// The requests that can be made from a Connection to the +/// ConnectionManagerWorker such as informing it that it's been dropped or that +/// an error occurred. +enum ConnectionRequest { + /// Notify that a Connection was dropped, if it was dropped while the + /// connection was still pending, then return the pending Channel to be + /// added back to the available channels. + Dropped(Option), + /// Notify that a Connection was established, return the Channel to the + /// available channels. + Connected(EstablishedChannel), + /// Notify that there was a transport error on the given Channel, the bool + /// specifies whether the connection was in the process of being established + /// or not (i.e. whether it's been returned to available channels yet). + Error((ChannelIdentifier, bool)), +} + +/// The result of a Future that connects to a given Endpoint. This is a tuple +/// of the index into the ConnectionManagerWorker::endpoints that this +/// connection is for, the iteration of the connection and the result of the +/// connection itself. +type IndexedChannel = Result; + +/// A channel that has been established to an endpoint with some metadata around +/// it to allow identification of the Channel if it errors in order to correctly +/// remove it. +#[derive(Clone)] +struct EstablishedChannel { + /// The Channel itself that the meta data relates to. + channel: Channel, + /// The identifier of the channel in the worker. + identifier: ChannelIdentifier, +} + +/// The context of the worker used to manage all of the connections. This +/// handles reconnecting to endpoints on errors and multiple connections to a +/// given endpoint. +struct ConnectionManagerWorker { + /// The endpoints to establish Channels and the identifier of the last + /// connection attempt to that endpoint. + endpoints: Vec<(ConnectionIndex, Endpoint)>, + /// The channel used to communicate between a Connection and the worker. + connection_tx: mpsc::UnboundedSender, + /// The number of connections that are currently allowed to be made. + available_connections: usize, + /// Channels that are currently being connected. + connecting_channels: FuturesUnordered + Send>>>, + /// Connected channels that are available for use. + available_channels: VecDeque, + /// Requests for a Channel when available. + waiting_connections: VecDeque>, + /// The retry configuration for connecting to an Endpoint, on failure will + /// restart the retrier after a 1 second delay. + retrier: Retrier, +} + +/// The maximum number of queued requests to obtain a connection from the +/// worker before applying back pressure to the requestor. It makes sense to +/// keep this small since it has to wait for a response anyway. +const WORKER_BACKLOG: usize = 8; + +impl ConnectionManager { + /// Create a connection manager that creates a balance list between a given + /// set of Endpoints. This will restrict the number of concurrent requests + /// and automatically re-connect upon transport error. + pub fn new( + endpoints: impl IntoIterator, + mut connections_per_endpoint: usize, + mut max_concurrent_requests: usize, + retry: Retry, + jitter_fn: retry::JitterFn, + ) -> Self { + let (worker_tx, worker_rx) = mpsc::channel(WORKER_BACKLOG); + // The connection messages always come from sync contexts (e.g. drop) + // and therefore, we'd end up spawning for them if this was bounded + // which defeats the object since there would be no backpressure + // applied. Therefore it makes sense for this to be unbounded. + let (connection_tx, connection_rx) = mpsc::unbounded_channel(); + let endpoints = Vec::from_iter(endpoints.into_iter().map(|endpoint| (0, endpoint))); + if max_concurrent_requests == 0 { + max_concurrent_requests = usize::MAX; + } + if connections_per_endpoint == 0 { + connections_per_endpoint = 1; + } + let worker = ConnectionManagerWorker { + endpoints, + available_connections: max_concurrent_requests, + connection_tx, + connecting_channels: FuturesUnordered::new(), + available_channels: VecDeque::new(), + waiting_connections: VecDeque::new(), + retrier: Retrier::new( + Arc::new(|duration| Box::pin(tokio::time::sleep(duration))), + jitter_fn, + retry, + ), + }; + tokio::spawn(async move { + worker + .service_requests(connections_per_endpoint, worker_rx, connection_rx) + .await; + }); + Self { worker_tx } + } + + /// Get a Connection that can be used as a tonic::Channel, except it + /// performs some additional counting to reconnect on error and restrict + /// the number of concurrent connections. + pub async fn connection(&self) -> Result { + let (tx, rx) = oneshot::channel(); + self.worker_tx + .send(tx) + .await + .map_err(|err| make_err!(Code::Unavailable, "Requesting a new connection: {err:?}"))?; + rx.await + .map_err(|err| make_err!(Code::Unavailable, "Waiting for a new connection: {err:?}")) + } +} + +impl ConnectionManagerWorker { + async fn service_requests( + mut self, + connections_per_endpoint: usize, + mut worker_rx: mpsc::Receiver>, + mut connection_rx: mpsc::UnboundedReceiver, + ) { + // Make the initial set of connections, connection failures will be + // handled in the same way as future transport failures, so no need to + // do anything special. + for endpoint_index in 0..self.endpoints.len() { + for _ in 0..connections_per_endpoint { + self.connect_endpoint(endpoint_index, None); + } + } + + // The main worker loop, when select resolves one of its arms the other + // ones are cancelled, therefore it's important that they maintain no + // state while `await`-ing. This is enforced through the use of + // non-async functions to do all of the work. + loop { + tokio::select! { + request = worker_rx.recv() => { + let Some(request) = request else { + // The ConnectionManager was dropped, shut down the + // worker. + break; + }; + self.handle_worker(request); + } + maybe_request = connection_rx.recv() => { + if let Some(request) = maybe_request { + self.handle_connection(request); + } + } + maybe_connection_result = self.connect_next() => { + if let Some(connection_result) = maybe_connection_result { + self.handle_connected(connection_result); + } + } + } + } + } + + async fn connect_next(&mut self) -> Option { + if self.connecting_channels.is_empty() { + // Make this Future never resolve, we will get cancelled by the + // select if there's some change in state to `self` and can re-enter + // and evaluate `connecting_channels` again. + futures::future::pending::<()>().await; + } + self.connecting_channels.next().await + } + + // This must never be made async otherwise the select may cancel it. + fn handle_connected(&mut self, connection_result: IndexedChannel) { + match connection_result { + Ok(established_channel) => { + self.available_channels.push_back(established_channel); + self.maybe_available_connection(); + } + // When the retrier runs out of attempts start again from the + // beginning of the retry period. Never want to be in a + // situation where we give up on an Endpoint forever. + Err((identifier, _)) => { + self.connect_endpoint(identifier.endpoint_index, Some(identifier.connection_index)) + } + } + } + + fn connect_endpoint(&mut self, endpoint_index: usize, connection_index: Option) { + let Some((current_connection_index, endpoint)) = self.endpoints.get_mut(endpoint_index) + else { + // Unknown endpoint, this should never happen. + error!("Connection to unknown endpoint {endpoint_index} requested."); + return; + }; + let is_backoff = connection_index.is_some(); + let connection_index = connection_index.unwrap_or_else(|| { + *current_connection_index += 1; + *current_connection_index + }); + if is_backoff { + warn!( + "Connection {connection_index} failed to {:?}, reconnecting.", + endpoint.uri() + ); + } else { + info!( + "Creating new connection {connection_index} to {:?}.", + endpoint.uri() + ); + } + let identifier = ChannelIdentifier { + endpoint_index, + connection_index, + }; + let connection_stream = unfold(endpoint.clone(), move |endpoint| async move { + let result = endpoint.connect().await.map_err(|err| { + make_err!( + Code::Unavailable, + "Failed to connect to {:?}: {err:?}", + endpoint.uri() + ) + }); + Some(( + result.map_or_else(RetryResult::Retry, RetryResult::Ok), + endpoint, + )) + }); + let retrier = self.retrier.clone(); + self.connecting_channels.push(Box::pin(async move { + if is_backoff { + // Just in case the retry config is 0, then we need to + // introduce some delay so we aren't in a hard loop. + tokio::time::sleep(Duration::from_secs(1)).await; + } + retrier.retry(connection_stream).await.map_or_else( + |err| Err((identifier, err)), + |channel| { + Ok(EstablishedChannel { + identifier, + channel, + }) + }, + ) + })); + } + + // This must never be made async otherwise the select may cancel it. + fn handle_worker(&mut self, tx: oneshot::Sender) { + if let Some(channel) = (self.available_connections > 0) + .then_some(()) + .and_then(|_| self.available_channels.pop_front()) + { + self.provide_channel(channel, tx); + } else { + self.waiting_connections.push_back(tx); + } + } + + fn provide_channel(&mut self, channel: EstablishedChannel, tx: oneshot::Sender) { + // We decrement here because we create Connection, this will signal when + // it is Dropped and therefore increment this again. + self.available_connections -= 1; + let _ = tx.send(Connection { + connection_tx: self.connection_tx.clone(), + pending_channel: Some(channel.channel.clone()), + channel, + }); + } + + fn maybe_available_connection(&mut self) { + while self.available_connections > 0 + && !self.waiting_connections.is_empty() + && !self.available_channels.is_empty() + { + if let Some(channel) = self.available_channels.pop_front() { + if let Some(tx) = self.waiting_connections.pop_front() { + self.provide_channel(channel, tx); + } else { + // This should never happen, but better than an unwrap. + self.available_channels.push_front(channel); + } + } + } + } + + // This must never be made async otherwise the select may cancel it. + fn handle_connection(&mut self, request: ConnectionRequest) { + match request { + ConnectionRequest::Dropped(maybe_channel) => { + if let Some(channel) = maybe_channel { + self.available_channels.push_back(channel); + } + self.available_connections += 1; + self.maybe_available_connection(); + } + ConnectionRequest::Connected(channel) => { + self.available_channels.push_back(channel); + self.maybe_available_connection(); + } + // Handle a transport error on a connection by making it unavailable + // for use and establishing a new connection to the endpoint. + ConnectionRequest::Error((identifier, was_pending)) => { + let should_reconnect = if was_pending { + true + } else { + let original_length = self.available_channels.len(); + self.available_channels + .retain(|channel| channel.identifier != identifier); + // Only reconnect if it wasn't already disconnected. + original_length != self.available_channels.len() + }; + if should_reconnect { + self.connect_endpoint(identifier.endpoint_index, None); + } + } + } + } +} + +/// An instance of this is obtained for every communication with the gGRPC +/// service. This handles the permit for limiting concurrency, and also +/// re-connecting the underlying channel on error. It depends on users +/// reporting all errors. +/// NOTE: This should never be cloneable because its lifetime is linked to the +/// ConnectionManagerWorker::available_connections. +pub struct Connection { + /// Communication with ConnectionManagerWorker to inform about transport + /// errors and when the Connection is dropped. + connection_tx: mpsc::UnboundedSender, + /// If set, the Channel that will be returned to the worker when connection + /// completes (success or failure) or when the Connection is dropped if that + /// happens before connection completes. + pending_channel: Option, + /// The identifier to send to connection_tx. + channel: EstablishedChannel, +} + +impl Drop for Connection { + fn drop(&mut self) { + let pending_channel = self + .pending_channel + .take() + .map(|channel| EstablishedChannel { + channel, + identifier: self.channel.identifier, + }); + let _ = self + .connection_tx + .send(ConnectionRequest::Dropped(pending_channel)); + } +} + +/// A wrapper around the channel::ResponseFuture that forwards errors to the +/// connection_tx. +pub struct ResponseFuture { + /// The wrapped future that actually does the work. + inner: channel::ResponseFuture, + /// Communication with ConnectionManagerWorker to inform about transport + /// errors. + connection_tx: mpsc::UnboundedSender, + /// The identifier to send to connection_tx on a transport error. + identifier: ChannelIdentifier, +} + +/// This is mostly copied from tonic::transport::channel except it wraps it +/// to allow messaging about connection success and failure. +impl tonic::codegen::Service> for Connection { + type Response = tonic::codegen::http::Response; + type Error = tonic::transport::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + let result = self.channel.channel.poll_ready(cx); + if let Poll::Ready(result) = &result { + match result { + Ok(_) => { + if let Some(pending_channel) = self.pending_channel.take() { + let _ = self.connection_tx.send(ConnectionRequest::Connected( + EstablishedChannel { + channel: pending_channel, + identifier: self.channel.identifier, + }, + )); + } + } + Err(err) => { + debug!("Error while creating connection on channel: {err:?}"); + let _ = self.connection_tx.send(ConnectionRequest::Error(( + self.channel.identifier, + self.pending_channel.take().is_some(), + ))); + } + } + } + result + } + + fn call( + &mut self, + request: tonic::codegen::http::Request, + ) -> Self::Future { + ResponseFuture { + inner: self.channel.channel.call(request), + connection_tx: self.connection_tx.clone(), + identifier: self.channel.identifier, + } + } +} + +/// This is mostly copied from tonic::transport::channel except it wraps it +/// to allow messaging about connection failure. +impl Future for ResponseFuture { + type Output = + Result, tonic::transport::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let result = Pin::new(&mut self.inner).poll(cx); + if let Poll::Ready(Err(_)) = &result { + let _ = self + .connection_tx + .send(ConnectionRequest::Error((self.identifier, false))); + } + result + } +} diff --git a/nativelink-util/src/grpc_utils.rs b/nativelink-util/src/grpc_utils.rs deleted file mode 100644 index 631682d64..000000000 --- a/nativelink-util/src/grpc_utils.rs +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use async_lock::{Semaphore, SemaphoreGuard}; -use nativelink_error::{Code, Error}; -use parking_lot::Mutex; -use tonic::transport::{Channel, Endpoint}; - -/// A helper utility that enables management of a suite of connections to an -/// upstream gRPC endpoint using Tonic. -pub struct ConnectionManager { - /// The endpoints to establish Channels for. - endpoints: Vec, - /// A balance channel over the above endpoints which is kept with a - /// monotonic index to ensure we only re-create a channel on the first error - /// received on it. - channel: Mutex<(usize, Channel)>, - /// If a maximum number of upstream requests are allowed at a time, this - /// is a Semaphore to manage that. - request_semaphore: Option, -} - -impl ConnectionManager { - /// Create a connection manager that creates a balance list between a given - /// set of Endpoints. This will restrict the number of concurrent requests - /// assuming that the user of this connection manager uses the connection - /// only once and reports all errors. - pub fn new( - endpoints: impl IntoIterator, - max_concurrent_requests: usize, - ) -> Self { - let endpoints = Vec::from_iter(endpoints); - let channel = Channel::balance_list(endpoints.iter().cloned()); - Self { - endpoints, - channel: Mutex::new((0, channel)), - request_semaphore: (max_concurrent_requests > 0) - .then_some(Semaphore::new(max_concurrent_requests)), - } - } - - /// Get a connection slot for an Endpoint, this contains a Channel which - /// should be used once and any errors should be reported back to the - /// on_error method to ensure that the Channel is re-connected on error. - pub async fn get_connection(&self) -> (Connection<'_>, Channel) { - let _permit = if let Some(semaphore) = &self.request_semaphore { - Some(semaphore.acquire().await) - } else { - None - }; - let channel_lock = self.channel.lock(); - ( - Connection { - channel_id: channel_lock.0, - parent: self, - _permit, - }, - channel_lock.1.clone(), - ) - } -} - -/// An instance of this is obtained for every communication with the gGRPC -/// service. This handles the permit for limiting concurrency, and also -/// re-connecting the underlying channel on error. It depends on users -/// reporting all errors. -pub struct Connection<'a> { - channel_id: usize, - parent: &'a ConnectionManager, - _permit: Option>, -} - -impl<'a> Connection<'a> { - pub fn on_error(self, err: &Error) { - // Usually Tonic reconnects on upstream errors (like Unavailable) but - // if there are protocol errors (such as GoAway) then it will not - // attempt to re-connect, and therefore we are forced to manually do - // that. - if err.code != Code::Internal { - return; - } - // Create a new channel for future requests to use upon a new request - // to ConnectionManager::get_connection(). In order to ensure we only - // do this for the first error on a cloned Channel we check the ID - // matches the current ID when we get the lock. - let mut channel_lock = self.parent.channel.lock(); - if channel_lock.0 != self.channel_id { - // The connection was already re-established by another user getting - // and error on a clone of this Channel, so don't make another one. - return; - } - // Create a new channel with a unique ID to track if it gets an error. - // This new Channel will be used when a new request comes into - // ConnectionManager::get_connection() as this request has been and gone - // with an error now and it's up to the user whether they retry by - // getting a new connection. - channel_lock.0 += 1; - channel_lock.1 = Channel::balance_list(self.parent.endpoints.iter().cloned()); - } -} diff --git a/nativelink-util/src/lib.rs b/nativelink-util/src/lib.rs index b94c64931..3ac0caa2b 100644 --- a/nativelink-util/src/lib.rs +++ b/nativelink-util/src/lib.rs @@ -15,11 +15,11 @@ pub mod action_messages; pub mod buf_channel; pub mod common; +pub mod connection_manager; pub mod digest_hasher; pub mod evicting_map; pub mod fastcdc; pub mod fs; -pub mod grpc_utils; pub mod health_utils; pub mod metrics_utils; pub mod platform_properties; diff --git a/nativelink-util/src/retry.rs b/nativelink-util/src/retry.rs index 9bb7c1c6b..8157aeda0 100644 --- a/nativelink-util/src/retry.rs +++ b/nativelink-util/src/retry.rs @@ -42,7 +42,7 @@ impl Iterator for ExponentialBackoff { } type SleepFn = Arc Pin + Send>> + Sync + Send>; -type JitterFn = Arc Duration + Send + Sync>; +pub(crate) type JitterFn = Arc Duration + Send + Sync>; #[derive(PartialEq, Eq, Debug)] pub enum RetryResult {