Skip to content

Commit

Permalink
Implement API to drain a specific worker (#413)
Browse files Browse the repository at this point in the history
  • Loading branch information
steedmicro authored Dec 5, 2023
1 parent 8b49bac commit fbf1e44
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 11 deletions.
3 changes: 2 additions & 1 deletion native-link-config/examples/basic_cas.json
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@
// are a frontend api.
"worker_api": {
"scheduler": "MAIN_SCHEDULER",
}
},
"admin": {}
}
}],
"global": {
Expand Down
15 changes: 15 additions & 0 deletions native-link-config/src/cas_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,17 @@ pub struct PrometheusConfig {
pub path: String,
}

#[derive(Deserialize, Debug, Default)]
pub struct AdminConfig {
/// Path to register the admin API. If path is "/admin", and your
/// domain is "example.com", you can reach the endpoint with:
/// <http://example.com/admin>.
///
/// Default: "/admin"
#[serde(default)]
pub path: String,
}

#[derive(Deserialize, Debug)]
pub struct ServicesConfig {
/// The Content Addressable Storage (CAS) backend config.
Expand Down Expand Up @@ -189,6 +200,10 @@ pub struct ServicesConfig {
/// Prometheus metrics configuration. Metrics are gathered as a singleton
/// but may be served on multiple endpoints.
pub prometheus: Option<PrometheusConfig>,

/// This is the service for any administrative tasks.
/// It provides a REST API endpoint for administrative purposes.
pub admin: Option<AdminConfig>,
}

#[derive(Deserialize, Debug)]
Expand Down
54 changes: 44 additions & 10 deletions native-link-scheduler/src/simple_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,14 @@ impl Workers {
assert!(matches!(awaited_action.current_state.stage, ActionStage::Queued));
let action_properties = &awaited_action.action_info.platform_properties;
let mut workers_iter = self.workers.iter_mut();
let workers_iter =
match self.allocation_strategy {
// Use rfind to get the least recently used that satisfies the properties.
WorkerAllocationStrategy::LeastRecentlyUsed => workers_iter
.rfind(|(_, w)| !w.is_paused && action_properties.is_satisfied_by(&w.platform_properties)),
// Use find to get the most recently used that satisfies the properties.
WorkerAllocationStrategy::MostRecentlyUsed => workers_iter
.find(|(_, w)| !w.is_paused && action_properties.is_satisfied_by(&w.platform_properties)),
};
let workers_iter = match self.allocation_strategy {
// Use rfind to get the least recently used that satisfies the properties.
WorkerAllocationStrategy::LeastRecentlyUsed => workers_iter
.rfind(|(_, w)| w.can_accept_work() && action_properties.is_satisfied_by(&w.platform_properties)),
// Use find to get the most recently used that satisfies the properties.
WorkerAllocationStrategy::MostRecentlyUsed => workers_iter
.find(|(_, w)| w.can_accept_work() && action_properties.is_satisfied_by(&w.platform_properties)),
};
let worker_id = workers_iter.map(|(_, w)| &w.id);
// We need to "touch" the worker to ensure it gets re-ordered in the LRUCache, since it was selected.
if let Some(&worker_id) = worker_id {
Expand Down Expand Up @@ -385,6 +384,19 @@ impl SimpleSchedulerImpl {
self.tasks_or_workers_change_notify.notify_one();
}

/// Sets if the worker is draining or not.
fn set_drain_worker(&mut self, worker_id: WorkerId, is_draining: bool) -> Result<(), Error> {
let worker = self
.workers
.workers
.get_mut(&worker_id)
.err_tip(|| format!("Worker {worker_id} doesn't exist in the pool"))?;
self.metrics.workers_drained.inc();
worker.is_draining = is_draining;
self.tasks_or_workers_change_notify.notify_one();
Ok(())
}

// TODO(blaise.bruer) This is an O(n*m) (aka n^2) algorithm. In theory we can create a map
// of capabilities of each worker and then try and match the actions to the worker using
// the map lookup (ie. map reduce).
Expand Down Expand Up @@ -486,7 +498,7 @@ impl SimpleSchedulerImpl {

// Clear this action from the current worker.
if let Some(worker) = self.workers.workers.get_mut(worker_id) {
let was_paused = worker.is_paused;
let was_paused = !worker.can_accept_work();
// This unpauses, but since we're completing with an error, don't
// unpause unless all actions have completed.
worker.complete_action(&action_info);
Expand Down Expand Up @@ -678,6 +690,17 @@ impl SimpleScheduler {
inner.workers.workers.contains(worker_id)
}

/// Checks to see if the worker can accept work. Should only be used in unit tests.
pub fn can_worker_accept_work_for_test(&self, worker_id: &WorkerId) -> Result<bool, Error> {
let mut inner = self.get_inner_lock();
let worker = inner
.workers
.workers
.get_mut(worker_id)
.ok_or_else(|| make_input_err!("WorkerId '{}' does not exist in workers map", worker_id))?;
Ok(worker.can_accept_work())
}

/// A unit test function used to send the keep alive message to the worker from the server.
pub fn send_keep_alive_to_worker_for_test(&self, worker_id: &WorkerId) -> Result<(), Error> {
let mut inner = self.get_inner_lock();
Expand Down Expand Up @@ -827,6 +850,11 @@ impl WorkerScheduler for SimpleScheduler {
})
}

async fn set_drain_worker(&self, worker_id: WorkerId, is_draining: bool) -> Result<(), Error> {
let mut inner = self.get_inner_lock();
inner.set_drain_worker(worker_id, is_draining)
}

fn register_metrics(self: Arc<Self>, _registry: &mut Registry) {
// We do not register anything here because we only want to register metrics
// once and we rely on the `ActionScheduler::register_metrics()` to do that.
Expand Down Expand Up @@ -979,6 +1007,7 @@ struct Metrics {
update_action_with_internal_error_from_wrong_worker: CounterWithTime,
workers_evicted: CounterWithTime,
workers_evicted_with_running_action: CounterWithTime,
workers_drained: CounterWithTime,
retry_action: CounterWithTime,
retry_action_max_attempts_reached: CounterWithTime,
retry_action_no_more_listeners: CounterWithTime,
Expand Down Expand Up @@ -1083,6 +1112,11 @@ impl Metrics {
&self.workers_evicted_with_running_action,
"The number of jobs cancelled because worker was evicted from scheduler.",
);
c.publish(
"workers_drained_total",
&self.workers_drained,
"The number of workers drained from scheduler.",
);
{
c.publish_with_labels(
"retry_action",
Expand Down
14 changes: 14 additions & 0 deletions native-link-scheduler/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ pub struct Worker {
/// Whether the worker rejected the last action due to back pressure.
pub is_paused: bool,

/// Whether the worker is draining.
pub is_draining: bool,

/// Stats about the worker.
metrics: Arc<Metrics>,
}
Expand Down Expand Up @@ -133,6 +136,7 @@ impl Worker {
running_action_infos: HashSet::new(),
last_update_timestamp: timestamp,
is_paused: false,
is_draining: false,
metrics: Arc::new(Metrics {
connected_timestamp: SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(),
actions_completed: CounterWithTime::default(),
Expand Down Expand Up @@ -215,6 +219,10 @@ impl Worker {
}
}
}

pub fn can_accept_work(&self) -> bool {
!self.is_paused && !self.is_draining
}
}

impl PartialEq for Worker {
Expand Down Expand Up @@ -280,6 +288,12 @@ impl MetricsComponent for Worker {
"If this worker is paused.",
vec![("worker_id".into(), format!("{}", self.id).into())],
);
c.publish_with_labels(
"is_draining",
&self.is_draining,
"If this worker is draining.",
vec![("worker_id".into(), format!("{}", self.id).into())],
);
for action_info in self.running_action_infos.iter() {
let action_name = action_info.unique_qualifier.action_name().to_string();
c.publish_with_labels(
Expand Down
3 changes: 3 additions & 0 deletions native-link-scheduler/src/worker_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ pub trait WorkerScheduler: Sync + Send + Unpin {
/// external source.
async fn remove_timedout_workers(&self, now_timestamp: WorkerTimestamp) -> Result<(), Error>;

/// Sets if the worker is draining or not.
async fn set_drain_worker(&self, worker_id: WorkerId, is_draining: bool) -> Result<(), Error>;

/// Register the metrics for the worker scheduler.
fn register_metrics(self: Arc<Self>, _registry: &mut Registry) {}
}
73 changes: 73 additions & 0 deletions native-link-scheduler/tests/simple_scheduler_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,79 @@ mod scheduler_tests {
Ok(())
}

#[tokio::test]
async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> {
const WORKER_ID: WorkerId = WorkerId(0x1234_5678_9111);

let scheduler = SimpleScheduler::new_with_callback(
&native_link_config::schedulers::SimpleScheduler::default(),
|| async move {},
);
let action_digest = DigestInfo::new([99u8; 32], 512);

let mut rx_from_worker = setup_new_worker(&scheduler, WORKER_ID, PlatformProperties::default()).await?;
let insert_timestamp = make_system_time(1);
let mut client_rx = setup_action(
&scheduler,
action_digest,
PlatformProperties::default(),
insert_timestamp,
)
.await?;

{
// Other tests check full data. We only care if we got StartAction.
match rx_from_worker.recv().await.unwrap().update {
Some(update_for_worker::Update::StartAction(_)) => { /* Success */ }
v => panic!("Expected StartAction, got : {v:?}"),
}
// Other tests check full data. We only care if client thinks we are Executing.
assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing);
}

// Set the worker draining.
scheduler.set_drain_worker(WORKER_ID, true).await?;
tokio::task::yield_now().await;

let action_digest = DigestInfo::new([88u8; 32], 512);
let insert_timestamp = make_system_time(14);
let mut client_rx = setup_action(
&scheduler,
action_digest,
PlatformProperties::default(),
insert_timestamp,
)
.await?;

{
// Client should get notification saying it's been queued.
let action_state = client_rx.borrow_and_update();
let expected_action_state = ActionState {
// Name is a random string, so we ignore it and just make it the same.
unique_qualifier: action_state.unique_qualifier.clone(),
stage: ActionStage::Queued,
};
assert_eq!(action_state.as_ref(), &expected_action_state);
}

// Set the worker not draining.
scheduler.set_drain_worker(WORKER_ID, false).await?;
tokio::task::yield_now().await;

{
// Client should get notification saying it's being executed.
let action_state = client_rx.borrow_and_update();
let expected_action_state = ActionState {
// Name is a random string, so we ignore it and just make it the same.
unique_qualifier: action_state.unique_qualifier.clone(),
stage: ActionStage::Executing,
};
assert_eq!(action_state.as_ref(), &expected_action_state);
}

Ok(())
}

#[tokio::test]
async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), Error> {
const WORKER_ID1: WorkerId = WorkerId(0x0010_0001);
Expand Down
47 changes: 47 additions & 0 deletions src/bin/cas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use native_link_config::cas_server::{
CasConfig, CompressionAlgorithm, ConfigDigestHashFunction, GlobalConfig, ServerConfig, WorkerConfig,
};
use native_link_scheduler::default_scheduler_factory::scheduler_factory;
use native_link_scheduler::worker::WorkerId;
use native_link_service::ac_server::AcServer;
use native_link_service::bytestream_server::ByteStreamServer;
use native_link_service::capabilities_server::CapabilitiesServer;
Expand Down Expand Up @@ -58,6 +59,9 @@ use tower::util::ServiceExt;
/// Note: This must be kept in sync with the documentation in `PrometheusConfig::path`.
const DEFAULT_PROMETHEUS_METRICS_PATH: &str = "/metrics";

/// Note: This must be kept in sync with the documentation in `AdminConfig::path`.
const DEFAULT_ADMIN_API_PATH: &str = "/admin";

/// Name of environment variable to disable metrics.
const METRICS_DISABLE_ENV: &str = "NATIVE_LINK_DISABLE_METRICS";

Expand Down Expand Up @@ -390,6 +394,49 @@ async fn inner_main(cfg: CasConfig, server_start_timestamp: u64) -> Result<(), B
)
}

if let Some(admin_config) = services.admin {
let path = if admin_config.path.is_empty() {
DEFAULT_ADMIN_API_PATH
} else {
&admin_config.path
};
let worker_schedulers = Arc::new(worker_schedulers.clone());
svc = svc.nest_service(
path,
Router::new().route(
"/scheduler/:instance_name/set_drain_worker/:worker_id/:is_draining",
axum::routing::post(
move |params: axum::extract::Path<(String, String, String)>| async move {
let (instance_name, worker_id, is_draining) = params.0;
(async move {
let is_draining = match is_draining.as_str() {
"0" => false,
"1" => true,
_ => return Err(make_err!(Code::Internal, "{} is neither 0 nor 1", is_draining)),
};
worker_schedulers
.get(&instance_name)
.err_tip(|| {
format!("Can not get an instance with the name of '{}'", &instance_name)
})?
.clone()
.set_drain_worker(WorkerId::try_from(worker_id.clone())?, is_draining)
.await?;
Ok::<_, Error>(format!("Draining worker {worker_id}"))
})
.await
.map_err(|e| {
Err::<String, _>((
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
format!("Error: {e:?}"),
))
})
},
),
),
)
}

// Configure our TLS acceptor if we have TLS configured.
let maybe_tls_acceptor = server_cfg.tls.map_or(Ok(None), |tls_config| {
let mut cert_reader = std::io::BufReader::new(
Expand Down

0 comments on commit fbf1e44

Please sign in to comment.