Skip to content

Commit

Permalink
Implement kill action requests on scheduler
Browse files Browse the repository at this point in the history
Implements scheduler api for sending a kill action request to a worker.
Allows for action cancellation during scenarios such as client
disconnection.

towards TraceMachina#338
  • Loading branch information
Zach Birenbaum committed Apr 8, 2024
1 parent 6c4fb7e commit 62c2ba7
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 3 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions nativelink-scheduler/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ rust_library(
"@crates//:blake3",
"@crates//:futures",
"@crates//:hashbrown",
"@crates//:hex",
"@crates//:lru",
"@crates//:parking_lot",
"@crates//:prost",
Expand Down Expand Up @@ -71,6 +72,7 @@ rust_test_suite(
"//nativelink-store",
"//nativelink-util",
"@crates//:futures",
"@crates//:hex",
"@crates//:pretty_assertions",
"@crates//:prost",
"@crates//:tokio",
Expand Down
1 change: 1 addition & 0 deletions nativelink-scheduler/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ tokio = { version = "1.37.0", features = ["sync", "rt", "parking_lot"] }
tokio-stream = { version = "0.1.15", features = ["sync"] }
tonic = { version = "0.11.0", features = ["gzip", "tls"] }
tracing = "0.1.40"
hex = "0.4.3"

[dev-dependencies]
pretty_assertions = "1.4.0"
16 changes: 16 additions & 0 deletions nativelink-scheduler/src/simple_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,22 @@ impl WorkerScheduler for SimpleScheduler {
inner.set_drain_worker(worker_id, is_draining)
}

async fn kill_running_action_on_worker(
&self,
worker_id: &WorkerId,
unique_qualifier: &ActionInfoHashKey,
) -> Result<(), Error> {
let mut inner = self.get_inner_lock();
let maybe_worker = inner.workers.workers.get_mut(worker_id);
match maybe_worker {
Some(worker) => worker.send_kill_action_request(unique_qualifier),
None => Err(make_err!(
Code::Internal,
"Worker {worker_id} timed out, removing from pool"
)),
}
}

fn register_metrics(self: Arc<Self>, _registry: &mut Registry) {
// We do not register anything here because we only want to register metrics
// once and we rely on the `ActionScheduler::register_metrics()` to do that.
Expand Down
16 changes: 14 additions & 2 deletions nativelink-scheduler/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ use std::time::{SystemTime, UNIX_EPOCH};

use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt};
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{
update_for_worker, ConnectionResult, StartExecute, UpdateForWorker,
update_for_worker, ConnectionResult, KillActionRequest, StartExecute, UpdateForWorker,
};
use nativelink_util::action_messages::ActionInfo;
use nativelink_util::action_messages::{ActionInfo, ActionInfoHashKey};
use nativelink_util::metrics_utils::{
CollectorState, CounterWithTime, FuncCounterWrapper, MetricsComponent,
};
Expand Down Expand Up @@ -160,6 +160,18 @@ impl Worker {
}
}

pub fn send_kill_action_request(
&mut self,
unique_qualifier: &ActionInfoHashKey,
) -> Result<(), Error> {
send_msg_to_worker(
&mut self.tx,
update_for_worker::Update::KillActionRequest(KillActionRequest {
action_id: hex::encode(unique_qualifier.get_hash()),
}),
)
}

/// Sends the initial connection information to the worker. This generally is just meta info.
/// This should only be sent once and should always be the first item in the stream.
pub fn send_initial_connection_result(&mut self) -> Result<(), Error> {
Expand Down
7 changes: 7 additions & 0 deletions nativelink-scheduler/src/worker_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ pub trait WorkerScheduler: Sync + Send + Unpin {
timestamp: WorkerTimestamp,
) -> Result<(), Error>;

/// Send a request to kill a running action
async fn kill_running_action_on_worker(
&self,
worker_id: &WorkerId,
unique_qualifier: &ActionInfoHashKey,
) -> Result<(), Error>;

/// Removes worker from pool and reschedule any tasks that might be running on it.
async fn remove_worker(&self, worker_id: WorkerId);

Expand Down
65 changes: 64 additions & 1 deletion nativelink-scheduler/tests/simple_scheduler_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ mod utils {
}
use nativelink_proto::build::bazel::remote::execution::v2::{digest_function, ExecuteRequest};
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{
update_for_worker, ConnectionResult, StartExecute, UpdateForWorker,
update_for_worker, ConnectionResult, KillActionRequest, StartExecute, UpdateForWorker,
};
use nativelink_scheduler::simple_scheduler::SimpleScheduler;
use nativelink_scheduler::worker::{Worker, WorkerId};
Expand Down Expand Up @@ -1602,4 +1602,67 @@ mod scheduler_tests {

Ok(())
}

#[tokio::test]
async fn kill_running_action() -> Result<(), Error> {
const WORKER_ID: WorkerId = WorkerId(0x1234_5678_9111);

let scheduler = SimpleScheduler::new_with_callback(
&nativelink_config::schedulers::SimpleScheduler::default(),
|| async move {},
);
let action_digest = DigestInfo::new([99u8; 32], 512);

let mut rx_from_worker =
setup_new_worker(&scheduler, WORKER_ID, PlatformProperties::default()).await?;
let insert_timestamp = make_system_time(1);
let client_rx = setup_action(
&scheduler,
action_digest,
PlatformProperties::default(),
insert_timestamp,
)
.await?;

// Drop our receiver and look up a new one.
let unique_qualifier = client_rx.borrow().unique_qualifier.clone();

{
// Worker should have been sent an execute command.
let expected_msg_for_worker = UpdateForWorker {
update: Some(update_for_worker::Update::StartAction(StartExecute {
execute_request: Some(ExecuteRequest {
instance_name: INSTANCE_NAME.to_string(),
skip_cache_lookup: true,
action_digest: Some(action_digest.into()),
digest_function: digest_function::Value::Sha256.into(),
..Default::default()
}),
salt: 0,
queued_timestamp: Some(insert_timestamp.into()),
})),
};
let msg_for_worker = rx_from_worker.recv().await.unwrap();
assert_eq!(msg_for_worker, expected_msg_for_worker);
}

assert_eq!(
scheduler
.kill_running_action_on_worker(&WORKER_ID, &unique_qualifier)
.await,
Ok(())
);

assert_eq!(
Some(UpdateForWorker {
update: Some(update_for_worker::Update::KillActionRequest(
KillActionRequest {
action_id: hex::encode(unique_qualifier.get_hash())
}
))
}),
rx_from_worker.recv().await
);
Ok(())
}
}

0 comments on commit 62c2ba7

Please sign in to comment.