Skip to content

Commit

Permalink
Make the PP aware of Ingress request ids. With this we fix all the co…
Browse files Browse the repository at this point in the history
…rner cases where mixing requests/attach/send didn't work previously.
  • Loading branch information
slinkydeveloper committed May 28, 2024
1 parent 3ef3eba commit bf1e637
Show file tree
Hide file tree
Showing 15 changed files with 457 additions and 568 deletions.
175 changes: 54 additions & 121 deletions crates/ingress-dispatcher/src/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
// by the Apache License, Version 2.0.

use super::{
error::IngressDispatchError, IngressCorrelationId, IngressDispatcherRequest,
IngressDispatcherRequestInner, IngressInvocationResponse, IngressInvocationResponseSender,
IngressRequestMode, IngressResponseKey, IngressResponseWaiterId,
error::IngressDispatchError, IngressDispatcherRequest, IngressDispatcherRequestInner,
IngressInvocationResponse, IngressInvocationResponseSender, IngressRequestMode,
IngressSubmittedInvocationNotificationSender, SubmittedInvocationNotification,
};

Expand All @@ -22,23 +21,20 @@ use restate_core::network::MessageHandler;
use restate_node_protocol::codec::Targeted;
use restate_node_protocol::ingress::IngressMessage;
use restate_storage_api::deduplication_table::DedupInformation;
use restate_types::identifiers::{InvocationId, PartitionKey, WithPartitionKey};
use restate_types::ingress::InvocationResponseCorrelationIds;
use restate_types::invocation::{AttachInvocationRequest, ServiceInvocationResponseSink};
use restate_types::identifiers::{IngressRequestId, PartitionKey, WithPartitionKey};
use restate_types::message::MessageIndex;
use restate_types::GenerationalNodeId;
use restate_wal_protocol::{
append_envelope_to_bifrost, Command, Destination, Envelope, Header, Source,
};
use std::collections::HashMap;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use tracing::{debug, trace};

/// Dispatches a request from ingress to bifrost
pub trait DispatchIngressRequest {
fn evict_pending_response(&self, ingress_response_key: IngressResponseKey);
fn evict_pending_submit_notification(&self, invocation_id: InvocationId);
fn evict_pending_response(&self, req_id: IngressRequestId);
fn evict_pending_submit_notification(&self, req_id: IngressRequestId);
fn dispatch_ingress_request(
&self,
ingress_request: IngressDispatcherRequest,
Expand All @@ -49,21 +45,20 @@ pub trait DispatchIngressRequest {
struct IngressDispatcherState {
msg_index: AtomicU64,

// TODO those two maps below can be replaced with the ResponseTracker from the network module

// This map can be unbounded, because we enforce concurrency limits in the ingress
// services using the global semaphore
waiting_responses: DashMap<
IngressCorrelationId,
HashMap<IngressResponseWaiterId, IngressInvocationResponseSender>,
>,
waiting_responses: DashMap<IngressRequestId, IngressInvocationResponseSender>,

// This map can be unbounded, because we enforce concurrency limits in the ingress
// services using the global semaphore
waiting_submit_notification:
DashMap<InvocationId, IngressSubmittedInvocationNotificationSender>,
DashMap<IngressRequestId, IngressSubmittedInvocationNotificationSender>,
}

impl IngressDispatcherState {
pub fn get_and_increment_msg_index(&self) -> MessageIndex {
fn get_and_increment_msg_index(&self) -> MessageIndex {
self.msg_index
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}
Expand All @@ -81,65 +76,15 @@ impl IngressDispatcher {
state: Arc::new(IngressDispatcherState::default()),
}
}

fn pop_waiters(
&self,
invocation_response_correlation_ids: InvocationResponseCorrelationIds,
) -> Vec<IngressInvocationResponseSender> {
let (invocation_id, idempotency_id, service_id) =
invocation_response_correlation_ids.into_inner();

let mut waiting_responses = vec![];

if let Some(waiter) = invocation_id.and_then(|invocation_id| {
self.state
.waiting_responses
.remove(&IngressCorrelationId::InvocationId(invocation_id))
}) {
waiting_responses.extend(waiter.1.into_values());
}

if let Some(waiter) = service_id.and_then(|service_id| {
self.state
.waiting_responses
.remove(&IngressCorrelationId::ServiceId(service_id.clone()))
}) {
waiting_responses.extend(waiter.1.into_values());
}

if let Some(waiter) = idempotency_id.and_then(|idempotency_id| {
self.state
.waiting_responses
.remove(&IngressCorrelationId::IdempotencyId(idempotency_id.clone()))
}) {
waiting_responses.extend(waiter.1.into_values());
}

waiting_responses
}
}

impl DispatchIngressRequest for IngressDispatcher {
fn evict_pending_response(&self, ingress_response_key: IngressResponseKey) {
let IngressResponseKey(invocation_id_or_idempotency_id, waiter_id) = ingress_response_key;
let e = self
.state
.waiting_responses
.entry(invocation_id_or_idempotency_id)
.and_modify(|h| {
h.remove(&waiter_id);
});
if let dashmap::mapref::entry::Entry::Occupied(o) = e {
if o.get().is_empty() {
o.remove();
}
}
fn evict_pending_response(&self, req_id: IngressRequestId) {
self.state.waiting_responses.remove(&req_id);
}

fn evict_pending_submit_notification(&self, invocation_id: InvocationId) {
self.state
.waiting_submit_notification
.remove(&invocation_id);
fn evict_pending_submit_notification(&self, req_id: IngressRequestId) {
self.state.waiting_submit_notification.remove(&req_id);
}

async fn dispatch_ingress_request(
Expand All @@ -156,9 +101,7 @@ impl DispatchIngressRequest for IngressDispatcher {
IngressRequestMode::RequestResponse(ingress_response_key, response_sender) => {
self.state
.waiting_responses
.entry(ingress_response_key.0)
.or_default()
.insert(ingress_response_key.1, response_sender);
.insert(ingress_response_key, response_sender);
(None, self.state.get_and_increment_msg_index(), None)
}
IngressRequestMode::WaitSubmitNotification(id, tx) => {
Expand Down Expand Up @@ -209,53 +152,42 @@ impl MessageHandler for IngressDispatcher {

match msg {
IngressMessage::InvocationResponse(invocation_response) => {
// We have the following situations to handle here:
// * Responses to regular calls.
// We correlate it by InvocationId
// * Responses to idempotent calls.
// We correlate it by IdempotencyId
// * Responses to invocation attach.
// We correlate it by InvocationId or IdempotencyId
// * Responses to workflow attach.
// We correlate it by ServiceId

let waiting_responses =
self.pop_waiters(invocation_response.correlation_ids.clone());

if waiting_responses.is_empty() {
debug!(
"Ignoring response '{:?}' because no handler was found locally waiting",
&invocation_response.correlation_ids
);
}

for sender in waiting_responses {
if let Some((_, tx)) = self
.state
.waiting_responses
.remove(&invocation_response.request_id)
{
let dispatcher_response = IngressInvocationResponse {
// TODO we need to add back the expiration time for idempotent results
idempotency_expiry_time: None,
result: invocation_response.response.clone(),
invocation_id: invocation_response.correlation_ids.invocation_id,
invocation_id: invocation_response.invocation_id,
};
if let Err(response) = sender.send(dispatcher_response) {
if let Err(response) = tx.send(dispatcher_response) {
debug!(
"Ignoring response '{:?}' because the handler has been \
closed, probably caused by the client connection that went away",
response
);
} else {
debug!(
trace!(
partition_processor_peer = %peer,
"Sent response of invocation {:?} out",
invocation_response.correlation_ids
invocation_response.invocation_id
);
}
} else {
debug!(
"Ignoring response to request id '{}' and invocation id '{:?}' because no handler was found locally waiting",
&invocation_response.request_id, invocation_response.invocation_id
);
}
}
IngressMessage::SubmittedInvocationNotification(attach_idempotent_invocation) => {
if let Some((_, sender)) = self
.state
.waiting_submit_notification
.remove(&attach_idempotent_invocation.original_invocation_id)
.remove(&attach_idempotent_invocation.request_id)
{
if let Err(response) = sender.send(SubmittedInvocationNotification {
invocation_id: attach_idempotent_invocation.attached_invocation_id,
Expand Down Expand Up @@ -306,11 +238,8 @@ fn wrap_service_invocation_in_envelope(
IngressDispatcherRequestInner::InvocationResponse(ir) => {
Command::InvocationResponse(ir)
}
IngressDispatcherRequestInner::Attach(invocation_query) => {
Command::AttachInvocation(AttachInvocationRequest {
invocation_query,
response_sink: ServiceInvocationResponseSink::Ingress(from_node_id),
})
IngressDispatcherRequestInner::Attach(attach_invocation_req) => {
Command::AttachInvocation(attach_invocation_req)
}
},
)
Expand All @@ -326,10 +255,11 @@ mod tests {
use restate_core::network::NetworkSender;
use restate_core::TestCoreEnvBuilder;
use restate_test_util::{let_assert, matchers::*};
use restate_types::identifiers::{IdempotencyId, InvocationId, WithPartitionKey};
use restate_types::identifiers::{InvocationId, WithPartitionKey};
use restate_types::ingress::{IngressResponseResult, InvocationResponse};
use restate_types::invocation::{
InvocationQuery, InvocationTarget, ServiceInvocation, VirtualObjectHandlerType,
AttachInvocationRequest, InvocationQuery, InvocationTarget, ServiceInvocation,
ServiceInvocationResponseSink, VirtualObjectHandlerType,
};
use restate_types::logs::{LogId, Lsn, SequenceNumber};
use restate_types::partition_table::{FindPartition, FixedPartitionTable};
Expand Down Expand Up @@ -371,11 +301,6 @@ mod tests {
&invocation_target,
Some(idempotency_key.clone()),
);
let idempotency_id = IdempotencyId::combine(
invocation_id,
&invocation_target,
idempotency_key.clone(),
);

let mut invocation = ServiceInvocation::initialize(
invocation_id,
Expand Down Expand Up @@ -416,6 +341,10 @@ mod tests {
completion_retention_time: some(eq(Duration::from_secs(60)))
})
);
let_assert!(
Some(ServiceInvocationResponseSink::Ingress { request_id, .. }) =
service_invocation.response_sink
);

// Now check we get the response is routed back to the handler correctly
let response = Bytes::from_static(b"vmoaifnuei");
Expand All @@ -424,14 +353,12 @@ mod tests {
.send(
metadata().my_node_id().into(),
&IngressMessage::InvocationResponse(InvocationResponse {
correlation_ids: InvocationResponseCorrelationIds::from_invocation_id(
service_invocation.invocation_id,
)
.with_idempotency_id(Some(idempotency_id)),
request_id,
response: IngressResponseResult::Success(
invocation_target.clone(),
response.clone(),
),
invocation_id: Some(service_invocation.invocation_id),
}),
)
.await?;
Expand Down Expand Up @@ -487,11 +414,18 @@ mod tests {
let output_message_1 =
Envelope::from_bytes(bifrost_messages[0].record.payload().unwrap().body())?;

let_assert!(
Command::AttachInvocation(attach_invocation_req) = output_message_1.command
);
assert_that!(
output_message_1.command,
pat!(Command::AttachInvocation(pat!(AttachInvocationRequest {
invocation_query: eq(InvocationQuery::Invocation(invocation_id))
})))
attach_invocation_req,
pat!(AttachInvocationRequest {
invocation_query: eq(InvocationQuery::Invocation(invocation_id)),
})
);
let_assert!(
ServiceInvocationResponseSink::Ingress { request_id, .. } =
attach_invocation_req.response_sink
);

// Now send the attach response
Expand All @@ -501,13 +435,12 @@ mod tests {
.send(
metadata().my_node_id().into(),
&IngressMessage::InvocationResponse(InvocationResponse {
correlation_ids: InvocationResponseCorrelationIds::from_invocation_id(
invocation_id,
),
request_id,
response: IngressResponseResult::Success(
InvocationTarget::mock_service(),
response.clone(),
),
invocation_id: Some(invocation_id),
}),
)
.await?;
Expand Down
Loading

0 comments on commit bf1e637

Please sign in to comment.