From b9caf1ea13c9e7c0e09acc2b1dad05a85ac14a81 Mon Sep 17 00:00:00 2001 From: Zach Birenbaum Date: Fri, 5 Apr 2024 12:49:42 -0700 Subject: [PATCH] Implement worker api for killing running actions Implements worker api for requesting a currently running action to be killed. Allows for action cancellation requests to be sent from the scheduler during scenarios such as client disconnection. towards #338 --- .../remote_execution/worker_api.proto | 11 ++- ..._machina.nativelink.remote_execution.pb.rs | 13 ++- nativelink-worker/BUILD.bazel | 1 + nativelink-worker/src/local_worker.rs | 22 +++++ .../src/running_actions_manager.rs | 32 +++++++ nativelink-worker/tests/local_worker_test.rs | 91 +++++++++++++++++++ .../utils/mock_running_actions_manager.rs | 25 ++++- 7 files changed, 192 insertions(+), 3 deletions(-) diff --git a/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/worker_api.proto b/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/worker_api.proto index 1bf03f5d4..f6e45de2a 100644 --- a/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/worker_api.proto +++ b/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/worker_api.proto @@ -131,6 +131,12 @@ message ConnectionResult { reserved 2; // NextId. } +/// Request to kill a running action sent from the scheduler to a worker. +message KillActionRequest { + /// The the hex encoded unique qualifier for the action to be killed. + string action_id = 1; + reserved 2; // NextId. +} /// Communication from the scheduler to the worker. message UpdateForWorker { oneof update { @@ -152,8 +158,11 @@ message UpdateForWorker { /// Informs the worker that it has been disconnected from the pool. /// The worker may discard any outstanding work that is being executed. google.protobuf.Empty disconnect = 4; + + /// Instructs the worker to kill a specific running action. + KillActionRequest kill_action_request = 5; } - reserved 5; // NextId. + reserved 6; // NextId. } message StartExecute { diff --git a/nativelink-proto/genproto/com.github.trace_machina.nativelink.remote_execution.pb.rs b/nativelink-proto/genproto/com.github.trace_machina.nativelink.remote_execution.pb.rs index aa8ae4416..de378ec2d 100644 --- a/nativelink-proto/genproto/com.github.trace_machina.nativelink.remote_execution.pb.rs +++ b/nativelink-proto/genproto/com.github.trace_machina.nativelink.remote_execution.pb.rs @@ -102,11 +102,19 @@ pub struct ConnectionResult { #[prost(string, tag = "1")] pub worker_id: ::prost::alloc::string::String, } +/// / Request to kill a running action sent from the scheduler to a worker. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct KillActionRequest { + /// / The the hash of the unique qualifier for the action to be killed. + #[prost(string, tag = "1")] + pub action_id: ::prost::alloc::string::String, +} /// / Communication from the scheduler to the worker. #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct UpdateForWorker { - #[prost(oneof = "update_for_worker::Update", tags = "1, 2, 3, 4")] + #[prost(oneof = "update_for_worker::Update", tags = "1, 2, 3, 4, 5")] pub update: ::core::option::Option, } /// Nested message and enum types in `UpdateForWorker`. @@ -133,6 +141,9 @@ pub mod update_for_worker { /// / The worker may discard any outstanding work that is being executed. #[prost(message, tag = "4")] Disconnect(()), + /// / Instructs the worker to kill a specific running action. + #[prost(message, tag = "5")] + KillActionRequest(super::KillActionRequest), } } #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/nativelink-worker/BUILD.bazel b/nativelink-worker/BUILD.bazel index 9abaf1ef6..7c93c6701 100644 --- a/nativelink-worker/BUILD.bazel +++ b/nativelink-worker/BUILD.bazel @@ -70,6 +70,7 @@ rust_test_suite( "//nativelink-util", "@crates//:async-lock", "@crates//:futures", + "@crates//:hex", "@crates//:hyper", "@crates//:once_cell", "@crates//:pretty_assertions", diff --git a/nativelink-worker/src/local_worker.rs b/nativelink-worker/src/local_worker.rs index 28c3e9e1e..fd880730a 100644 --- a/nativelink-worker/src/local_worker.rs +++ b/nativelink-worker/src/local_worker.rs @@ -210,6 +210,28 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, Update::KeepAlive(()) => { self.metrics.keep_alives_received.inc(); } + Update::KillActionRequest(kill_action_request) => { + let mut action_id = [0u8; 32]; + let decode_result = hex::decode_to_slice(kill_action_request.action_id, &mut action_id as &mut [u8]); + + if let Err(e) = decode_result { + return Err(make_input_err!( + "KillActionRequest failed to decode ActionId hex with error {}", + e + )); + } + let kill_action_result = self.running_actions_manager + .kill_action(action_id) + .await; + + if let Err(e) = kill_action_result { + return Err(make_input_err!( + "Kill action {} failed with error - {}", + hex::encode(action_id), + e + )); + } + } Update::StartAction(start_execute) => { self.metrics.start_actions_received.inc(); let add_future_channel = add_future_channel.clone(); diff --git a/nativelink-worker/src/running_actions_manager.rs b/nativelink-worker/src/running_actions_manager.rs index 4e77904b2..afa21170a 100644 --- a/nativelink-worker/src/running_actions_manager.rs +++ b/nativelink-worker/src/running_actions_manager.rs @@ -1285,6 +1285,8 @@ pub trait RunningActionsManager: Sync + Send + Sized + Unpin + 'static { async fn kill_all(&self); + async fn kill_action(&self, action_id: ActionId) -> Result<(), Error>; + fn metrics(&self) -> &Arc; } @@ -1783,6 +1785,36 @@ impl RunningActionsManager for RunningActionsManagerImpl { .await } + async fn kill_action(&self, action_id: ActionId) -> Result<(), Error> { + + let upgrade_or_err: Result, Error> = { + let running_actions = self.running_actions.lock(); + running_actions.get(&action_id).map_or( + Err(make_input_err!( + "Failed to get running action {}", + hex::encode(action_id) + )), + |action| { + match action.upgrade() { + Some(upgraded_action) => Ok(upgraded_action), + None => Err(make_input_err!( + "Failed to upgrade after getting action {}", + hex::encode(action_id) + )) + } + } + ) + }; + + match upgrade_or_err { + Ok(upgrade) => { + Self::kill_action(upgrade).await; + Ok(()) + } + Err(e) => Err(e) + } + } + // Note: When the future returns the process should be fully killed and cleaned up. async fn kill_all(&self) { self.metrics diff --git a/nativelink-worker/tests/local_worker_test.rs b/nativelink-worker/tests/local_worker_test.rs index cf45ef2c3..c62a57785 100644 --- a/nativelink-worker/tests/local_worker_test.rs +++ b/nativelink-worker/tests/local_worker_test.rs @@ -71,6 +71,7 @@ fn make_temp_path(data: &str) -> String { #[cfg(test)] mod local_worker_tests { + use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::KillActionRequest; use pretty_assertions::assert_eq; use super::*; // Must be declared in every module. @@ -638,4 +639,94 @@ mod local_worker_tests { Ok(()) } + + #[tokio::test] + async fn kill_action_request_kills_action() -> Result<(), Box> { + const SALT: u64 = 1000; + + let mut test_context = setup_local_worker(HashMap::new()).await; + + let streaming_response = test_context.maybe_streaming_response.take().unwrap(); + + { + // Ensure our worker connects and properties were sent. + let props = test_context + .client + .expect_connect_worker(Ok(streaming_response)) + .await; + assert_eq!(props, SupportedProperties::default()); + } + + // Handle registration (kill_all not called unless registered). + let mut tx_stream = test_context.maybe_tx_stream.take().unwrap(); + { + tx_stream + .send_data(encode_stream_proto(&UpdateForWorker { + update: Some(Update::ConnectionResult(ConnectionResult { + worker_id: "foobar".to_string(), + })), + })?) + .await + .map_err(|e| make_input_err!("Could not send : {:?}", e))?; + } + + let action_digest = DigestInfo::new([3u8; 32], 10); + let action_info = ActionInfo { + command_digest: DigestInfo::new([1u8; 32], 10), + input_root_digest: DigestInfo::new([2u8; 32], 10), + timeout: Duration::from_secs(1), + platform_properties: PlatformProperties::default(), + priority: 0, + load_timestamp: SystemTime::UNIX_EPOCH, + insert_timestamp: SystemTime::UNIX_EPOCH, + unique_qualifier: ActionInfoHashKey { + instance_name: INSTANCE_NAME.to_string(), + digest: action_digest, + salt: SALT, + }, + skip_cache_lookup: true, + digest_function: DigestHasherFunc::Blake3, + }; + + { + // Send execution request. + tx_stream + .send_data(encode_stream_proto(&UpdateForWorker { + update: Some(Update::StartAction(StartExecute { + execute_request: Some(action_info.clone().into()), + salt: SALT, + queued_timestamp: None, + })), + })?) + .await + .map_err(|e| make_input_err!("Could not send : {:?}", e))?; + } + let running_action = Arc::new(MockRunningAction::new()); + + // Send and wait for response from create_and_add_action to RunningActionsManager. + test_context + .actions_manager + .expect_create_and_add_action(Ok(running_action.clone())) + .await; + + let action_id = action_info.unique_qualifier.get_hash(); + { + // Send kill request. + tx_stream + .send_data(encode_stream_proto(&UpdateForWorker { + update: Some(Update::KillActionRequest(KillActionRequest { + action_id: hex::encode(action_id), + })), + })?) + .await + .map_err(|e| make_input_err!("Could not send : {:?}", e))?; + } + + let killed_action_id = test_context.actions_manager.expect_kill_action().await; + + // Make sure that the killed action is the one we intended + assert_eq!(killed_action_id, action_id); + + Ok(()) + } } diff --git a/nativelink-worker/tests/utils/mock_running_actions_manager.rs b/nativelink-worker/tests/utils/mock_running_actions_manager.rs index 9eadf998d..1e61da83e 100644 --- a/nativelink-worker/tests/utils/mock_running_actions_manager.rs +++ b/nativelink-worker/tests/utils/mock_running_actions_manager.rs @@ -21,7 +21,9 @@ use nativelink_proto::com::github::trace_machina::nativelink::remote_execution:: use nativelink_util::action_messages::ActionResult; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; -use nativelink_worker::running_actions_manager::{Metrics, RunningAction, RunningActionsManager}; +use nativelink_worker::running_actions_manager::{ + ActionId, Metrics, RunningAction, RunningActionsManager, +}; use tokio::sync::mpsc; #[derive(Debug)] @@ -43,6 +45,9 @@ pub struct MockRunningActionsManager { rx_kill_all: Mutex>, tx_kill_all: mpsc::UnboundedSender<()>, + + rx_kill_action: Mutex>, + tx_kill_action: mpsc::UnboundedSender, metrics: Arc, } @@ -57,6 +62,7 @@ impl MockRunningActionsManager { let (tx_call, rx_call) = mpsc::unbounded_channel(); let (tx_resp, rx_resp) = mpsc::unbounded_channel(); let (tx_kill_all, rx_kill_all) = mpsc::unbounded_channel(); + let (tx_kill_action, rx_kill_action) = mpsc::unbounded_channel(); Self { rx_call: Mutex::new(rx_call), tx_call, @@ -64,6 +70,8 @@ impl MockRunningActionsManager { tx_resp, rx_kill_all: Mutex::new(rx_kill_all), tx_kill_all, + rx_kill_action: Mutex::new(rx_kill_action), + tx_kill_action, metrics: Arc::new(Metrics::default()), } } @@ -108,6 +116,14 @@ impl MockRunningActionsManager { .await .expect("Could not receive msg in mpsc"); } + + pub async fn expect_kill_action(&self) -> ActionId { + let mut rx_kill_action_lock = self.rx_kill_action.lock().await; + rx_kill_action_lock + .recv() + .await + .expect("Could not receive msg in mpsc") + } } #[async_trait] @@ -151,6 +167,13 @@ impl RunningActionsManager for MockRunningActionsManager { Ok(()) } + async fn kill_action(&self, action_id: ActionId) -> Result<(), Error> { + self.tx_kill_action + .send(action_id) + .expect("Could not send request to mpsc"); + Ok(()) + } + async fn kill_all(&self) { self.tx_kill_all .send(())