From 94288263af14bcf733e147ffa940d35e94d08e49 Mon Sep 17 00:00:00 2001 From: Blaise Bruer Date: Thu, 31 Oct 2024 13:10:34 -0500 Subject: [PATCH] changes --- Cargo.toml | 2 +- nativelink-util/src/shutdown_manager.rs | 67 ++++++++++++---- nativelink-worker/src/local_worker.rs | 4 +- src/bin/nativelink.rs | 101 ++++++++++++++---------- 4 files changed, 115 insertions(+), 59 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 33fe54582..f1c66a57b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ edition = "2021" rust-version = "1.81.0" [profile.release] -lto = true +lto = false opt-level = 3 [profile.dev] diff --git a/nativelink-util/src/shutdown_manager.rs b/nativelink-util/src/shutdown_manager.rs index 9170e410f..89935cd02 100644 --- a/nativelink-util/src/shutdown_manager.rs +++ b/nativelink-util/src/shutdown_manager.rs @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::future::Future; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, Weak}; +use futures::future::ready; +use futures::FutureExt; use parking_lot::Mutex; use tokio::runtime::Handle; #[cfg(target_family = "unix")] @@ -25,6 +28,7 @@ use tracing::{event, Level}; static SHUTDOWN_MANAGER: ShutdownManager = ShutdownManager { is_shutting_down: AtomicBool::new(false), shutdown_tx: Mutex::new(None), // Will be initialized in `init`. + maybe_shutdown_weak_sender: Mutex::new(None), }; /// Broadcast Channel Capacity @@ -43,6 +47,7 @@ const BROADCAST_CAPACITY: usize = 1; pub struct ShutdownManager { is_shutting_down: AtomicBool, shutdown_tx: Mutex>>>>, + maybe_shutdown_weak_sender: Mutex>>>, } impl ShutdownManager { @@ -85,9 +90,14 @@ impl ShutdownManager { event!(Level::WARN, "Shutdown already in progress."); return; } + let (complete_tx, complete_rx) = oneshot::channel::<()>(); + let shutdown_guard = Arc::new(complete_tx); + SHUTDOWN_MANAGER + .maybe_shutdown_weak_sender + .lock() + .replace(Arc::downgrade(&shutdown_guard)) + .expect("Expected maybe_shutdown_weak_sender to be empty"); tokio::spawn(async move { - let (complete_tx, complete_rx) = oneshot::channel::<()>(); - let shutdown_guard = Arc::new(complete_tx); { let shutdown_tx_lock = SHUTDOWN_MANAGER.shutdown_tx.lock(); // No need to check result of send, since it will only fail if @@ -106,33 +116,58 @@ impl ShutdownManager { }); } - pub async fn wait_for_shutdown(service_name: impl Into) -> ShutdownGuard { + pub fn wait_for_shutdown(service_name: impl Into) -> impl Future + Send { let service_name = service_name.into(); + if Self::is_shutting_down() { + let maybe_shutdown_weak_sender_lock = SHUTDOWN_MANAGER + .maybe_shutdown_weak_sender + .lock(); + let maybe_sender = maybe_shutdown_weak_sender_lock + .as_ref() + .expect("Expected maybe_shutdown_weak_sender to be set"); + if let Some(sender) = maybe_sender.upgrade() { + event!( + Level::INFO, + "Service {service_name} has been notified of shutdown request" + ); + return ready(ShutdownGuard { + service_name, + _maybe_guard: Some(sender), + }).left_future(); + } + return ready(ShutdownGuard { + service_name, + _maybe_guard: None, + }).left_future(); + } let mut shutdown_receiver = SHUTDOWN_MANAGER .shutdown_tx .lock() .as_ref() .expect("ShutdownManager was never initialized") .subscribe(); - let sender = shutdown_receiver - .recv() - .await - .expect("Shutdown sender dropped. This should never happen."); - event!( - Level::INFO, - "Service {service_name} has been notified of shutdown request" - ); - ShutdownGuard { - service_name, - _guard: sender, + async move { + let sender = shutdown_receiver + .recv() + .await + .expect("Shutdown sender dropped. This should never happen."); + event!( + Level::INFO, + "Service {service_name} has been notified of shutdown request" + ); + ShutdownGuard { + service_name, + _maybe_guard: Some(sender), + } } + .right_future() } } #[derive(Clone)] pub struct ShutdownGuard { service_name: String, - _guard: Arc>, + _maybe_guard: Option>>, } impl Drop for ShutdownGuard { diff --git a/nativelink-worker/src/local_worker.rs b/nativelink-worker/src/local_worker.rs index 765cd5dd7..af7b70544 100644 --- a/nativelink-worker/src/local_worker.rs +++ b/nativelink-worker/src/local_worker.rs @@ -193,6 +193,8 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, // If we are shutting down we need to hold onto the shutdown guard // until we are done processing all the futures. let mut _maybe_shutdown_guard = None; + let wait_for_shutdown_fut = ShutdownManager::wait_for_shutdown("LocalWorker").fuse(); + tokio::pin!(wait_for_shutdown_fut); loop { select! { maybe_update = update_for_worker_stream.next() => { @@ -381,7 +383,7 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, // If we are not shutting down and get an error, return the error. res?; }, - shutdown_guard = ShutdownManager::wait_for_shutdown("LocalWorker").fuse() => { + shutdown_guard = wait_for_shutdown_fut.as_mut() => { _maybe_shutdown_guard = Some(shutdown_guard); event!(Level::INFO, "Worker loop reveiced shutdown signal. Shutting down worker...",); let mut grpc_client = self.grpc_client.clone(); diff --git a/src/bin/nativelink.rs b/src/bin/nativelink.rs index 85cd6ff35..3c64944fe 100644 --- a/src/bin/nativelink.rs +++ b/src/bin/nativelink.rs @@ -21,6 +21,7 @@ use async_lock::Mutex as AsyncMutex; use axum::Router; use clap::Parser; use futures::future::{try_join_all, BoxFuture, Either, OptionFuture, TryFutureExt}; +use futures::FutureExt; use hyper::{Response, StatusCode}; use hyper_util::rt::tokio::TokioIo; use hyper_util::server::conn::auto; @@ -58,7 +59,7 @@ use nativelink_util::store_trait::{ set_default_digest_size_health_check, DEFAULT_DIGEST_SIZE_HEALTH_CHECK_CFG, }; use nativelink_util::task::TaskExecutor; -use nativelink_util::{background_spawn, init_tracing, spawn, spawn_blocking}; +use nativelink_util::{init_tracing, spawn, spawn_blocking}; use nativelink_worker::local_worker::new_local_worker; use opentelemetry::metrics::MeterProvider; use opentelemetry_sdk::metrics::SdkMeterProvider; @@ -779,26 +780,12 @@ async fn inner_main( event!(Level::WARN, "Ready, listening on {socket_addr}"); root_futures.push(Box::pin(async move { let shutdown_guard = Arc::new(Mutex::new(None)); - let name = format!("TcpSocketListener_{name}"); + let socket_name = format!("TcpSocketListener_{name}"); + let wait_for_shutdown_fut = ShutdownManager::wait_for_shutdown(socket_name.clone()).fuse(); + tokio::pin!(wait_for_shutdown_fut); loop { select! { - inner_shutdown_guard = ShutdownManager::wait_for_shutdown(name.clone()) => { - if server_cfg.experimental_connections_dont_block_graceful_shutdown { - event!( - target: "nativelink", - Level::INFO, - name, - "Connections will not block graceful shutdown" - ); - continue; - } - let connected_clients = connected_clients_mux.inner.lock(); - if connected_clients.is_empty() { - drop(shutdown_guard.lock().take()); - } else { - *shutdown_guard.lock() = Some(inner_shutdown_guard); - } - } + biased; accept_result = tcp_listener.accept() => { match accept_result { Ok((tcp_stream, remote_addr)) => { @@ -816,18 +803,19 @@ async fn inner_main( connected_clients_mux.counter.inc(); let shutdown_guard = shutdown_guard.clone(); - let name = name.clone(); + let socket_name_clone = socket_name.clone(); // This is the safest way to guarantee that if our future // is ever dropped we will cleanup our data. let scope_guard = guard( Arc::downgrade(&connected_clients_mux), move |weak_connected_clients_mux| { + let socket_name = socket_name_clone; event!( target: "nativelink::services", Level::INFO, ?remote_addr, ?socket_addr, - name, + socket_name, "Client disconnected" ); if let Some(connected_clients_mux) = weak_connected_clients_mux.upgrade() { @@ -841,7 +829,7 @@ async fn inner_main( Level::INFO, ?remote_addr, ?socket_addr, - name, + socket_name, "No more clients connected & received shutdown signal." ); drop(shutdown_guard.lock().take()); @@ -849,7 +837,7 @@ async fn inner_main( event!( target: "nativelink::services", Level::INFO, - name, + socket_name, ?connected_clients, "Waiting on {} more clients to disconnect before shutting down.", connected_clients.len() @@ -859,7 +847,7 @@ async fn inner_main( } }, ); - + let socket_name = socket_name.clone(); let (http, svc, maybe_tls_acceptor) = (http.clone(), svc.clone(), maybe_tls_acceptor.clone()); Arc::new(OriginContext::new()).background_spawn( @@ -869,11 +857,7 @@ async fn inner_main( ?remote_addr, ?socket_addr ), - async move {}, - ); - background_spawn!( - name: "http_connection", - fut: async move { + async move { // Move it into our spawn, so if our spawn dies the cleanup happens. let _guard = scope_guard; let serve_connection = if let Some(tls_acceptor) = maybe_tls_acceptor { @@ -893,19 +877,37 @@ async fn inner_main( TowerToHyperService::new(svc), )) }; - - if let Err(err) = serve_connection.await { - event!( - target: "nativelink::services", - Level::ERROR, - ?err, - "Failed running service" - ); + let connection_name = format!("Connection_{socket_name}_{remote_addr}"); + let wait_for_shutdown_fut = ShutdownManager::wait_for_shutdown(connection_name.clone()).fuse(); + tokio::pin!(wait_for_shutdown_fut); + tokio::pin!(serve_connection); + loop { + select! { + biased; + res = serve_connection.as_mut() => { + if let Err(err) = res { + event!( + target: "nativelink::services", + Level::ERROR, + ?err, + "Failed running service" + ); + } + break; + } + // Note: We don't need to hold onto this shutdown guard because + // we already have one captured in the outer scope. + _shutdown_guard = wait_for_shutdown_fut.as_mut() => { + if !server_cfg.experimental_connections_dont_block_graceful_shutdown { + match serve_connection.as_mut().as_pin_mut() { + Either::Left(conn) => conn.graceful_shutdown(), + Either::Right(conn) => conn.graceful_shutdown(), + } + } + }, + } } - }, - target: "nativelink::services", - ?remote_addr, - ?socket_addr, + } ); }, Err(err) => { @@ -914,6 +916,23 @@ async fn inner_main( } } }, + inner_shutdown_guard = wait_for_shutdown_fut.as_mut() => { + if server_cfg.experimental_connections_dont_block_graceful_shutdown { + event!( + target: "nativelink", + Level::INFO, + socket_name, + "Connections will not block graceful shutdown" + ); + continue; + } + let connected_clients = connected_clients_mux.inner.lock(); + if connected_clients.is_empty() { + drop(shutdown_guard.lock().take()); + } else { + *shutdown_guard.lock() = Some(inner_shutdown_guard); + } + } } } // Unreachable