Skip to content

Commit

Permalink
Add scheduler function to kill action on worker
Browse files Browse the repository at this point in the history
Adds methods to send a request from the scheduler to a worker for
a running action to be killed. Allows for action cancellation
requests to be sent from the scheduler during scenarios such
as client disconnection.

towards TraceMachina#338
  • Loading branch information
Zach Birenbaum committed Apr 7, 2024
1 parent b9caf1e commit e13020b
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 33 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions nativelink-scheduler/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ tokio = { version = "1.37.0", features = ["sync", "rt", "parking_lot"] }
tokio-stream = { version = "0.1.15", features = ["sync"] }
tonic = { version = "0.11.0", features = ["gzip", "tls"] }
tracing = "0.1.40"
hex = "0.4.3"

[dev-dependencies]
pretty_assertions = "1.4.0"
16 changes: 16 additions & 0 deletions nativelink-scheduler/src/simple_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,22 @@ impl WorkerScheduler for SimpleScheduler {
inner.set_drain_worker(worker_id, is_draining)
}

async fn kill_running_action_on_worker(
&self,
worker_id: &WorkerId,
unique_qualifier: &ActionInfoHashKey,
) -> Result<(), Error> {
let mut inner = self.get_inner_lock();
let maybe_worker = inner.workers.workers.get_mut(worker_id);
match maybe_worker {
Some(worker) => worker.send_kill_action_request(unique_qualifier),
None => Err(make_err!(
Code::Internal,
"Worker {worker_id} timed out, removing from pool"
)),
}
}

fn register_metrics(self: Arc<Self>, _registry: &mut Registry) {
// We do not register anything here because we only want to register metrics
// once and we rely on the `ActionScheduler::register_metrics()` to do that.
Expand Down
16 changes: 14 additions & 2 deletions nativelink-scheduler/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ use std::time::{SystemTime, UNIX_EPOCH};

use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt};
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{
update_for_worker, ConnectionResult, StartExecute, UpdateForWorker,
update_for_worker, ConnectionResult, KillActionRequest, StartExecute, UpdateForWorker,
};
use nativelink_util::action_messages::ActionInfo;
use nativelink_util::action_messages::{ActionInfo, ActionInfoHashKey};
use nativelink_util::metrics_utils::{
CollectorState, CounterWithTime, FuncCounterWrapper, MetricsComponent,
};
Expand Down Expand Up @@ -160,6 +160,18 @@ impl Worker {
}
}

pub fn send_kill_action_request(
&mut self,
unique_qualifier: &ActionInfoHashKey,
) -> Result<(), Error> {
send_msg_to_worker(
&mut self.tx,
update_for_worker::Update::KillActionRequest(KillActionRequest {
action_id: hex::encode(unique_qualifier.get_hash()),
}),
)
}

/// Sends the initial connection information to the worker. This generally is just meta info.
/// This should only be sent once and should always be the first item in the stream.
pub fn send_initial_connection_result(&mut self) -> Result<(), Error> {
Expand Down
7 changes: 7 additions & 0 deletions nativelink-scheduler/src/worker_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ pub trait WorkerScheduler: Sync + Send + Unpin {
timestamp: WorkerTimestamp,
) -> Result<(), Error>;

/// Send a request to kill a running action
async fn kill_running_action_on_worker(
&self,
worker_id: &WorkerId,
unique_qualifier: &ActionInfoHashKey,
) -> Result<(), Error>;

/// Removes worker from pool and reschedule any tasks that might be running on it.
async fn remove_worker(&self, worker_id: WorkerId);

Expand Down
65 changes: 64 additions & 1 deletion nativelink-scheduler/tests/simple_scheduler_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ mod utils {
}
use nativelink_proto::build::bazel::remote::execution::v2::{digest_function, ExecuteRequest};
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{
update_for_worker, ConnectionResult, StartExecute, UpdateForWorker,
update_for_worker, ConnectionResult, KillActionRequest, StartExecute, UpdateForWorker,
};
use nativelink_scheduler::simple_scheduler::SimpleScheduler;
use nativelink_scheduler::worker::{Worker, WorkerId};
Expand Down Expand Up @@ -1602,4 +1602,67 @@ mod scheduler_tests {

Ok(())
}

#[tokio::test]
async fn kill_running_action() -> Result<(), Error> {
const WORKER_ID: WorkerId = WorkerId(0x1234_5678_9111);

let scheduler = SimpleScheduler::new_with_callback(
&nativelink_config::schedulers::SimpleScheduler::default(),
|| async move {},
);
let action_digest = DigestInfo::new([99u8; 32], 512);

let mut rx_from_worker =
setup_new_worker(&scheduler, WORKER_ID, PlatformProperties::default()).await?;
let insert_timestamp = make_system_time(1);
let client_rx = setup_action(
&scheduler,
action_digest,
PlatformProperties::default(),
insert_timestamp,
)
.await?;

// Drop our receiver and look up a new one.
let unique_qualifier = client_rx.borrow().unique_qualifier.clone();

{
// Worker should have been sent an execute command.
let expected_msg_for_worker = UpdateForWorker {
update: Some(update_for_worker::Update::StartAction(StartExecute {
execute_request: Some(ExecuteRequest {
instance_name: INSTANCE_NAME.to_string(),
skip_cache_lookup: true,
action_digest: Some(action_digest.into()),
digest_function: digest_function::Value::Sha256.into(),
..Default::default()
}),
salt: 0,
queued_timestamp: Some(insert_timestamp.into()),
})),
};
let msg_for_worker = rx_from_worker.recv().await.unwrap();
assert_eq!(msg_for_worker, expected_msg_for_worker);
}

assert_eq!(
scheduler
.kill_running_action_on_worker(&WORKER_ID, &unique_qualifier)
.await,
Ok(())
);

assert_eq!(
Some(UpdateForWorker {
update: Some(update_for_worker::Update::KillActionRequest(
KillActionRequest {
action_id: hex::encode(unique_qualifier.get_hash())
}
))
}),
rx_from_worker.recv().await
);
Ok(())
}
}
29 changes: 10 additions & 19 deletions nativelink-worker/src/local_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,25 +212,16 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
}
Update::KillActionRequest(kill_action_request) => {
let mut action_id = [0u8; 32];
let decode_result = hex::decode_to_slice(kill_action_request.action_id, &mut action_id as &mut [u8]);

if let Err(e) = decode_result {
return Err(make_input_err!(
"KillActionRequest failed to decode ActionId hex with error {}",
e
));
}
let kill_action_result = self.running_actions_manager
.kill_action(action_id)
.await;

if let Err(e) = kill_action_result {
return Err(make_input_err!(
"Kill action {} failed with error - {}",
hex::encode(action_id),
e
));
}
hex::decode_to_slice(kill_action_request.action_id, &mut action_id as &mut [u8])
.map_err(|e| make_input_err!(
"KillActionRequest failed to decode ActionId hex with error {}",
e
))?;

self.running_actions_manager
.kill_action(action_id)
.await
.err_tip(|| format!("Failed to send kill request for action {}", hex::encode(action_id)))?
}
Update::StartAction(start_execute) => {
self.metrics.start_actions_received.inc();
Expand Down
19 changes: 8 additions & 11 deletions nativelink-worker/src/running_actions_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1786,23 +1786,20 @@ impl RunningActionsManager for RunningActionsManagerImpl {
}

async fn kill_action(&self, action_id: ActionId) -> Result<(), Error> {

let upgrade_or_err: Result<Arc<RunningActionImpl>, Error> = {
let running_actions = self.running_actions.lock();
running_actions.get(&action_id).map_or(
Err(make_input_err!(
"Failed to get running action {}",
hex::encode(action_id)
)),
|action| {
match action.upgrade() {
Some(upgraded_action) => Ok(upgraded_action),
None => Err(make_input_err!(
"Failed to upgrade after getting action {}",
hex::encode(action_id)
))
}
}
|action| match action.upgrade() {
Some(upgraded_action) => Ok(upgraded_action),
None => Err(make_input_err!(
"Failed to upgrade after getting action {}",
hex::encode(action_id)
)),
},
)
};

Expand All @@ -1811,7 +1808,7 @@ impl RunningActionsManager for RunningActionsManagerImpl {
Self::kill_action(upgrade).await;
Ok(())
}
Err(e) => Err(e)
Err(e) => Err(e),
}
}

Expand Down

0 comments on commit e13020b

Please sign in to comment.