Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invoker concurrency quota #548

Merged
merged 8 commits into from
Jul 3, 2023
134 changes: 61 additions & 73 deletions src/invoker/src/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ use crate::service::invocation_task::InvocationTask;
use crate::service::status_store::InvocationStatusStore;
use codederror::CodedError;
use drain::ReleaseShutdown;
use futures::future::BoxFuture;
use futures::FutureExt;
use input_command::{InputCommand, InvokeCommand};
use invocation_state_machine::InvocationStateMachine;
use invocation_task::{InvocationTaskOutput, InvocationTaskOutputInner};
Expand All @@ -29,15 +27,15 @@ use std::pin::Pin;
use std::time::{Duration, SystemTime};
use std::{cmp, panic};
use tokio::sync::mpsc;
use tokio::task::JoinSet;
use tokio::task::{AbortHandle, JoinSet};
use tracing::instrument;
use tracing::{debug, trace};

mod input_command;
mod invocation_state_machine;
mod invocation_task;
mod quota;
mod state_machine_tree;
mod state_machine_manager;
mod status_store;

pub use input_command::ChannelServiceHandle;
Expand Down Expand Up @@ -89,6 +87,7 @@ type HttpsClient = hyper::Client<
// -- InvocationTask factory: we use this to mock the state machine in tests

trait InvocationTaskRunner {
#[allow(clippy::too_many_arguments)]
fn start_invocation_task(
&self,
partition: PartitionLeaderEpoch,
Expand All @@ -97,7 +96,8 @@ trait InvocationTaskRunner {
invoker_tx: mpsc::UnboundedSender<InvocationTaskOutput>,
invoker_rx: Option<mpsc::UnboundedReceiver<Completion>>,
input_journal: InvokeInputJournal,
) -> BoxFuture<'static, ()>;
task_pool: &mut JoinSet<()>,
) -> AbortHandle;
}

#[derive(Debug)]
Expand Down Expand Up @@ -129,26 +129,28 @@ where
invoker_tx: mpsc::UnboundedSender<InvocationTaskOutput>,
invoker_rx: Option<mpsc::UnboundedReceiver<Completion>>,
input_journal: InvokeInputJournal,
) -> BoxFuture<'static, ()> {
InvocationTask::new(
self.client.clone(),
partition,
sid,
0,
endpoint_metadata,
self.suspension_timeout,
self.response_abort_timeout,
self.disable_eager_state,
self.message_size_warning,
self.message_size_limit,
self.journal_reader.clone(),
self.state_reader.clone(),
self.entry_enricher.clone(),
invoker_tx,
invoker_rx,
task_pool: &mut JoinSet<()>,
) -> AbortHandle {
task_pool.spawn(
InvocationTask::new(
self.client.clone(),
partition,
sid,
0,
endpoint_metadata,
self.suspension_timeout,
self.response_abort_timeout,
self.disable_eager_state,
self.message_size_warning,
self.message_size_limit,
self.journal_reader.clone(),
self.state_reader.clone(),
self.entry_enricher.clone(),
invoker_tx,
invoker_rx,
)
.run(input_journal),
)
.run(input_journal)
.boxed()
}
}

Expand Down Expand Up @@ -212,7 +214,7 @@ impl<JR, SR, EE, SER> Service<JR, SR, EE, SER> {
retry_timers: Default::default(),
quota: quota::InvokerConcurrencyQuota::new(concurrency_limit),
status_store: Default::default(),
invocation_state_machines_tree: Default::default(),
ism_manager: Default::default(),
},
}
}
Expand Down Expand Up @@ -303,7 +305,7 @@ struct ServiceInner<ServiceEndpointRegistry, InvocationTaskRunner> {
retry_timers: TimerQueue<(PartitionLeaderEpoch, ServiceInvocationId)>,
quota: quota::InvokerConcurrencyQuota,
status_store: InvocationStatusStore,
invocation_state_machines_tree: state_machine_tree::InvocationStateMachineTree,
ism_manager: state_machine_manager::InvocationStateMachineManager,
slinkydeveloper marked this conversation as resolved.
Show resolved Hide resolved
}

impl<SER, ITR> ServiceInner<SER, ITR>
Expand Down Expand Up @@ -417,8 +419,7 @@ where
partition: PartitionLeaderEpoch,
sender: mpsc::Sender<Effect>,
) {
self.invocation_state_machines_tree
.register_partition(partition, sender);
self.ism_manager.register_partition(partition, sender);
}

#[instrument(
Expand All @@ -436,9 +437,9 @@ where
service_invocation_id: ServiceInvocationId,
journal: InvokeInputJournal,
) {
debug_assert!(self.invocation_state_machines_tree.has_partition(partition));
debug_assert!(self.ism_manager.has_partition(partition));
debug_assert!(self
.invocation_state_machines_tree
.ism_manager
.resolve_invocation(partition, &service_invocation_id)
.is_none());

Expand Down Expand Up @@ -515,7 +516,7 @@ where
entry: EnrichedRawEntry,
) {
if let Some((output_tx, ism)) = self
.invocation_state_machines_tree
.ism_manager
.resolve_invocation(partition, &service_invocation_id)
{
ism.notify_new_entry(entry_index);
Expand Down Expand Up @@ -551,7 +552,7 @@ where
completion: Completion,
) {
if let Some((_, ism)) = self
.invocation_state_machines_tree
.ism_manager
.resolve_invocation(partition, &service_invocation_id)
{
trace!(
Expand Down Expand Up @@ -580,7 +581,7 @@ where
service_invocation_id: ServiceInvocationId,
) {
if let Some((sender, _)) = self
.invocation_state_machines_tree
.ism_manager
.remove_invocation(partition, &service_invocation_id)
{
trace!("Invocation task closed correctly");
Expand Down Expand Up @@ -614,7 +615,7 @@ where
entry_indexes: HashSet<EntryIndex>,
) {
if let Some((sender, _)) = self
.invocation_state_machines_tree
.ism_manager
.remove_invocation(partition, &service_invocation_id)
{
trace!("Suspending invocation");
Expand Down Expand Up @@ -650,7 +651,7 @@ where
error: impl InvokerError + CodedError + Send + Sync + 'static,
) {
if let Some((_, ism)) = self
.invocation_state_machines_tree
.ism_manager
.remove_invocation(partition, &service_invocation_id)
{
self.handle_error_event(partition, service_invocation_id, error, ism)
Expand All @@ -676,7 +677,7 @@ where
service_invocation_id: ServiceInvocationId,
) {
if let Some((_, mut ism)) = self
.invocation_state_machines_tree
.ism_manager
.remove_invocation(partition, &service_invocation_id)
{
trace!(
Expand Down Expand Up @@ -705,10 +706,7 @@ where
)
)]
fn handle_abort_partition(&mut self, partition: PartitionLeaderEpoch) {
if let Some(invocation_state_machines) = self
.invocation_state_machines_tree
.remove_partition(partition)
{
if let Some(invocation_state_machines) = self.ism_manager.remove_partition(partition) {
for (sid, mut ism) in invocation_state_machines.into_iter() {
trace!(
rpc.service = %sid.service_id.service_name,
Expand All @@ -729,7 +727,7 @@ where

#[instrument(level = "trace", skip_all)]
fn handle_shutdown(&mut self) {
let partitions = self.invocation_state_machines_tree.registered_partitions();
let partitions = self.ism_manager.registered_partitions();
for partition in partitions {
self.handle_abort_partition(partition);
}
Expand All @@ -755,11 +753,8 @@ where
);
self.status_store
.on_failure(partition, service_invocation_id.clone(), &error);
self.invocation_state_machines_tree.register_invocation(
partition,
service_invocation_id.clone(),
ism,
);
self.ism_manager
.register_invocation(partition, service_invocation_id.clone(), ism);
self.retry_timers.sleep_until(
SystemTime::now() + next_retry_timer_duration,
(partition, service_invocation_id),
Expand All @@ -770,7 +765,7 @@ where
self.quota.unreserve_slot();
self.status_store.on_end(&partition, &service_invocation_id);
let _ = self
.invocation_state_machines_tree
.ism_manager
.resolve_partition_sender(partition)
.expect("Partition should be registered")
.send(Effect {
Expand Down Expand Up @@ -828,16 +823,15 @@ where
(Some(tx), Some(rx))
}
};
let abort_handle =
self.invocation_tasks
.spawn(self.invocation_task_runner.start_invocation_task(
partition,
service_invocation_id.clone(),
endpoint_metadata,
self.invocation_tasks_tx.clone(),
completions_rx,
journal,
));
let abort_handle = self.invocation_task_runner.start_invocation_task(
partition,
service_invocation_id.clone(),
endpoint_metadata,
self.invocation_tasks_tx.clone(),
completions_rx,
journal,
&mut self.invocation_tasks,
);

// Transition the state machine, and store it
self.status_store
Expand All @@ -847,11 +841,8 @@ where
"Invocation task started state. Invocation state: {:?}",
ism.invocation_state_debug()
);
self.invocation_state_machines_tree.register_invocation(
partition,
service_invocation_id,
ism,
);
self.ism_manager
.register_invocation(partition, service_invocation_id, ism);
}

async fn handle_retry_event<FN>(
Expand All @@ -863,7 +854,7 @@ where
FN: FnOnce(&mut InvocationStateMachine),
{
if let Some((_, mut ism)) = self
.invocation_state_machines_tree
.ism_manager
.remove_invocation(partition, &service_invocation_id)
{
f(&mut ism);
Expand All @@ -883,11 +874,8 @@ where
ism.invocation_state_debug()
);
// Not ready for retrying yet
self.invocation_state_machines_tree.register_invocation(
partition,
service_invocation_id,
ism,
);
self.ism_manager
.register_invocation(partition, service_invocation_id, ism);
}
} else {
// If no state machine is registered, the PP will send a new invoke
Expand Down Expand Up @@ -934,7 +922,7 @@ mod tests {
retry_timers: Default::default(),
quota: InvokerConcurrencyQuota::new(concurrency_limit),
status_store: Default::default(),
invocation_state_machines_tree: Default::default(),
ism_manager: Default::default(),
};
(input_tx, service_inner)
}
Expand Down Expand Up @@ -970,16 +958,16 @@ mod tests {
invoker_tx: mpsc::UnboundedSender<InvocationTaskOutput>,
invoker_rx: Option<mpsc::UnboundedReceiver<Completion>>,
input_journal: InvokeInputJournal,
) -> BoxFuture<'static, ()> {
(*self)(
task_pool: &mut JoinSet<()>,
) -> AbortHandle {
task_pool.spawn((*self)(
partition,
sid,
endpoint_metadata,
invoker_tx,
invoker_rx,
input_journal,
)
.boxed()
))
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/invoker/src/service/quota.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl InvokerConcurrencyQuota {
}

pub(in crate::service) fn reserve_slot(&mut self) {
debug_assert!(self.is_slot_available());
assert!(self.is_slot_available());
match self {
Self::Unlimited => {}
Self::Limited { available_slots } => *available_slots -= 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::service::*;

/// Tree of [InvocationStateMachine] held by the [Service].
#[derive(Debug, Default)]
pub(in crate::service) struct InvocationStateMachineTree {
pub(in crate::service) struct InvocationStateMachineManager {
partitions: HashMap<PartitionLeaderEpoch, PartitionInvocationStateMachineCoordinator>,
}

Expand All @@ -13,7 +13,7 @@ struct PartitionInvocationStateMachineCoordinator {
invocation_state_machines: HashMap<ServiceInvocationId, InvocationStateMachine>,
}

impl InvocationStateMachineTree {
impl InvocationStateMachineManager {
#[inline]
pub(in crate::service) fn has_partition(&self, partition: PartitionLeaderEpoch) -> bool {
self.partitions.contains_key(&partition)
Expand Down