Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix case when scheduler drops action on client reconnect #1198

Merged
merged 1 commit into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion nativelink-config/src/stores.rs
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,8 @@ pub struct EvictionPolicy {
#[serde(default, deserialize_with = "convert_data_size_with_shellexpand")]
pub evict_bytes: usize,

/// Maximum number of seconds for an entry to live before an eviction.
/// Maximum number of seconds for an entry to live since it was last
/// accessed before it is evicted.
/// Default: 0. Zero means never evict based on time.
#[serde(default, deserialize_with = "convert_duration_with_shellexpand")]
pub max_seconds: u32,
Expand Down
1 change: 1 addition & 0 deletions nativelink-scheduler/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ rust_test_suite(
"//nativelink-store",
"//nativelink-util",
"@crates//:futures",
"@crates//:mock_instant",
"@crates//:pretty_assertions",
"@crates//:prost",
"@crates//:tokio",
Expand Down
1 change: 1 addition & 0 deletions nativelink-scheduler/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ uuid = { version = "1.8.0", features = ["v4"] }
futures = "0.3.30"
hashbrown = "0.14"
lru = "0.12.3"
mock_instant = "0.3.2"
parking_lot = "0.12.2"
rand = "0.8.5"
scopeguard = "1.2.0"
Expand Down
146 changes: 92 additions & 54 deletions nativelink-scheduler/src/memory_awaited_action_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::ops::{Bound, RangeBounds};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use std::time::Duration;

use async_lock::Mutex;
use async_trait::async_trait;
Expand All @@ -28,6 +28,7 @@ use nativelink_util::action_messages::{
};
use nativelink_util::chunked_stream::ChunkedStream;
use nativelink_util::evicting_map::{EvictingMap, LenEntry};
use nativelink_util::instant_wrapper::InstantWrapper;
use nativelink_util::metrics_utils::{CollectorState, MetricsComponent};
use nativelink_util::operation_state_manager::ActionStateResult;
use nativelink_util::spawn;
Expand All @@ -48,21 +49,21 @@ const CLIENT_KEEPALIVE_DURATION: Duration = Duration::from_secs(10);

/// Represents a client that is currently listening to an action.
/// When the client is dropped, it will send the [`AwaitedAction`] to the
/// `drop_tx` if there are other cleanups needed.
/// `event_tx` if there are other cleanups needed.
#[derive(Debug)]
struct ClientAwaitedAction {
/// The OperationId that the client is listening to.
operation_id: OperationId,

/// The sender to notify of this struct being dropped.
drop_tx: mpsc::UnboundedSender<ActionEvent>,
event_tx: mpsc::UnboundedSender<ActionEvent>,
}

impl ClientAwaitedAction {
pub fn new(operation_id: OperationId, drop_tx: mpsc::UnboundedSender<ActionEvent>) -> Self {
pub fn new(operation_id: OperationId, event_tx: mpsc::UnboundedSender<ActionEvent>) -> Self {
Self {
operation_id,
drop_tx,
event_tx,
}
}

Expand All @@ -74,7 +75,7 @@ impl ClientAwaitedAction {
impl Drop for ClientAwaitedAction {
fn drop(&mut self) {
// If we failed to send it means noone is listening.
let _ = self.drop_tx.send(ActionEvent::ClientDroppedOperation(
let _ = self.event_tx.send(ActionEvent::ClientDroppedOperation(
self.operation_id.clone(),
));
}
Expand Down Expand Up @@ -105,50 +106,61 @@ pub(crate) enum ActionEvent {

/// Information required to track an individual client
/// keep alive config and state.
struct ClientKeepAlive {
struct ClientInfo<I: InstantWrapper, NowFn: Fn() -> I> {
/// The client operation id.
client_operation_id: ClientOperationId,
/// The last time a keep alive was sent.
last_keep_alive: Instant,
/// The sender to notify of this struct being dropped.
drop_tx: mpsc::UnboundedSender<ActionEvent>,
last_keep_alive: I,
/// The function to get the current time.
now_fn: NowFn,
/// The sender to notify of this struct had an event.
event_tx: mpsc::UnboundedSender<ActionEvent>,
}

/// Subscriber that can be used to monitor when AwaitedActions change.
pub struct MemoryAwaitedActionSubscriber {
/// Subscriber that clients can be used to monitor when AwaitedActions change.
pub struct MemoryAwaitedActionSubscriber<I: InstantWrapper, NowFn: Fn() -> I> {
/// The receiver to listen for changes.
awaited_action_rx: watch::Receiver<AwaitedAction>,
/// The client operation id and keep alive information.
client_operation_info: Option<ClientKeepAlive>,
/// If a client id is known this is the info needed to keep the client
/// action alive.
client_info: Option<ClientInfo<I, NowFn>>,
}

impl MemoryAwaitedActionSubscriber {
impl<I: InstantWrapper, NowFn: Fn() -> I> MemoryAwaitedActionSubscriber<I, NowFn> {
pub fn new(mut awaited_action_rx: watch::Receiver<AwaitedAction>) -> Self {
awaited_action_rx.mark_changed();
Self {
awaited_action_rx,
client_operation_info: None,
client_info: None,
}
}

pub fn new_with_client(
mut awaited_action_rx: watch::Receiver<AwaitedAction>,
client_operation_id: ClientOperationId,
drop_tx: mpsc::UnboundedSender<ActionEvent>,
) -> Self {
event_tx: mpsc::UnboundedSender<ActionEvent>,
now_fn: NowFn,
) -> Self
where
NowFn: Fn() -> I,
{
awaited_action_rx.mark_changed();
Self {
awaited_action_rx,
client_operation_info: Some(ClientKeepAlive {
client_info: Some(ClientInfo {
client_operation_id,
last_keep_alive: Instant::now(),
drop_tx,
last_keep_alive: I::from_secs(0),
now_fn,
event_tx,
}),
}
}
}

impl AwaitedActionSubscriber for MemoryAwaitedActionSubscriber {
impl<I, NowFn> AwaitedActionSubscriber for MemoryAwaitedActionSubscriber<I, NowFn>
where
I: InstantWrapper,
NowFn: Fn() -> I + Send + Sync + 'static,
{
async fn changed(&mut self) -> Result<AwaitedAction, Error> {
{
let changed_fut = self.awaited_action_rx.changed().map(|r| {
Expand All @@ -159,25 +171,26 @@ impl AwaitedActionSubscriber for MemoryAwaitedActionSubscriber {
)
})
});
let Some(client_keep_alive) = self.client_operation_info.as_mut() else {
let Some(client_info) = self.client_info.as_mut() else {
changed_fut.await?;
return Ok(self.awaited_action_rx.borrow().clone());
};
tokio::pin!(changed_fut);
loop {
if client_keep_alive.last_keep_alive.elapsed() > CLIENT_KEEPALIVE_DURATION {
client_keep_alive.last_keep_alive = Instant::now();
if client_info.last_keep_alive.elapsed() > CLIENT_KEEPALIVE_DURATION {
client_info.last_keep_alive = (client_info.now_fn)();
// Failing to send just means our receiver dropped.
let _ = client_keep_alive.drop_tx.send(ActionEvent::ClientKeepAlive(
client_keep_alive.client_operation_id.clone(),
let _ = client_info.event_tx.send(ActionEvent::ClientKeepAlive(
client_info.client_operation_id.clone(),
));
}
let sleep_fut = (client_info.now_fn)().sleep(CLIENT_KEEPALIVE_DURATION);
tokio::select! {
result = &mut changed_fut => {
result?;
break;
}
_ = tokio::time::sleep(CLIENT_KEEPALIVE_DURATION) => {
_ = sleep_fut => {
// If we haven't received any updates for a while, we should
// let the database know that we are still listening to prevent
// the action from being dropped.
Expand Down Expand Up @@ -329,10 +342,9 @@ impl SortedAwaitedActions {
}

/// The database for storing the state of all actions.
pub struct AwaitedActionDbImpl {
pub struct AwaitedActionDbImpl<I: InstantWrapper, NowFn: Fn() -> I> {
/// A lookup table to lookup the state of an action by its client operation id.
client_operation_to_awaited_action:
EvictingMap<ClientOperationId, Arc<ClientAwaitedAction>, SystemTime>,
client_operation_to_awaited_action: EvictingMap<ClientOperationId, Arc<ClientAwaitedAction>, I>,

/// A lookup table to lookup the state of an action by its worker operation id.
operation_id_to_awaited_action: BTreeMap<OperationId, watch::Sender<AwaitedAction>>,
Expand All @@ -351,13 +363,16 @@ pub struct AwaitedActionDbImpl {

/// Where to send notifications about important events related to actions.
action_event_tx: mpsc::UnboundedSender<ActionEvent>,

/// The function to get the current time.
now_fn: NowFn,
}

impl AwaitedActionDbImpl {
impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync> AwaitedActionDbImpl<I, NowFn> {
async fn get_awaited_action_by_id(
&self,
client_operation_id: &ClientOperationId,
) -> Result<Option<MemoryAwaitedActionSubscriber>, Error> {
) -> Result<Option<MemoryAwaitedActionSubscriber<I, NowFn>>, Error> {
let maybe_client_awaited_action = self
.client_operation_to_awaited_action
.get(client_operation_id)
Expand All @@ -369,7 +384,14 @@ impl AwaitedActionDbImpl {

self.operation_id_to_awaited_action
.get(client_awaited_action.operation_id())
.map(|tx| Some(MemoryAwaitedActionSubscriber::new(tx.subscribe())))
.map(|tx| {
Some(MemoryAwaitedActionSubscriber::new_with_client(
tx.subscribe(),
client_operation_id.clone(),
self.action_event_tx.clone(),
self.now_fn.clone(),
))
})
.ok_or_else(|| {
make_err!(
Code::Internal,
Expand Down Expand Up @@ -487,32 +509,38 @@ impl AwaitedActionDbImpl {
&self,
start: Bound<&OperationId>,
end: Bound<&OperationId>,
) -> impl Iterator<Item = (&'_ OperationId, MemoryAwaitedActionSubscriber)> {
) -> impl Iterator<Item = (&'_ OperationId, MemoryAwaitedActionSubscriber<I, NowFn>)> {
self.operation_id_to_awaited_action
.range((start, end))
.map(|(operation_id, tx)| {
(
operation_id,
MemoryAwaitedActionSubscriber::new(tx.subscribe()),
MemoryAwaitedActionSubscriber::<I, NowFn>::new(tx.subscribe()),
)
})
}

fn get_by_operation_id(
&self,
operation_id: &OperationId,
) -> Option<MemoryAwaitedActionSubscriber> {
) -> Option<MemoryAwaitedActionSubscriber<I, NowFn>> {
self.operation_id_to_awaited_action
.get(operation_id)
.map(|tx| MemoryAwaitedActionSubscriber::new(tx.subscribe()))
.map(|tx| MemoryAwaitedActionSubscriber::<I, NowFn>::new(tx.subscribe()))
}

fn get_range_of_actions<'a, 'b>(
&'a self,
state: SortedAwaitedActionState,
range: impl RangeBounds<SortedAwaitedAction> + 'b,
) -> impl DoubleEndedIterator<
Item = Result<(&'a SortedAwaitedAction, MemoryAwaitedActionSubscriber), Error>,
Item = Result<
(
&'a SortedAwaitedAction,
MemoryAwaitedActionSubscriber<I, NowFn>,
),
Error,
>,
> + 'a {
let btree = match state {
SortedAwaitedActionState::CacheCheck => &self.sorted_action_info_hash_keys.cache_check,
Expand Down Expand Up @@ -674,7 +702,7 @@ impl AwaitedActionDbImpl {
&mut self,
client_operation_id: ClientOperationId,
action_info: Arc<ActionInfo>,
) -> Result<MemoryAwaitedActionSubscriber, Error> {
) -> Result<MemoryAwaitedActionSubscriber<I, NowFn>, Error> {
// Check to see if the action is already known and subscribe if it is.
let subscription_result = self
.try_subscribe(
Expand Down Expand Up @@ -738,6 +766,7 @@ impl AwaitedActionDbImpl {
rx,
client_operation_id,
self.action_event_tx.clone(),
self.now_fn.clone(),
))
}

Expand All @@ -749,7 +778,7 @@ impl AwaitedActionDbImpl {
// removed the ability to upgrade priorities of actions.
// we should add priority upgrades back in.
_priority: i32,
) -> Result<Option<MemoryAwaitedActionSubscriber>, Error> {
) -> Result<Option<MemoryAwaitedActionSubscriber<I, NowFn>>, Error> {
let unique_key = match unique_qualifier {
ActionUniqueQualifier::Cachable(unique_key) => unique_key,
ActionUniqueQualifier::Uncachable(_unique_key) => return Ok(None),
Expand Down Expand Up @@ -795,28 +824,33 @@ impl AwaitedActionDbImpl {
)
.await;

Ok(Some(MemoryAwaitedActionSubscriber::new(subscription)))
Ok(Some(MemoryAwaitedActionSubscriber::new_with_client(
subscription,
client_operation_id.clone(),
self.action_event_tx.clone(),
self.now_fn.clone(),
)))
}
}

pub struct MemoryAwaitedActionDb {
inner: Arc<Mutex<AwaitedActionDbImpl>>,
pub struct MemoryAwaitedActionDb<I: InstantWrapper, NowFn: Fn() -> I> {
inner: Arc<Mutex<AwaitedActionDbImpl<I, NowFn>>>,
_handle_awaited_action_events: JoinHandleDropGuard<()>,
}

impl MemoryAwaitedActionDb {
pub fn new(eviction_config: &EvictionPolicy) -> Self {
impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync + 'static>
MemoryAwaitedActionDb<I, NowFn>
{
pub fn new(eviction_config: &EvictionPolicy, now_fn: NowFn) -> Self {
let (action_event_tx, mut action_event_rx) = mpsc::unbounded_channel();
let inner = Arc::new(Mutex::new(AwaitedActionDbImpl {
client_operation_to_awaited_action: EvictingMap::new(
eviction_config,
SystemTime::now(),
),
client_operation_to_awaited_action: EvictingMap::new(eviction_config, (now_fn)()),
operation_id_to_awaited_action: BTreeMap::new(),
action_info_hash_key_to_awaited_action: HashMap::new(),
sorted_action_info_hash_keys: SortedAwaitedActions::default(),
connected_clients_for_operation_id: HashMap::new(),
action_event_tx,
now_fn,
}));
let weak_inner = Arc::downgrade(&inner);
Self {
Expand All @@ -841,8 +875,10 @@ impl MemoryAwaitedActionDb {
}
}

impl AwaitedActionDb for MemoryAwaitedActionDb {
type Subscriber = MemoryAwaitedActionSubscriber;
impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync + 'static> AwaitedActionDb
for MemoryAwaitedActionDb<I, NowFn>
{
type Subscriber = MemoryAwaitedActionSubscriber<I, NowFn>;

async fn get_awaited_action_by_id(
&self,
Expand Down Expand Up @@ -943,7 +979,9 @@ impl AwaitedActionDb for MemoryAwaitedActionDb {
}
}

impl MetricsComponent for MemoryAwaitedActionDb {
impl<I: InstantWrapper, NowFn: Fn() -> I + Send + Sync + 'static> MetricsComponent
for MemoryAwaitedActionDb<I, NowFn>
{
fn gather_metrics(&self, c: &mut CollectorState) {
let inner = self.inner.lock_blocking();
c.publish(
Expand Down
Loading