Skip to content

Commit

Permalink
Implement worker api for killing running actions
Browse files Browse the repository at this point in the history
Implements worker api for requesting a currently running action to be
killed. Allows for action cancellation requests to be sent from the
scheduler during scenarios such as client disconnection.

towards #338
  • Loading branch information
Zach Birenbaum committed Apr 5, 2024
1 parent 6b9e68e commit b91f05d
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ message ConnectionResult {
reserved 2; // NextId.
}

/// Request to kill a running action sent from the scheduler to a worker.
message KillActionRequest {
/// The the hash of the unique qualifier for the action to be killed.
string action_id = 1;
reserved 2;
}
/// Communication from the scheduler to the worker.
message UpdateForWorker {
oneof update {
Expand All @@ -152,8 +158,11 @@ message UpdateForWorker {
/// Informs the worker that it has been disconnected from the pool.
/// The worker may discard any outstanding work that is being executed.
google.protobuf.Empty disconnect = 4;

/// Instructs the worker to kill a specific running action.
KillActionRequest kill_action_request = 5;
}
reserved 5; // NextId.
reserved 6; // NextId.
}

message StartExecute {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,19 @@ pub struct ConnectionResult {
#[prost(string, tag = "1")]
pub worker_id: ::prost::alloc::string::String,
}
/// / Request to kill a running action sent from the scheduler to a worker.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct KillActionRequest {
/// / The the hash of the unique qualifier for the action to be killed.
#[prost(string, tag = "1")]
pub action_id: ::prost::alloc::string::String,
}
/// / Communication from the scheduler to the worker.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct UpdateForWorker {
#[prost(oneof = "update_for_worker::Update", tags = "1, 2, 3, 4")]
#[prost(oneof = "update_for_worker::Update", tags = "1, 2, 3, 4, 5")]
pub update: ::core::option::Option<update_for_worker::Update>,
}
/// Nested message and enum types in `UpdateForWorker`.
Expand All @@ -133,6 +141,9 @@ pub mod update_for_worker {
/// / The worker may discard any outstanding work that is being executed.
#[prost(message, tag = "4")]
Disconnect(()),
/// / Instructs the worker to kill a specific running action.
#[prost(message, tag = "5")]
KillActionRequest(super::KillActionRequest),
}
}
#[allow(clippy::derive_partial_eq_without_eq)]
Expand Down
12 changes: 12 additions & 0 deletions nativelink-worker/src/local_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,18 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
Update::KeepAlive(()) => {
self.metrics.keep_alives_received.inc();
}
Update::KillActionRequest(kill_action_request) => {
let kill_action_result = self.running_actions_manager
.kill_action(kill_action_request.clone())
.await;
if let Err(e) = kill_action_result {
warn!(
"Kill action {} failed with error - {}",
kill_action_request.action_id,
e
)
}
}
Update::StartAction(start_execute) => {
self.metrics.start_actions_received.inc();
let add_future_channel = add_future_channel.clone();
Expand Down
27 changes: 26 additions & 1 deletion nativelink-worker/src/running_actions_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use nativelink_proto::build::bazel::remote::execution::v2::{
Tree as ProtoTree, UpdateActionResultRequest,
};
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{
HistoricalExecuteResponse, StartExecute,
HistoricalExecuteResponse, KillActionRequest, StartExecute,
};
use nativelink_store::ac_utils::{
compute_buf_digest, get_and_decode_digest, serialize_and_upload_message, ESTIMATED_DIGEST_SIZE,
Expand Down Expand Up @@ -1285,6 +1285,8 @@ pub trait RunningActionsManager: Sync + Send + Sized + Unpin + 'static {

async fn kill_all(&self);

async fn kill_action(&self, kill_action_request: KillActionRequest) -> Result<(), Error>;

fn metrics(&self) -> &Arc<Metrics>;
}

Expand Down Expand Up @@ -1783,6 +1785,29 @@ impl RunningActionsManager for RunningActionsManagerImpl {
.await
}

async fn kill_action(&self, kill_action_request: KillActionRequest) -> Result<(), Error> {
let mut action_id: ActionId = [0u8; 32];

assert_eq!(
hex::decode_to_slice(kill_action_request.action_id, &mut action_id as &mut [u8]),
Ok(()),
"Failed to decode action id"
);

let running_actions = self.running_actions.lock();
running_actions.get(&action_id).map_or(
Err(make_err!(
Code::Internal,
"Failed to get running action {}",
hex::encode(action_id)
)),
|action| {
action.upgrade().map(Self::kill_action);
Ok(())
},
)
}

// Note: When the future returns the process should be fully killed and cleaned up.
async fn kill_all(&self) {
self.metrics
Expand Down
91 changes: 91 additions & 0 deletions nativelink-worker/tests/local_worker_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ fn make_temp_path(data: &str) -> String {

#[cfg(test)]
mod local_worker_tests {
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::KillActionRequest;
use pretty_assertions::assert_eq;

use super::*; // Must be declared in every module.
Expand Down Expand Up @@ -638,4 +639,94 @@ mod local_worker_tests {

Ok(())
}

#[tokio::test]
async fn kill_action_request_kills_action() -> Result<(), Box<dyn std::error::Error>> {
const SALT: u64 = 1000;

let mut test_context = setup_local_worker(HashMap::new()).await;

let streaming_response = test_context.maybe_streaming_response.take().unwrap();

{
// Ensure our worker connects and properties were sent.
let props = test_context
.client
.expect_connect_worker(Ok(streaming_response))
.await;
assert_eq!(props, SupportedProperties::default());
}

// Handle registration (kill_all not called unless registered).
let mut tx_stream = test_context.maybe_tx_stream.take().unwrap();
{
tx_stream
.send_data(encode_stream_proto(&UpdateForWorker {
update: Some(Update::ConnectionResult(ConnectionResult {
worker_id: "foobar".to_string(),
})),
})?)
.await
.map_err(|e| make_input_err!("Could not send : {:?}", e))?;
}

let action_digest = DigestInfo::new([3u8; 32], 10);
let action_info = ActionInfo {
command_digest: DigestInfo::new([1u8; 32], 10),
input_root_digest: DigestInfo::new([2u8; 32], 10),
timeout: Duration::from_secs(1),
platform_properties: PlatformProperties::default(),
priority: 0,
load_timestamp: SystemTime::UNIX_EPOCH,
insert_timestamp: SystemTime::UNIX_EPOCH,
unique_qualifier: ActionInfoHashKey {
instance_name: INSTANCE_NAME.to_string(),
digest: action_digest,
salt: SALT,
},
skip_cache_lookup: true,
digest_function: DigestHasherFunc::Blake3,
};

{
// Send execution request.
tx_stream
.send_data(encode_stream_proto(&UpdateForWorker {
update: Some(Update::StartAction(StartExecute {
execute_request: Some(action_info.clone().into()),
salt: SALT,
queued_timestamp: None,
})),
})?)
.await
.map_err(|e| make_input_err!("Could not send : {:?}", e))?;
}
let running_action = Arc::new(MockRunningAction::new());

// Send and wait for response from create_and_add_action to RunningActionsManager.
test_context
.actions_manager
.expect_create_and_add_action(Ok(running_action.clone()))
.await;

let action_id = action_info.unique_qualifier.get_hash();
{
// Send kill request.
tx_stream
.send_data(encode_stream_proto(&UpdateForWorker {
update: Some(Update::KillActionRequest(KillActionRequest {
action_id: hex::encode(action_id),
})),
})?)
.await
.map_err(|e| make_input_err!("Could not send : {:?}", e))?;
}

let killed_action_id = test_context.actions_manager.expect_kill_action().await;

// Make sure that the killed action is the one we intended
assert_eq!(killed_action_id, action_id);

Ok(())
}
}
37 changes: 35 additions & 2 deletions nativelink-worker/tests/utils/mock_running_actions_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@ use std::sync::Arc;
use async_lock::Mutex;
use async_trait::async_trait;
use nativelink_error::{make_input_err, Error};
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::StartExecute;
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{
KillActionRequest, StartExecute,
};
use nativelink_util::action_messages::ActionResult;
use nativelink_util::common::DigestInfo;
use nativelink_util::digest_hasher::DigestHasherFunc;
use nativelink_worker::running_actions_manager::{Metrics, RunningAction, RunningActionsManager};
use nativelink_worker::running_actions_manager::{
ActionId, Metrics, RunningAction, RunningActionsManager,
};
use tokio::sync::mpsc;

#[derive(Debug)]
Expand All @@ -43,6 +47,9 @@ pub struct MockRunningActionsManager {

rx_kill_all: Mutex<mpsc::UnboundedReceiver<()>>,
tx_kill_all: mpsc::UnboundedSender<()>,

rx_kill_action: Mutex<mpsc::UnboundedReceiver<KillActionRequest>>,
tx_kill_action: mpsc::UnboundedSender<KillActionRequest>,
metrics: Arc<Metrics>,
}

Expand All @@ -57,13 +64,16 @@ impl MockRunningActionsManager {
let (tx_call, rx_call) = mpsc::unbounded_channel();
let (tx_resp, rx_resp) = mpsc::unbounded_channel();
let (tx_kill_all, rx_kill_all) = mpsc::unbounded_channel();
let (tx_kill_action, rx_kill_action) = mpsc::unbounded_channel();
Self {
rx_call: Mutex::new(rx_call),
tx_call,
rx_resp: Mutex::new(rx_resp),
tx_resp,
rx_kill_all: Mutex::new(rx_kill_all),
tx_kill_all,
rx_kill_action: Mutex::new(rx_kill_action),
tx_kill_action,
metrics: Arc::new(Metrics::default()),
}
}
Expand Down Expand Up @@ -108,6 +118,22 @@ impl MockRunningActionsManager {
.await
.expect("Could not receive msg in mpsc");
}

pub async fn expect_kill_action(&self) -> ActionId {
let mut rx_kill_action_lock = self.rx_kill_action.lock().await;
let kill_action_request = rx_kill_action_lock
.recv()
.await
.expect("Could not receive msg in mpsc");

let mut action_id: ActionId = [0u8; 32];
assert_eq!(
hex::decode_to_slice(kill_action_request.action_id, &mut action_id as &mut [u8]),
Ok(()),
"Failed to decode action id"
);
action_id
}
}

#[async_trait]
Expand Down Expand Up @@ -151,6 +177,13 @@ impl RunningActionsManager for MockRunningActionsManager {
Ok(())
}

async fn kill_action(&self, kill_action_request: KillActionRequest) -> Result<(), Error> {
self.tx_kill_action
.send(kill_action_request)
.expect("Could not send request to mpsc");
Ok(())
}

async fn kill_all(&self) {
self.tx_kill_all
.send(())
Expand Down

0 comments on commit b91f05d

Please sign in to comment.