Skip to content

Commit

Permalink
Force spawned tasks to accept a cancellation token
Browse files Browse the repository at this point in the history
  • Loading branch information
rakanalh committed Sep 24, 2024
1 parent 8c63cba commit b21f310
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 79 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

101 changes: 51 additions & 50 deletions crates/fullnode/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,42 +179,43 @@ where
.layer(citrea_common::rpc::get_cors_layer())
.layer(citrea_common::rpc::get_healthcheck_proxy_layer());

self.task_manager.spawn(async move {
let server = ServerBuilder::default()
.max_connections(max_connections)
.max_subscriptions_per_connection(max_subscriptions_per_connection)
.max_request_body_size(max_request_body_size)
.max_response_body_size(max_response_body_size)
.set_batch_request_config(BatchRequestConfig::Limit(batch_requests_limit))
.set_http_middleware(middleware)
.build([listen_address].as_ref())
.await;

match server {
Ok(server) => {
let bound_address = match server.local_addr() {
Ok(address) => address,
Err(e) => {
error!("{}", e);
return;
}
};
if let Some(channel) = channel {
if let Err(e) = channel.send(bound_address) {
error!("Could not send bound_address {}: {}", bound_address, e);
return;
self.task_manager
.spawn(move |cancellation_token| async move {
let server = ServerBuilder::default()
.max_connections(max_connections)
.max_subscriptions_per_connection(max_subscriptions_per_connection)
.max_request_body_size(max_request_body_size)
.max_response_body_size(max_response_body_size)
.set_batch_request_config(BatchRequestConfig::Limit(batch_requests_limit))
.set_http_middleware(middleware)
.build([listen_address].as_ref())
.await;

match server {
Ok(server) => {
let bound_address = match server.local_addr() {
Ok(address) => address,
Err(e) => {
error!("{}", e);
return;
}
};
if let Some(channel) = channel {
if let Err(e) = channel.send(bound_address) {
error!("Could not send bound_address {}: {}", bound_address, e);
return;
}
}
}
info!("Starting RPC server at {} ", &bound_address);
info!("Starting RPC server at {} ", &bound_address);

let _server_handle = server.start(methods);
futures::future::pending::<()>().await;
}
Err(e) => {
error!("Could not start RPC server: {}", e);
let _server_handle = server.start(methods);
cancellation_token.cancelled().await;
}
Err(e) => {
error!("Could not start RPC server: {}", e);
}
}
}
});
});
}

async fn process_l2_block(
Expand Down Expand Up @@ -325,22 +326,22 @@ where
let accept_public_input_as_proven = self.accept_public_input_as_proven;
let l1_block_cache = self.l1_block_cache.clone();

let cancellation_token = self.task_manager.child_token();
self.task_manager.spawn(async move {
let l1_block_handler = L1BlockHandler::<C, Vm, Da, Stf::StateRoot, DB>::new(
ledger_db,
da_service,
sequencer_pub_key,
sequencer_da_pub_key,
prover_da_pub_key,
code_commitments_by_spec,
accept_public_input_as_proven,
l1_block_cache.clone(),
);
l1_block_handler
.run(start_l1_height, cancellation_token)
.await
});
self.task_manager
.spawn(move |cancellation_token| async move {
let l1_block_handler = L1BlockHandler::<C, Vm, Da, Stf::StateRoot, DB>::new(
ledger_db,
da_service,
sequencer_pub_key,
sequencer_da_pub_key,
prover_da_pub_key,
code_commitments_by_spec,
accept_public_input_as_proven,
l1_block_cache.clone(),
);
l1_block_handler
.run(start_l1_height, cancellation_token)
.await
});

let (l2_tx, mut l2_rx) = mpsc::channel(1);
let l2_sync_worker = sync_l2::<Da>(
Expand Down Expand Up @@ -397,7 +398,7 @@ where
},
_ = signal::ctrl_c() => {
info!("Shutting down");
self.task_manager.abort();
self.task_manager.abort().await;
return Ok(());
}
}
Expand Down
25 changes: 18 additions & 7 deletions crates/primitives/src/tasks/manager.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use std::future::Future;
use std::time::Duration;

use tokio::task::JoinHandle;
use tokio::time::sleep;
use tokio_util::sync::CancellationToken;

const WAIT_DURATION: u64 = 5; // 5 seconds

/// TaskManager manages tasks spawned using tokio and keeps
/// track of handles so that these tasks are cancellable.
/// This provides a way to implement graceful shutdown of our
Expand All @@ -22,17 +26,24 @@ impl<T: Send + 'static> TaskManager<T> {
}

/// Spawn a new asynchronous task.
pub fn spawn(&mut self, future: impl Future<Output = T> + Send + 'static) {
let handle = tokio::spawn(future);
///
/// Tasks are forced to accept a cancellation token so that they can be notified
/// about the cancellation using the passed token.
pub fn spawn<F, Fut>(&mut self, callback: F)
where
F: FnOnce(CancellationToken) -> Fut,
Fut: Future<Output = T> + Send + 'static,
{
let handle = tokio::spawn(callback(self.child_token()));
self.handles.push(handle);
}

/// Drastically abort all running tasks
pub fn abort(&self) {
/// Notify all running tasks to stop.
pub async fn abort(&self) {
self.cancellation_token.cancel();
for handle in &self.handles {
handle.abort();
}

// provide tasks with some time to finish existing work
sleep(Duration::from_secs(WAIT_DURATION)).await;
}

/// Provides a child cancellation token.
Expand Down
9 changes: 4 additions & 5 deletions crates/prover/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ where
let middleware = tower::ServiceBuilder::new().layer(citrea_common::rpc::get_cors_layer());
// .layer(citrea_common::rpc::get_healthcheck_proxy_layer());

self.task_manager.spawn(async move {
self.task_manager.spawn(|cancellation_token| async move {
let server = ServerBuilder::default()
.max_connections(max_connections)
.max_subscriptions_per_connection(max_subscriptions_per_connection)
Expand Down Expand Up @@ -215,7 +215,7 @@ where
info!("Starting RPC server at {} ", &bound_address);

let _server_handle = server.start(methods);
futures::future::pending::<()>().await;
cancellation_token.cancelled().await;
}
Err(e) => {
error!("Could not start RPC server: {}", e);
Expand Down Expand Up @@ -251,8 +251,7 @@ where
let code_commitments_by_spec = self.code_commitments_by_spec.clone();
let l1_block_cache = self.l1_block_cache.clone();

let cancellation_token = self.task_manager.child_token();
self.task_manager.spawn(async move {
self.task_manager.spawn(|cancellation_token| async move {
let l1_block_handler =
L1BlockHandler::<Vm, Da, Ps, DB, Stf::StateRoot, Stf::Witness>::new(
prover_config,
Expand Down Expand Up @@ -327,7 +326,7 @@ where
},
_ = signal::ctrl_c() => {
info!("Shutting down");
self.task_manager.abort();
self.task_manager.abort().await;
return Ok(());
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/sequencer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ schnellru = "0.2.1"
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
tokio-util = { workspace = true }
tower = { workspace = true }
tower-http = { workspace = true }
tracing = { workspace = true }
Expand Down
46 changes: 29 additions & 17 deletions crates/sequencer/src/sequencer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ use tokio::signal;
use tokio::sync::mpsc;
use tokio::sync::oneshot::channel as oneshot_channel;
use tokio::time::{sleep, Instant};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, instrument, trace, warn};

use crate::commitment_controller;
Expand Down Expand Up @@ -219,7 +220,7 @@ where
let middleware = tower::ServiceBuilder::new().layer(citrea_common::rpc::get_cors_layer());
// .layer(citrea_common::rpc::get_healthcheck_proxy_layer());

self.task_manager.spawn(async move {
self.task_manager.spawn(|cancellation_token| async move {
let server = ServerBuilder::default()
.max_connections(max_connections)
.max_subscriptions_per_connection(max_subscriptions_per_connection)
Expand Down Expand Up @@ -248,7 +249,7 @@ where
info!("Starting RPC server at {} ", &bound_address);

let _server_handle = server.start(methods);
futures::future::pending::<()>().await;
cancellation_token.cancelled().await;
}
Err(e) => {
error!("Could not start RPC server: {}", e);
Expand Down Expand Up @@ -859,11 +860,14 @@ where
let (da_height_update_tx, mut da_height_update_rx) = mpsc::channel(1);
let (da_commitment_tx, mut da_commitment_rx) = unbounded::<bool>();

self.task_manager.spawn(da_block_monitor(
self.da_service.clone(),
da_height_update_tx,
self.config.da_update_interval_ms,
));
self.task_manager.spawn(|cancellation_token| {
da_block_monitor(
self.da_service.clone(),
da_height_update_tx,
self.config.da_update_interval_ms,
cancellation_token,
)
});

let target_block_time = Duration::from_millis(self.config.block_production_interval_ms);

Expand Down Expand Up @@ -969,7 +973,7 @@ where
},
_ = signal::ctrl_c() => {
info!("Shutting down sequencer");
self.task_manager.abort();
self.task_manager.abort().await;
return Ok(());
}
}
Expand Down Expand Up @@ -1193,21 +1197,29 @@ async fn da_block_monitor<Da>(
da_service: Arc<Da>,
sender: mpsc::Sender<L1Data<Da>>,
loop_interval: u64,
cancellation_token: CancellationToken,
) where
Da: DaService,
{
loop {
let l1_data = match get_da_block_data(da_service.clone()).await {
Ok(l1_data) => l1_data,
Err(e) => {
error!("Could not fetch L1 data, {}", e);
continue;
}
};
tokio::select! {
l1_data = get_da_block_data(da_service.clone()) => {
let l1_data = match l1_data {
Ok(l1_data) => l1_data,
Err(e) => {
error!("Could not fetch L1 data, {}", e);
continue;
}
};

let _ = sender.send(l1_data).await;
let _ = sender.send(l1_data).await;

sleep(Duration::from_millis(loop_interval)).await;
sleep(Duration::from_millis(loop_interval)).await;
},
_ = cancellation_token.cancelled() => {
return;
}
}
}
}

Expand Down

0 comments on commit b21f310

Please sign in to comment.