diff --git a/rs/replicated_state/src/canister_state/queues.rs b/rs/replicated_state/src/canister_state/queues.rs index 5d0b8db172c..17ea0b26eac 100644 --- a/rs/replicated_state/src/canister_state/queues.rs +++ b/rs/replicated_state/src/canister_state/queues.rs @@ -10,6 +10,7 @@ use self::message_pool::{ Context, InboundReference, Kind, MessagePool, OutboundReference, SomeReference, }; use self::queue::{CanisterQueue, IngressQueue, InputQueue, OutputQueue}; +use crate::page_map::int_map::MutableIntMap; use crate::replicated_state::MR_SYNTHETIC_REJECT_MESSAGE_MAX_LEN; use crate::{CanisterState, CheckpointLoadingMetrics, InputQueueType, InputSource, StateError}; use ic_base_types::PrincipalId; @@ -164,7 +165,7 @@ pub struct CanisterQueues { /// /// Used for response deduplication (whether due to a locally generated reject /// response to a best-effort call; or due to a malicious / buggy subnet). - callbacks_with_enqueued_response: BTreeSet, + callbacks_with_enqueued_response: MutableIntMap, } /// Circular iterator that consumes output queue messages: loops over output @@ -364,13 +365,13 @@ struct MessageStoreImpl { /// `CanisterInput::DeadlineExpired` by `peek_input()` / `pop_input()` (and /// "inflated" by `SystemState` into `SysUnknown` reject responses based on the /// callback). - expired_callbacks: BTreeMap, + expired_callbacks: MutableIntMap, /// Compact reject responses (`CallbackIds`) replacing best-effort responses /// that were shed. These are returned as `CanisterInput::ResponseDropped` by /// `peek_input()` / `pop_input()` (and "inflated" by `SystemState` into /// `SysUnknown` reject responses based on the callback). - shed_responses: BTreeMap, + shed_responses: MutableIntMap, } impl MessageStoreImpl { @@ -554,7 +555,7 @@ trait InboundMessageStore: MessageStore { fn callbacks_with_enqueued_response( &self, canister_queues: &BTreeMap, Arc)>, - ) -> Result, String>; + ) -> Result, String>; } impl InboundMessageStore for MessageStoreImpl { @@ -567,8 +568,8 @@ impl InboundMessageStore for MessageStoreImpl { fn callbacks_with_enqueued_response( &self, canister_queues: &BTreeMap, Arc)>, - ) -> Result, String> { - let mut callbacks = BTreeSet::new(); + ) -> Result, String> { + let mut callbacks = MutableIntMap::new(); canister_queues .values() .flat_map(|(input_queue, _)| input_queue.iter()) @@ -603,7 +604,7 @@ impl InboundMessageStore for MessageStoreImpl { } }; - if callbacks.insert(callback_id) { + if callbacks.insert(callback_id, ()).is_none() { Ok(()) } else { Err(format!( @@ -758,9 +759,10 @@ impl CanisterQueues { match self.canister_queues.get_mut(&sender) { Some((queue, _)) if queue.check_has_reserved_response_slot().is_ok() => { // Check against duplicate responses. - if !self + if self .callbacks_with_enqueued_response - .insert(response.originator_reply_callback) + .insert(response.originator_reply_callback, ()) + .is_some() { debug_assert_eq!(Ok(()), self.test_invariants()); if response.deadline == NO_DEADLINE { @@ -796,7 +798,8 @@ impl CanisterQueues { // aleady checked for a matching callback). Silently drop it. debug_assert!(self .callbacks_with_enqueued_response - .contains(&response.originator_reply_callback)); + .get(&response.originator_reply_callback) + .is_some()); return Ok(false); } } @@ -853,7 +856,11 @@ impl CanisterQueues { }; // Check against duplicate responses. - if !self.callbacks_with_enqueued_response.insert(callback_id) { + if self + .callbacks_with_enqueued_response + .insert(callback_id, ()) + .is_some() + { // There is already a response enqueued for the callback. return Ok(false); } @@ -920,7 +927,10 @@ impl CanisterQueues { if let Some(msg_) = &msg { if let Some(callback_id) = msg_.response_callback_id() { - assert!(self.callbacks_with_enqueued_response.remove(&callback_id)); + assert!(self + .callbacks_with_enqueued_response + .remove(&callback_id) + .is_some()); } debug_assert_eq!(Ok(()), self.test_invariants()); debug_assert_eq!(Ok(()), self.schedules_ok(&|_| InputQueueType::RemoteSubnet)); @@ -1559,7 +1569,8 @@ impl CanisterQueues { // request that was still in an output queue. assert!(self .callbacks_with_enqueued_response - .insert(response.originator_reply_callback)); + .insert(response.originator_reply_callback, ()) + .is_none()); let reference = self.store.insert_inbound(response.into()); Arc::make_mut(input_queue).push_response(reference); @@ -1742,7 +1753,7 @@ fn input_queue_type_fn<'a>( impl From<&CanisterQueues> for pb_queues::CanisterQueues { fn from(item: &CanisterQueues) -> Self { fn callback_references_to_proto( - callback_references: &BTreeMap, + callback_references: &MutableIntMap, ) -> Vec { callback_references .iter() @@ -1791,7 +1802,7 @@ impl TryFrom<(pb_queues::CanisterQueues, &dyn CheckpointLoadingMetrics)> for Can fn callback_references_try_from_proto( callback_references: Vec, - ) -> Result, ProxyDecodeError> + ) -> Result, ProxyDecodeError> { callback_references .into_iter() diff --git a/rs/replicated_state/src/canister_state/queues/message_pool.rs b/rs/replicated_state/src/canister_state/queues/message_pool.rs index f0507ccbf55..592d1bd8059 100644 --- a/rs/replicated_state/src/canister_state/queues/message_pool.rs +++ b/rs/replicated_state/src/canister_state/queues/message_pool.rs @@ -1,4 +1,5 @@ use super::CanisterInput; +use crate::page_map::int_map::{AsInt, MutableIntMap}; use ic_protobuf::proxy::{try_from_option_field, ProxyDecodeError}; use ic_protobuf::state::queues::v1 as pb_queues; use ic_types::messages::{ @@ -8,7 +9,7 @@ use ic_types::time::CoarseTime; use ic_types::{CountBytes, Time}; use ic_validate_eq::ValidateEq; use ic_validate_eq_derive::ValidateEq; -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::BTreeSet; use std::marker::PhantomData; use std::ops::{AddAssign, SubAssign}; use std::sync::Arc; @@ -131,6 +132,33 @@ impl Id { } } +impl AsInt for Id { + type Repr = u64; + + #[inline] + fn as_int(&self) -> u64 { + self.0 + } +} + +impl AsInt for (CoarseTime, Id) { + type Repr = u128; + + #[inline] + fn as_int(&self) -> u128 { + (self.0.as_secs_since_unix_epoch() as u128) << 64 | self.1 .0 as u128 + } +} + +impl AsInt for (usize, Id) { + type Repr = u128; + + #[inline] + fn as_int(&self) -> u128 { + (self.0 as u128) << 64 | self.1 .0 as u128 + } +} + /// A typed reference -- inbound (`CanisterInput`) or outbound /// (`RequestOrResponse`) -- to a message in the `MessagePool`. #[derive(Debug)] @@ -214,6 +242,15 @@ impl From> for Id { } } +impl AsInt for Reference { + type Repr = u64; + + #[inline] + fn as_int(&self) -> u64 { + self.0 + } +} + /// A reference to an inbound message (returned as a `CanisterInput`). pub(super) type InboundReference = Reference; @@ -327,7 +364,7 @@ impl TryFrom for CallbackReferenc pub(super) struct MessagePool { /// Pool contents. #[validate_eq(CompareWithValidateEq)] - messages: BTreeMap, + messages: MutableIntMap, /// Records the (implicit) deadlines of all the outbound guaranteed response /// requests (only). @@ -337,7 +374,7 @@ pub(super) struct MessagePool { /// `outbound_guaranteed_request_deadlines.keys().collect() == messages.keys().filter(|id| (id.context(), id.class(), id.kind()) == (Context::Outbound, Class::GuaranteedResponse, Kind::Request)).collect()` /// * The deadline matches the one recorded in `deadline_queue`: /// `outbound_guaranteed_request_deadlines.iter().all(|(id, deadline)| deadline_queue.contains(&(deadline, id)))` - outbound_guaranteed_request_deadlines: BTreeMap, + outbound_guaranteed_request_deadlines: MutableIntMap, /// Running message stats for the pool. message_stats: MessageStats, @@ -348,13 +385,13 @@ pub(super) struct MessagePool { /// by deadline. /// /// Message IDs break ties, ensuring deterministic ordering. - deadline_queue: BTreeSet<(CoarseTime, Id)>, + deadline_queue: MutableIntMap<(CoarseTime, Id), ()>, /// Load shedding priority queue. Holds all best-effort messages, ordered by /// size. /// /// Message IDs break ties, ensuring deterministic ordering. - size_queue: BTreeSet<(usize, Id)>, + size_queue: MutableIntMap<(usize, Id), ()>, /// A monotonically increasing counter used to generate unique message IDs. message_id_generator: u64, @@ -470,7 +507,7 @@ impl MessagePool { // all best-effort messages except responses in input queues; plus guaranteed // response requests in output queues if actual_deadline != NO_DEADLINE { - self.deadline_queue.insert((actual_deadline, id)); + self.deadline_queue.insert((actual_deadline, id), ()); // Record in the outbound guaranteed response deadline map, iff it's an outbound // guaranteed response request. @@ -483,7 +520,7 @@ impl MessagePool { // Record in load shedding queue iff it's a best-effort message. if class == Class::BestEffort { - self.size_queue.insert((size_bytes, id)); + self.size_queue.insert((size_bytes, id), ()); } reference @@ -552,7 +589,7 @@ impl MessagePool { .outbound_guaranteed_request_deadlines .remove(&id) .unwrap(); - let removed = self.deadline_queue.remove(&(deadline, id)); + let removed = self.deadline_queue.remove(&(deadline, id)).is_some(); debug_assert!(removed); } @@ -564,7 +601,7 @@ impl MessagePool { // All other best-effort messages do expire. (_, BestEffort, _) => { - let removed = self.deadline_queue.remove(&(msg.deadline(), id)); + let removed = self.deadline_queue.remove(&(msg.deadline(), id)).is_some(); debug_assert!(removed); } } @@ -573,7 +610,7 @@ impl MessagePool { /// Removes the given message from the load shedding queue. fn remove_from_size_queue(&mut self, id: Id, msg: &RequestOrResponse) { if id.class() == Class::BestEffort { - let removed = self.size_queue.remove(&(msg.count_bytes(), id)); + let removed = self.size_queue.remove(&(msg.count_bytes(), id)).is_some(); debug_assert!(removed); } } @@ -582,7 +619,7 @@ impl MessagePool { /// /// Time complexity: `O(log(self.len()))`. pub(super) fn has_expired_deadlines(&self, now: Time) -> bool { - if let Some((deadline, _)) = self.deadline_queue.first() { + if let Some((deadline, _)) = self.deadline_queue.min_key() { let now = CoarseTime::floor(now); if *deadline < now { return true; @@ -602,7 +639,7 @@ impl MessagePool { } let now = CoarseTime::floor(now); - if self.deadline_queue.first().unwrap().0 >= now { + if self.deadline_queue.min_key().unwrap().0 >= now { // No expired messages, bail out. return Vec::new(); } @@ -614,7 +651,7 @@ impl MessagePool { // Take and return all expired messages. let expired = temp .into_iter() - .map(|(_, id)| { + .map(|((_, id), _)| { let msg = self.take_impl(id).unwrap(); if id.is_outbound_guaranteed_request() { self.outbound_guaranteed_request_deadlines.remove(&id); @@ -633,7 +670,8 @@ impl MessagePool { /// /// Time complexity: `O(log(self.len()))`. pub(super) fn shed_largest_message(&mut self) -> Option<(SomeReference, RequestOrResponse)> { - if let Some((_, id)) = self.size_queue.pop_last() { + if let Some(&(size_bytes, id)) = self.size_queue.max_key() { + self.size_queue.remove(&(size_bytes, id)).unwrap(); debug_assert_eq!(Class::BestEffort, id.class()); let msg = self.take_impl(id).unwrap(); @@ -661,7 +699,7 @@ impl MessagePool { /// `debug_assert!()` checks. /// /// Time complexity: `O(n)`. - fn calculate_message_stats(messages: &BTreeMap) -> MessageStats { + fn calculate_message_stats(messages: &MutableIntMap) -> MessageStats { let mut stats = MessageStats::default(); for (id, msg) in messages.iter() { stats += MessageStats::stats_delta(msg, id.context()); @@ -754,11 +792,14 @@ impl MessagePool { /// Time complexity: `O(n * log(n))`. #[allow(clippy::type_complexity)] fn calculate_priority_queues( - messages: &BTreeMap, - outbound_guaranteed_request_deadlines: &BTreeMap, - ) -> (BTreeSet<(CoarseTime, Id)>, BTreeSet<(usize, Id)>) { - let mut expected_deadline_queue = BTreeSet::new(); - let mut expected_size_queue = BTreeSet::new(); + messages: &MutableIntMap, + outbound_guaranteed_request_deadlines: &MutableIntMap, + ) -> ( + MutableIntMap<(CoarseTime, Id), ()>, + MutableIntMap<(usize, Id), ()>, + ) { + let mut expected_deadline_queue = MutableIntMap::new(); + let mut expected_size_queue = MutableIntMap::new(); messages.iter().for_each(|(id, msg)| { use Class::*; use Context::*; @@ -767,7 +808,7 @@ impl MessagePool { // Outbound guaranteed response requests have (separately recorded) deadlines. (Outbound, GuaranteedResponse, Request) => { let deadline = outbound_guaranteed_request_deadlines.get(id).unwrap(); - expected_deadline_queue.insert((*deadline, *id)); + expected_deadline_queue.insert((*deadline, *id), ()); } // All other guaranteed response messages neither expire nor can be shed. @@ -776,13 +817,13 @@ impl MessagePool { // Inbound best-effort responses don't have expiration deadlines, but can be // shed. (Inbound, BestEffort, Response) => { - expected_size_queue.insert((msg.count_bytes(), *id)); + expected_size_queue.insert((msg.count_bytes(), *id), ()); } // All other best-effort messages are enqueued in both priority queues. (_, BestEffort, _) => { - expected_deadline_queue.insert((msg.deadline(), *id)); - expected_size_queue.insert((msg.count_bytes(), *id)); + expected_deadline_queue.insert((msg.deadline(), *id), ()); + expected_size_queue.insert((msg.count_bytes(), *id), ()); } } }); @@ -821,7 +862,7 @@ impl TryFrom for MessagePool { fn try_from(item: pb_queues::MessagePool) -> Result { let message_count = item.messages.len(); - let messages: BTreeMap<_, _> = item + let messages: MutableIntMap<_, _> = item .messages .into_iter() .map(|entry| { diff --git a/rs/replicated_state/src/canister_state/queues/message_pool/tests.rs b/rs/replicated_state/src/canister_state/queues/message_pool/tests.rs index 3f2ab229095..48e850920f3 100644 --- a/rs/replicated_state/src/canister_state/queues/message_pool/tests.rs +++ b/rs/replicated_state/src/canister_state/queues/message_pool/tests.rs @@ -70,6 +70,9 @@ fn test_insert() { (time(50 + REQUEST_LIFETIME.as_secs() as u32), id5) }, pool.deadline_queue + .iter() + .map(|((t, id), _)| (*t, *id)) + .collect() ); // All best-effort messages should be in the load shedding queue. @@ -102,7 +105,7 @@ fn test_insert_outbound_request_deadline_rounding() { pool.insert_outbound_request(request(NO_DEADLINE).into(), current_time); - assert_eq!(expected_deadline, pool.deadline_queue.first().unwrap().0); + assert_eq!(expected_deadline, pool.deadline_queue.min_key().unwrap().0); } #[test] @@ -233,6 +236,9 @@ fn test_expiration() { (time(40 + REQUEST_LIFETIME.as_secs() as u32), id4) }, pool.deadline_queue + .iter() + .map(|((t, id), _)| (*t, *id)) + .collect() ); // There are expiring messages. assert!(pool.has_expired_deadlines(t_max)); @@ -1028,9 +1034,12 @@ fn time(seconds_since_unix_epoch: u32) -> CoarseTime { CoarseTime::from_secs_since_unix_epoch(seconds_since_unix_epoch) } -fn assert_exact_messages_in_queue(messages: BTreeSet, queue: &BTreeSet<(T, Id)>) { +fn assert_exact_messages_in_queue(messages: BTreeSet, queue: &MutableIntMap<(T, Id), ()>) +where + (T, Id): AsInt, +{ assert_eq!(messages.len(), queue.len()); - assert_eq!(messages, queue.iter().map(|(_, id)| *id).collect()) + assert_eq!(messages, queue.iter().map(|((_, id), ())| *id).collect()) } /// Generates an `InboundReference` for a request of the given class. diff --git a/rs/replicated_state/src/canister_state/system_state/call_context_manager.rs b/rs/replicated_state/src/canister_state/system_state/call_context_manager.rs index eb79a67cf09..3982278d4bf 100644 --- a/rs/replicated_state/src/canister_state/system_state/call_context_manager.rs +++ b/rs/replicated_state/src/canister_state/system_state/call_context_manager.rs @@ -1,6 +1,7 @@ #[cfg(test)] mod tests; +use crate::page_map::int_map::{AsInt, MutableIntMap}; use ic_interfaces::execution_environment::HypervisorError; use ic_management_canister_types::IC_00; use ic_protobuf::proxy::{try_from_option_field, ProxyDecodeError}; @@ -18,12 +19,14 @@ use ic_types::{ PrincipalId, Time, UserId, }; use serde::{Deserialize, Serialize}; -use std::collections::btree_map::Entry; -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::BTreeSet; use std::convert::{From, TryFrom, TryInto}; use std::sync::Arc; use std::time::Duration; +#[cfg(test)] +use std::collections::BTreeMap; + /// Contains all context information related to an incoming call. #[derive(Clone, Eq, PartialEq, Debug)] pub struct CallContext { @@ -277,8 +280,8 @@ impl CallContextManagerStats { /// /// Time complexity: `O(n)`. pub(crate) fn calculate_stats( - call_contexts: &BTreeMap, - callbacks: &BTreeMap>, + call_contexts: &MutableIntMap, + callbacks: &MutableIntMap>, ) -> CallContextManagerStats { let unresponded_canister_update_call_contexts = call_contexts .values() @@ -320,11 +323,13 @@ impl CallContextManagerStats { /// (since this response was just delivered). /// /// Time complexity: `O(n)`. - #[allow(dead_code)] + #[cfg(test)] pub(crate) fn calculate_unresponded_callbacks_per_respondent( - callbacks: &BTreeMap>, + callbacks: &MutableIntMap>, aborted_or_paused_response: Option<&Response>, ) -> BTreeMap { + use std::collections::btree_map::Entry; + let mut callback_counts = callbacks.values().fold( BTreeMap::::new(), |mut counts, callback| { @@ -365,9 +370,9 @@ impl CallContextManagerStats { /// plus one for a paused or aborted canister request execution, if any. /// /// Time complexity: `O(n)`. - #[allow(dead_code)] + #[cfg(test)] pub(crate) fn calculate_unresponded_call_contexts_per_originator( - call_contexts: &BTreeMap, + call_contexts: &MutableIntMap, aborted_or_paused_request: Option<&Request>, ) -> BTreeMap { let mut unresponded_canister_update_call_contexts = call_contexts @@ -418,11 +423,14 @@ pub struct CallContextManager { next_callback_id: u64, /// Call contexts (including deleted ones) that still have open callbacks. - call_contexts: BTreeMap, + call_contexts: MutableIntMap, + + /// Counts of open callbacks per call context. + outstanding_callbacks: MutableIntMap, /// Callbacks still awaiting response, plus the callback of the currently /// paused or aborted DTS response execution, if any. - callbacks: BTreeMap>, + callbacks: MutableIntMap>, /// Callback deadline priority queue. Holds all not-yet-expired best-effort /// callbacks, ordered by deadline. `CallbackIds` break ties, ensuring @@ -430,7 +438,7 @@ pub struct CallContextManager { /// /// When a `CallbackId` is returned by `expired_callbacks()`, it is removed from /// the queue. This ensures that each callback is expired at most once. - unexpired_callbacks: BTreeSet<(CoarseTime, CallbackId)>, + unexpired_callbacks: MutableIntMap<(CoarseTime, CallbackId), ()>, /// Guaranteed response and overall callback and call context stats. stats: CallContextManagerStats, @@ -575,7 +583,7 @@ impl CallContextManager { /// Returns the currently open `CallContexts` maintained by this /// `CallContextManager`. - pub fn call_contexts(&self) -> &BTreeMap { + pub fn call_contexts(&self) -> &MutableIntMap { &self.call_contexts } @@ -594,18 +602,21 @@ impl CallContextManager { call_context_id: CallContextId, cycles: Cycles, ) -> Result<&CallContext, &str> { - let call_context = self + let mut call_context = self .call_contexts - .get_mut(&call_context_id) + .remove(&call_context_id) .ok_or("Canister accepted cycles from invalid call context")?; - call_context - .withdraw_cycles(cycles) - .map_err(|_| "Canister accepted more cycles than available from call context")?; - Ok(call_context) + let res = call_context.withdraw_cycles(cycles); + self.call_contexts.insert(call_context_id, call_context); + + match res { + Ok(()) => Ok(self.call_contexts.get(&call_context_id).unwrap()), + Err(()) => Err("Canister accepted more cycles than available from call context"), + } } /// Returns the `Callback`s maintained by this `CallContextManager`. - pub fn callbacks(&self) -> &BTreeMap> { + pub fn callbacks(&self) -> &MutableIntMap> { &self.callbacks } @@ -642,9 +653,9 @@ impl CallContextManager { OutstandingCalls::No }; - let context = self + let mut context = self .call_contexts - .get_mut(&call_context_id) + .remove(&call_context_id) .unwrap_or_else(|| panic!("no call context with ID={}", call_context_id)); // Update call context `instructions_executed += instructions_used` context.instructions_executed = context @@ -663,65 +674,56 @@ impl CallContextManager { let (action, call_context) = match (result, responded, outstanding_calls) { (Ok(None), Responded::No, OutstandingCalls::Yes) | (Err(_), Responded::No, OutstandingCalls::Yes) => { + self.call_contexts.insert(call_context_id, context); (CallContextAction::NotYetResponded, None) } (Ok(None), Responded::Yes, OutstandingCalls::Yes) | (Err(_), Responded::Yes, OutstandingCalls::Yes) => { + self.call_contexts.insert(call_context_id, context); (CallContextAction::AlreadyResponded, None) } (Ok(None), Responded::Yes, OutstandingCalls::No) - | (Err(_), Responded::Yes, OutstandingCalls::No) => ( - CallContextAction::AlreadyResponded, - self.call_contexts.remove(&call_context_id), - ), + | (Err(_), Responded::Yes, OutstandingCalls::No) => { + (CallContextAction::AlreadyResponded, Some(context)) + } (Ok(None), Responded::No, OutstandingCalls::No) => { self.stats.on_call_context_response(&context.call_origin); let refund = context.available_cycles; - ( - CallContextAction::NoResponse { refund }, - self.call_contexts.remove(&call_context_id), - ) + (CallContextAction::NoResponse { refund }, Some(context)) } (Ok(Some(WasmResult::Reply(payload))), Responded::No, OutstandingCalls::No) => { self.stats.on_call_context_response(&context.call_origin); let refund = context.available_cycles; - ( - CallContextAction::Reply { payload, refund }, - self.call_contexts.remove(&call_context_id), - ) + (CallContextAction::Reply { payload, refund }, Some(context)) } (Ok(Some(WasmResult::Reply(payload))), Responded::No, OutstandingCalls::Yes) => { self.stats.on_call_context_response(&context.call_origin); let refund = context.available_cycles; context.mark_responded(); + self.call_contexts.insert(call_context_id, context); (CallContextAction::Reply { payload, refund }, None) } (Ok(Some(WasmResult::Reject(payload))), Responded::No, OutstandingCalls::No) => { self.stats.on_call_context_response(&context.call_origin); let refund = context.available_cycles; - ( - CallContextAction::Reject { payload, refund }, - self.call_contexts.remove(&call_context_id), - ) + (CallContextAction::Reject { payload, refund }, Some(context)) } (Ok(Some(WasmResult::Reject(payload))), Responded::No, OutstandingCalls::Yes) => { self.stats.on_call_context_response(&context.call_origin); let refund = context.available_cycles; context.mark_responded(); + self.call_contexts.insert(call_context_id, context); (CallContextAction::Reject { payload, refund }, None) } (Err(error), Responded::No, OutstandingCalls::No) => { self.stats.on_call_context_response(&context.call_origin); let refund = context.available_cycles; - ( - CallContextAction::Fail { error, refund }, - self.call_contexts.remove(&call_context_id), - ) + (CallContextAction::Fail { error, refund }, Some(context)) } // The following can never happen since we handle at the SystemApi level if a canister @@ -748,18 +750,17 @@ impl CallContextManager { // TODO: Remove, this is only used in tests. #[cfg(test)] fn mark_responded(&mut self, call_context_id: CallContextId) -> Result<(), String> { - let call_context = self + let mut call_context = self .call_contexts - .get_mut(&call_context_id) + .remove(&call_context_id) .ok_or(format!("Call context not found: {}", call_context_id))?; - if call_context.responded { - return Ok(()); - } - - call_context.mark_responded(); + if !call_context.responded { + call_context.mark_responded(); - self.stats - .on_call_context_response(&call_context.call_origin); + self.stats + .on_call_context_response(&call_context.call_origin); + } + self.call_contexts.insert(call_context_id, call_context); debug_assert!(self.stats_ok()); Ok(()) @@ -773,10 +774,21 @@ impl CallContextManager { self.stats.on_register_callback(&callback); if callback.deadline != NO_DEADLINE { self.unexpired_callbacks - .insert((callback.deadline, callback_id)); + .insert((callback.deadline, callback_id), ()); } + self.outstanding_callbacks.insert( + callback.call_context_id, + self.outstanding_callbacks + .get(&callback.call_context_id) + .unwrap_or(&0) + + 1, + ); self.callbacks.insert(callback_id, Arc::new(callback)); + debug_assert_eq!( + calculate_outstanding_callbacks(&self.callbacks), + self.outstanding_callbacks + ); debug_assert!(self.stats_ok()); callback_id @@ -786,11 +798,27 @@ impl CallContextManager { /// the callback and return it. pub(super) fn unregister_callback(&mut self, callback_id: CallbackId) -> Option> { self.callbacks.remove(&callback_id).inspect(|callback| { + let outstanding_callbacks = *self + .outstanding_callbacks + .get(&callback.call_context_id) + .unwrap_or(&0); + if outstanding_callbacks <= 1 { + self.outstanding_callbacks.remove(&callback.call_context_id); + } else { + self.outstanding_callbacks + .insert(callback.call_context_id, outstanding_callbacks - 1); + } + self.stats.on_unregister_callback(callback); if callback.deadline != NO_DEADLINE { self.unexpired_callbacks .remove(&(callback.deadline, callback_id)); } + + debug_assert_eq!( + calculate_outstanding_callbacks(&self.callbacks), + self.outstanding_callbacks + ); debug_assert!(self.stats_ok()); }) } @@ -799,7 +827,7 @@ impl CallContextManager { /// whose deadlines are `< now`. pub(super) fn has_expired_callbacks(&self, now: CoarseTime) -> bool { self.unexpired_callbacks - .first() + .min_key() .map(|(deadline, _)| *deadline < now) .unwrap_or(false) } @@ -819,7 +847,7 @@ impl CallContextManager { expired_callbacks .into_iter() - .map(|(_, callback_id)| callback_id) + .map(|((_, callback_id), ())| callback_id) } /// Returns the call origin, which is either the message ID of the ingress @@ -839,14 +867,11 @@ impl CallContextManager { } /// Returns the number of outstanding calls for a given call context. - // - // TODO: This could be made more efficient by tracking the callback count per - // call context in a map. pub fn outstanding_calls(&self, call_context_id: CallContextId) -> usize { - self.callbacks - .iter() - .filter(|(_, callback)| callback.call_context_id == call_context_id) - .count() + *self + .outstanding_callbacks + .get(&call_context_id) + .unwrap_or(&0) } /// Expose the `next_callback_id` field so that the canister sandbox can @@ -954,7 +979,9 @@ impl CallContextManager { // subset of all best-effort callbacks. let all_callback_deadlines = calculate_callback_deadlines(&self.callbacks); debug_assert!( - all_callback_deadlines.is_superset(&self.unexpired_callbacks), + self.unexpired_callbacks + .iter() + .all(|(key, ())| all_callback_deadlines.contains(key)), "unexpired_callbacks: {:?}, all_callback_deadlines: {:?}", self.unexpired_callbacks, all_callback_deadlines @@ -972,21 +999,26 @@ impl CallContextManager { ) -> Vec { let mut reject_responses = Vec::new(); - for call_context in self.call_contexts.values_mut() { - if !call_context.has_responded() { - // Generate a reject response. - if let Some(response) = reject(call_context) { - reject_responses.push(response) - } + let call_contexts = std::mem::take(&mut self.call_contexts); + self.call_contexts = call_contexts + .into_iter() + .map(|(id, mut call_context)| { + if !call_context.has_responded() { + // Generate a reject response. + if let Some(response) = reject(&call_context) { + reject_responses.push(response) + } - call_context.mark_responded(); - self.stats - .on_call_context_response(&call_context.call_origin); - } + call_context.mark_responded(); + self.stats + .on_call_context_response(&call_context.call_origin); + } - // Mark the call context as deleted. - call_context.mark_deleted(); - } + // Mark the call context as deleted. + call_context.mark_deleted(); + (id, call_context) + }) + .collect(); debug_assert!(self.stats_ok()); reject_responses @@ -1041,7 +1073,7 @@ impl From<&CallContextManager> for pb::CallContextManager { unexpired_callbacks: item .unexpired_callbacks .iter() - .map(|(_, id)| id.get()) + .map(|((_, id), ())| id.get()) .collect(), } } @@ -1050,8 +1082,8 @@ impl From<&CallContextManager> for pb::CallContextManager { impl TryFrom for CallContextManager { type Error = ProxyDecodeError; fn try_from(value: pb::CallContextManager) -> Result { - let mut call_contexts = BTreeMap::::new(); - let mut callbacks = BTreeMap::>::new(); + let mut call_contexts = MutableIntMap::::new(); + let mut callbacks = MutableIntMap::>::new(); for pb::CallContextEntry { call_context_id, call_context, @@ -1075,6 +1107,7 @@ impl TryFrom for CallContextManager { )?), ); } + let outstanding_callbacks = calculate_outstanding_callbacks(&callbacks); let unexpired_callbacks = value .unexpired_callbacks .into_iter() @@ -1086,7 +1119,7 @@ impl TryFrom for CallContextManager { callback_id )) })?; - Ok((callback.deadline, callback_id)) + Ok(((callback.deadline, callback_id), ())) }) .collect::>()?; let stats = CallContextManagerStats::calculate_stats(&call_contexts, &callbacks); @@ -1095,6 +1128,7 @@ impl TryFrom for CallContextManager { next_call_context_id: value.next_call_context_id, next_callback_id: value.next_callback_id, call_contexts, + outstanding_callbacks, callbacks, unexpired_callbacks, stats, @@ -1109,7 +1143,7 @@ impl TryFrom for CallContextManager { /// /// Time complexity: `O(n)`. fn calculate_callback_deadlines( - callbacks: &BTreeMap>, + callbacks: &MutableIntMap>, ) -> BTreeSet<(CoarseTime, CallbackId)> { callbacks .iter() @@ -1118,6 +1152,36 @@ fn calculate_callback_deadlines( .collect() } +/// Calculates the counts of callbacks per call context. +/// +/// Time complexity: `O(n)`. +fn calculate_outstanding_callbacks( + callbacks: &MutableIntMap>, +) -> MutableIntMap { + callbacks + .iter() + .map(|(_, callback)| callback.call_context_id) + .fold( + MutableIntMap::::new(), + |mut counts, call_context_id| { + counts.insert( + call_context_id, + counts.get(&call_context_id).unwrap_or(&0) + 1, + ); + counts + }, + ) +} + +impl AsInt for (CoarseTime, CallbackId) { + type Repr = u128; + + #[inline] + fn as_int(&self) -> u128 { + (self.0.as_secs_since_unix_epoch() as u128) << 64 | self.1.get() as u128 + } +} + pub mod testing { use super::{CallContext, CallContextManager}; use ic_types::messages::CallContextId; diff --git a/rs/replicated_state/src/page_map/int_map.rs b/rs/replicated_state/src/page_map/int_map.rs index 1dcecde9410..238af11f507 100644 --- a/rs/replicated_state/src/page_map/int_map.rs +++ b/rs/replicated_state/src/page_map/int_map.rs @@ -1071,7 +1071,6 @@ impl<'a, K: AsInt, V: Clone> std::iter::Iterator for IntMapIter<'a, K, V> { // Find the leftmost subtree, pushing all the right hand side nodes onto the // stack. while let Tree::Branch { left, right, .. } = p { - debug_assert!(left.len() > 0 && right.len() > 0); self.0.push(right); p = left; } @@ -1106,7 +1105,6 @@ impl std::iter::Iterator for IntMapIntoIter { // Find the leftmost subtree, pushing all the right hand side nodes onto the // stack. while let Tree::Branch { left, right, .. } = p { - debug_assert!(left.len() > 0 && right.len() > 0); self.0.push(take_arc(right)); p = take_arc(left); }