diff --git a/crates/storage-api/src/outbox_table/mod.rs b/crates/storage-api/src/outbox_table/mod.rs index ca8fabf27a..f6f8e344b5 100644 --- a/crates/storage-api/src/outbox_table/mod.rs +++ b/crates/storage-api/src/outbox_table/mod.rs @@ -10,7 +10,9 @@ use crate::{GetFuture, PutFuture}; use restate_types::identifiers::{FullInvocationId, IngressDispatcherId, PartitionId}; -use restate_types::invocation::{InvocationResponse, ResponseResult, ServiceInvocation}; +use restate_types::invocation::{ + InvocationResponse, InvocationTermination, ResponseResult, ServiceInvocation, +}; use std::ops::Range; /// Types of outbox messages. @@ -29,8 +31,8 @@ pub enum OutboxMessage { response: ResponseResult, }, - /// Kill command to send to another partition processor - Kill(FullInvocationId), + /// Terminate invocation to send to another partition processor + InvocationTermination(InvocationTermination), } pub trait OutboxTable { diff --git a/crates/storage-proto/proto/dev/restate/storage/v1/domain.proto b/crates/storage-proto/proto/dev/restate/storage/v1/domain.proto index 9823026589..0fcab13204 100644 --- a/crates/storage-proto/proto/dev/restate/storage/v1/domain.proto +++ b/crates/storage-proto/proto/dev/restate/storage/v1/domain.proto @@ -14,6 +14,13 @@ message FullInvocationId { bytes invocation_uuid = 3; } +message MaybeFullInvocationId { + oneof kind { + FullInvocationId full_invocation_id = 1; + bytes invocation_id = 2; + } +} + // --------------------------------------------------------------------- // Service Invocation // --------------------------------------------------------------------- @@ -283,10 +290,7 @@ message OutboxMessage { } message OutboxServiceInvocationResponse { - oneof id { - FullInvocationId full_invocation_id = 1; - bytes invocation_id = 4; - } + MaybeFullInvocationId maybe_fid = 1; uint32 entry_index = 2; ResponseResult response_result = 3; } @@ -298,7 +302,11 @@ message OutboxMessage { } message OutboxKill { - FullInvocationId full_invocation_id = 1; + MaybeFullInvocationId maybe_full_invocation_id = 1; + } + + message OutboxCancel { + MaybeFullInvocationId maybe_full_invocation_id = 1; } oneof outbox_message { @@ -306,6 +314,7 @@ message OutboxMessage { OutboxServiceInvocationResponse service_invocation_response = 2; OutboxIngressResponse ingress_response = 3; OutboxKill kill = 4; + OutboxCancel cancel = 5; } } diff --git a/crates/storage-proto/src/lib.rs b/crates/storage-proto/src/lib.rs index 407febbca3..7e25445189 100644 --- a/crates/storage-proto/src/lib.rs +++ b/crates/storage-proto/src/lib.rs @@ -32,7 +32,7 @@ pub mod storage { completion_result, CompletionResult, Entry, Kind, }; use crate::storage::v1::outbox_message::{ - OutboxIngressResponse, OutboxKill, OutboxServiceInvocation, + OutboxCancel, OutboxIngressResponse, OutboxKill, OutboxServiceInvocation, OutboxServiceInvocationResponse, }; use crate::storage::v1::service_invocation_response_sink::{ @@ -40,12 +40,11 @@ pub mod storage { }; use crate::storage::v1::{ enriched_entry_header, invocation_resolution_result, invocation_status, - outbox_message, outbox_message::outbox_service_invocation_response, - response_result, span_relation, timer, BackgroundCallResolutionResult, - EnrichedEntryHeader, FullInvocationId, InboxEntry, InvocationResolutionResult, - InvocationStatus, JournalEntry, JournalMeta, OutboxMessage, ResponseResult, - SequencedTimer, ServiceInvocation, ServiceInvocationResponseSink, SpanContext, - SpanRelation, Timer, + maybe_full_invocation_id, outbox_message, response_result, span_relation, timer, + BackgroundCallResolutionResult, EnrichedEntryHeader, FullInvocationId, InboxEntry, + InvocationResolutionResult, InvocationStatus, JournalEntry, JournalMeta, + MaybeFullInvocationId, OutboxMessage, ResponseResult, SequencedTimer, + ServiceInvocation, ServiceInvocationResponseSink, SpanContext, SpanRelation, Timer, }; use anyhow::anyhow; use bytes::{Buf, Bytes}; @@ -53,7 +52,7 @@ pub mod storage { use opentelemetry_api::trace::TraceState; use restate_storage_api::StorageError; use restate_types::identifiers::ServiceId; - use restate_types::invocation::MaybeFullInvocationId; + use restate_types::invocation::{InvocationTermination, TerminationFlavor}; use restate_types::journal::enriched::AwakeableEnrichmentResult; use restate_types::time::MillisSinceEpoch; use std::collections::{HashSet, VecDeque}; @@ -589,6 +588,49 @@ pub mod storage { } } + impl TryFrom for restate_types::invocation::MaybeFullInvocationId { + type Error = ConversionError; + + fn try_from(value: MaybeFullInvocationId) -> Result { + match value.kind.ok_or(ConversionError::missing_field("kind"))? { + maybe_full_invocation_id::Kind::FullInvocationId(fid) => { + Ok(restate_types::invocation::MaybeFullInvocationId::Full( + restate_types::identifiers::FullInvocationId::try_from(fid)?, + )) + } + maybe_full_invocation_id::Kind::InvocationId(invocation_id) => { + Ok(restate_types::invocation::MaybeFullInvocationId::Partial( + restate_types::identifiers::InvocationId::from_slice( + &invocation_id, + ) + .map_err(|e| ConversionError::invalid_data(e))?, + )) + } + } + } + } + + impl From for MaybeFullInvocationId { + fn from(value: restate_types::invocation::MaybeFullInvocationId) -> Self { + match value { + restate_types::invocation::MaybeFullInvocationId::Full(fid) => { + MaybeFullInvocationId { + kind: Some(maybe_full_invocation_id::Kind::FullInvocationId( + FullInvocationId::from(fid), + )), + } + } + restate_types::invocation::MaybeFullInvocationId::Partial( + invocation_id, + ) => MaybeFullInvocationId { + kind: Some(maybe_full_invocation_id::Kind::InvocationId( + Bytes::copy_from_slice(&invocation_id.as_bytes()), + )), + }, + } + } + } + fn try_bytes_into_invocation_uuid( bytes: Bytes, ) -> Result { @@ -1248,26 +1290,10 @@ pub mod storage { ) => restate_storage_api::outbox_table::OutboxMessage::ServiceResponse( restate_types::invocation::InvocationResponse { entry_index: invocation_response.entry_index, - id: match invocation_response - .id - .ok_or(ConversionError::missing_field("id"))? - { - outbox_service_invocation_response::Id::FullInvocationId( - fid, - ) => MaybeFullInvocationId::Full( - restate_types::identifiers::FullInvocationId::try_from( - fid, - )?, - ), - outbox_service_invocation_response::Id::InvocationId( - invocation_id_bytes, - ) => MaybeFullInvocationId::Partial( - restate_types::identifiers::InvocationId::from_slice( - &invocation_id_bytes, - ) - .map_err(ConversionError::invalid_data)?, - ), - }, + id: invocation_response + .maybe_fid + .ok_or(ConversionError::missing_field("maybe_fid"))? + .try_into()?, result: restate_types::invocation::ResponseResult::try_from( invocation_response .response_result @@ -1294,11 +1320,27 @@ pub mod storage { } } outbox_message::OutboxMessage::Kill(outbox_kill) => { - let fid = outbox_kill - .full_invocation_id - .ok_or(ConversionError::missing_field("full_invocation_id"))?; - restate_storage_api::outbox_table::OutboxMessage::Kill( - restate_types::identifiers::FullInvocationId::try_from(fid)?, + let maybe_fid = outbox_kill.maybe_full_invocation_id.ok_or( + ConversionError::missing_field("maybe_full_invocation_id"), + )?; + restate_storage_api::outbox_table::OutboxMessage::InvocationTermination( + InvocationTermination::kill( + restate_types::invocation::MaybeFullInvocationId::try_from( + maybe_fid, + )?, + ), + ) + } + outbox_message::OutboxMessage::Cancel(outbox_cancel) => { + let maybe_fid = outbox_cancel.maybe_full_invocation_id.ok_or( + ConversionError::missing_field("maybe_full_invocation_id"), + )?; + restate_storage_api::outbox_table::OutboxMessage::InvocationTermination( + InvocationTermination::cancel( + restate_types::invocation::MaybeFullInvocationId::try_from( + maybe_fid, + )?, + ), ) } }; @@ -1324,18 +1366,9 @@ pub mod storage { ) => outbox_message::OutboxMessage::ServiceInvocationResponse( OutboxServiceInvocationResponse { entry_index: invocation_response.entry_index, - id: Some(match invocation_response.id { - MaybeFullInvocationId::Partial(iid) => { - outbox_service_invocation_response::Id::InvocationId( - Bytes::copy_from_slice(&iid.as_bytes()), - ) - } - MaybeFullInvocationId::Full(fid) => { - outbox_service_invocation_response::Id::FullInvocationId( - FullInvocationId::from(fid), - ) - } - }), + maybe_fid: Some(MaybeFullInvocationId::from( + invocation_response.id, + )), response_result: Some(ResponseResult::from( invocation_response.result, )), @@ -1354,11 +1387,24 @@ pub mod storage { response_result: Some(ResponseResult::from(response)), }) } - restate_storage_api::outbox_table::OutboxMessage::Kill(fid) => { - outbox_message::OutboxMessage::Kill(OutboxKill { - full_invocation_id: Some(FullInvocationId::from(fid)), - }) - } + restate_storage_api::outbox_table::OutboxMessage::InvocationTermination( + invocation_termination, + ) => match invocation_termination.flavor { + TerminationFlavor::Kill => { + outbox_message::OutboxMessage::Kill(OutboxKill { + maybe_full_invocation_id: Some(MaybeFullInvocationId::from( + invocation_termination.maybe_fid, + )), + }) + } + TerminationFlavor::Cancel => { + outbox_message::OutboxMessage::Cancel(OutboxCancel { + maybe_full_invocation_id: Some(MaybeFullInvocationId::from( + invocation_termination.maybe_fid, + )), + }) + } + }, }; OutboxMessage { diff --git a/crates/types/src/errors.rs b/crates/types/src/errors.rs index 4b17e852e9..c2d17fb584 100644 --- a/crates/types/src/errors.rs +++ b/crates/types/src/errors.rs @@ -322,6 +322,13 @@ pub const KILLED_INVOCATION_ERROR: InvocationError = InvocationError::new_static "killed", ); +// TODO: Once we want to distinguish server side cancellations from user code returning the +// UserErrorCode::Cancelled, we need to add a new RestateErrorCode. +pub const CANCELED_INVOCATION_ERROR: InvocationError = InvocationError::new_static( + InvocationErrorCode::User(UserErrorCode::Cancelled), + "canceled", +); + #[cfg(feature = "tonic_conversions")] mod tonic_conversions_impl { use super::{InvocationError, InvocationErrorCode}; diff --git a/crates/types/src/invocation.rs b/crates/types/src/invocation.rs index 80fd872fc5..21a8dd0a3f 100644 --- a/crates/types/src/invocation.rs +++ b/crates/types/src/invocation.rs @@ -379,6 +379,38 @@ impl SpanRelation { } } +/// Message to terminate an invocation. +#[derive(Debug, Clone, PartialEq)] +pub struct InvocationTermination { + pub maybe_fid: MaybeFullInvocationId, + pub flavor: TerminationFlavor, +} + +impl InvocationTermination { + pub fn kill(maybe_fid: impl Into) -> Self { + Self { + maybe_fid: maybe_fid.into(), + flavor: TerminationFlavor::Kill, + } + } + + pub fn cancel(maybe_fid: impl Into) -> Self { + Self { + maybe_fid: maybe_fid.into(), + flavor: TerminationFlavor::Cancel, + } + } +} + +/// Flavor of the termination. Can be kill (hard stop) or graceful cancel. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum TerminationFlavor { + /// hard termination, no clean up + Kill, + /// graceful termination allowing the invocation to clean up + Cancel, +} + #[cfg(any(test, feature = "mocks"))] mod mocks { use super::*; diff --git a/crates/worker/src/network_integration.rs b/crates/worker/src/network_integration.rs index 2f57f379b4..e3eaa71625 100644 --- a/crates/worker/src/network_integration.rs +++ b/crates/worker/src/network_integration.rs @@ -139,7 +139,7 @@ mod shuffle_integration { use restate_types::errors::InvocationError; use restate_types::identifiers::WithPartitionKey; use restate_types::identifiers::{PartitionId, PartitionKey, PeerId}; - use restate_types::invocation::{MaybeFullInvocationId, ResponseResult}; + use restate_types::invocation::ResponseResult; use restate_types::message::MessageIndex; #[derive(Debug)] @@ -159,7 +159,9 @@ mod shuffle_integration { shuffle::PartitionProcessorMessage::Response(response) => { response.id.partition_key() } - PartitionProcessorMessage::Kill(fid) => fid.partition_key(), + PartitionProcessorMessage::InvocationTermination(invocation_termination) => { + invocation_termination.maybe_fid.partition_key() + } } } } @@ -192,12 +194,12 @@ mod shuffle_integration { deduplication_source, ) } - shuffle::PartitionProcessorMessage::Kill(fid) => { - partition::StateMachineAckCommand::dedup( - partition::StateMachineCommand::Kill(MaybeFullInvocationId::from(fid)), - deduplication_source, - ) - } + shuffle::PartitionProcessorMessage::InvocationTermination( + invocation_termination, + ) => partition::StateMachineAckCommand::dedup( + partition::StateMachineCommand::TerminateInvocation(invocation_termination), + deduplication_source, + ), } } } diff --git a/crates/worker/src/partition/shuffle.rs b/crates/worker/src/partition/shuffle.rs index d2b6d576a5..e0768adbfd 100644 --- a/crates/worker/src/partition/shuffle.rs +++ b/crates/worker/src/partition/shuffle.rs @@ -12,7 +12,9 @@ use crate::partition::shuffle::state_machine::StateMachine; use futures::future::BoxFuture; use restate_storage_api::outbox_table::OutboxMessage; use restate_types::identifiers::{FullInvocationId, IngressDispatcherId, PartitionId, PeerId}; -use restate_types::invocation::{InvocationResponse, ResponseResult, ServiceInvocation}; +use restate_types::invocation::{ + InvocationResponse, InvocationTermination, ResponseResult, ServiceInvocation, +}; use restate_types::message::{AckKind, MessageIndex}; use std::time::Duration; use tokio::sync::mpsc; @@ -53,7 +55,7 @@ pub(crate) struct ShuffleInput(pub(crate) AckKind); pub(crate) enum PartitionProcessorMessage { Invocation(ServiceInvocation), Response(InvocationResponse), - Kill(FullInvocationId), + InvocationTermination(InvocationTermination), } #[derive(Debug, Clone)] @@ -126,8 +128,10 @@ impl From for ShuffleMessageDestination { PartitionProcessorMessage::Invocation(invocation), ) } - OutboxMessage::Kill(fid) => { - ShuffleMessageDestination::PartitionProcessor(PartitionProcessorMessage::Kill(fid)) + OutboxMessage::InvocationTermination(invocation_termination) => { + ShuffleMessageDestination::PartitionProcessor( + PartitionProcessorMessage::InvocationTermination(invocation_termination), + ) } } } diff --git a/crates/worker/src/partition/state_machine/command_interpreter.rs b/crates/worker/src/partition/state_machine/command_interpreter.rs index d3348e25af..169a7d5cf0 100644 --- a/crates/worker/src/partition/state_machine/command_interpreter.rs +++ b/crates/worker/src/partition/state_machine/command_interpreter.rs @@ -9,6 +9,7 @@ // by the Apache License, Version 2.0. use super::Error; +use std::collections::HashSet; use crate::partition::services::deterministic; use crate::partition::services::non_deterministic::{Effect as NBISEffect, Effects as NBISEffects}; @@ -25,20 +26,25 @@ use futures::StreamExt; use restate_storage_api::inbox_table::InboxEntry; use restate_storage_api::journal_table::JournalEntry; use restate_storage_api::outbox_table::OutboxMessage; -use restate_storage_api::status_table::{InvocationMetadata, InvocationStatus}; +use restate_storage_api::status_table::{InvocationMetadata, InvocationStatus, NotificationTarget}; use restate_storage_api::timer_table::Timer; -use restate_types::errors::{InvocationError, InvocationErrorCode, KILLED_INVOCATION_ERROR}; +use restate_types::errors::{ + InvocationError, InvocationErrorCode, UserErrorCode, CANCELED_INVOCATION_ERROR, + KILLED_INVOCATION_ERROR, +}; use restate_types::identifiers::{ EntryIndex, FullInvocationId, InvocationId, InvocationUuid, ServiceId, }; use restate_types::invocation::{ - InvocationResponse, MaybeFullInvocationId, ResponseResult, ServiceInvocation, - ServiceInvocationResponseSink, ServiceInvocationSpanContext, SpanRelation, SpanRelationCause, + InvocationResponse, InvocationTermination, MaybeFullInvocationId, ResponseResult, + ServiceInvocation, ServiceInvocationResponseSink, ServiceInvocationSpanContext, SpanRelation, + SpanRelationCause, TerminationFlavor, }; use restate_types::journal::enriched::{ AwakeableEnrichmentResult, EnrichedEntryHeader, EnrichedRawEntry, InvokeEnrichmentResult, }; use restate_types::journal::raw::RawEntryCodec; +use restate_types::journal::Completion; use restate_types::journal::*; use restate_types::message::MessageIndex; use restate_types::time::MillisSinceEpoch; @@ -191,7 +197,10 @@ where Ok((None, SpanRelation::None)) } Command::Timer(timer) => self.on_timer(timer, state, effects).await, - Command::Kill(maybe_fid) => self.try_kill_invocation(maybe_fid, state, effects).await, + Command::TerminateInvocation(invocation_termination) => { + self.try_terminate_invocation(invocation_termination, state, effects) + .await + } Command::BuiltInInvoker(nbis_effects) => { self.try_built_in_invoker_effect(effects, state, nbis_effects) .await @@ -353,14 +362,29 @@ where Ok(()) } + async fn try_terminate_invocation( + &mut self, + InvocationTermination { + maybe_fid, + flavor: termination_flavor, + }: InvocationTermination, + state: &mut State, + effects: &mut Effects, + ) -> Result<(Option, SpanRelation), Error> { + match termination_flavor { + TerminationFlavor::Kill => self.try_kill_invocation(maybe_fid, state, effects).await, + TerminationFlavor::Cancel => { + self.try_cancel_invocation(maybe_fid, state, effects).await + } + } + } + async fn try_kill_invocation( &mut self, maybe_fid: MaybeFullInvocationId, state: &mut State, effects: &mut Effects, ) -> Result<(Option, SpanRelation), Error> { - // TODO: Introduce distinction between invocation_status and service_instance_status to - // properly handle case when the given invocation is not executing + avoid cloning maybe_fid let (full_invocation_id, status) = Self::read_invocation_status(maybe_fid.clone(), state).await?; @@ -368,10 +392,12 @@ where InvocationStatus::Invoked(metadata) | InvocationStatus::Suspended { metadata, .. } if metadata.invocation_uuid == full_invocation_id.invocation_uuid => { - let (fid, related_span) = self - .kill_invocation(full_invocation_id, metadata, state, effects) + let related_span = metadata.journal_metadata.span_context.as_parent(); + + self.kill_invocation(full_invocation_id.clone(), metadata, state, effects) .await?; - Ok((Some(fid), related_span)) + + Ok((Some(full_invocation_id), related_span)) } InvocationStatus::Virtual { kill_notification_target, @@ -397,30 +423,129 @@ where Ok((Some(full_invocation_id), SpanRelation::None)) } _ => { - // check if service invocation is in inbox - let inbox_entry = state.get_inbox_entry(maybe_fid).await?; + self.try_terminate_inboxed_invocation( + TerminationFlavor::Kill, + maybe_fid, + state, + effects, + full_invocation_id, + ) + .await? + } + } + } - if let Some(inbox_entry) = inbox_entry { - self.kill_inboxed_invocation(effects, inbox_entry) - } else { - trace!("Received kill command for unknown invocation with fid '{full_invocation_id}'."); - // We still try to send the abort signal to the invoker, - // as it might be the case that previously the user sent an abort signal - // but some message was still between the invoker/PP queues. - // This can happen because the invoke/resume and the abort invoker messages end up in different queues, - // and the abort message can overtake the invoke/resume. - // Consequently the invoker might have not received the abort and the user tried to send it again. - effects.abort_invocation(full_invocation_id.clone()); - Ok((Some(full_invocation_id), SpanRelation::None)) + async fn try_terminate_inboxed_invocation( + &mut self, + termination_flavor: TerminationFlavor, + maybe_fid: MaybeFullInvocationId, + state: &mut State, + effects: &mut Effects, + full_invocation_id: FullInvocationId, + ) -> Result, SpanRelation), Error>, Error> { + let (termination_command, error) = match termination_flavor { + TerminationFlavor::Kill => ("kill", KILLED_INVOCATION_ERROR), + TerminationFlavor::Cancel => ("cancel", CANCELED_INVOCATION_ERROR), + }; + + // check if service invocation is in inbox + let inbox_entry = state.get_inbox_entry(maybe_fid).await?; + + Ok(if let Some(inbox_entry) = inbox_entry { + self.terminate_inboxed_invocation(inbox_entry, error, effects) + } else { + trace!("Received {termination_command} command for unknown invocation with fid '{full_invocation_id}'."); + // We still try to send the abort signal to the invoker, + // as it might be the case that previously the user sent an abort signal + // but some message was still between the invoker/PP queues. + // This can happen because the invoke/resume and the abort invoker messages end up in different queues, + // and the abort message can overtake the invoke/resume. + // Consequently the invoker might have not received the abort and the user tried to send it again. + effects.abort_invocation(full_invocation_id.clone()); + Ok((Some(full_invocation_id), SpanRelation::None)) + }) + } + + async fn try_cancel_invocation( + &mut self, + maybe_fid: MaybeFullInvocationId, + state: &mut State, + effects: &mut Effects, + ) -> Result<(Option, SpanRelation), Error> { + let (full_invocation_id, status) = + Self::read_invocation_status(maybe_fid.clone(), state).await?; + + match status { + InvocationStatus::Invoked(metadata) + if metadata.invocation_uuid == full_invocation_id.invocation_uuid => + { + let related_span = metadata.journal_metadata.span_context.as_parent(); + + self.cancel_journal_leaves( + full_invocation_id.clone(), + InvocationStatusProjection::Invoked, + metadata.journal_metadata.length, + state, + effects, + ) + .await?; + + Ok((Some(full_invocation_id), related_span)) + } + InvocationStatus::Suspended { + metadata, + waiting_for_completed_entries, + } if metadata.invocation_uuid == full_invocation_id.invocation_uuid => { + let related_span = metadata.journal_metadata.span_context.as_parent(); + + if self + .cancel_journal_leaves( + full_invocation_id.clone(), + InvocationStatusProjection::Suspended(waiting_for_completed_entries), + metadata.journal_metadata.length, + state, + effects, + ) + .await? + { + effects.resume_service(full_invocation_id.service_id.clone(), metadata); } + + Ok((Some(full_invocation_id), related_span)) + } + InvocationStatus::Virtual { + journal_metadata, + completion_notification_target, + .. + } => { + self.cancel_journal_leaves( + full_invocation_id.clone(), + InvocationStatusProjection::Virtual(completion_notification_target), + journal_metadata.length, + state, + effects, + ) + .await?; + Ok((Some(full_invocation_id), SpanRelation::None)) + } + _ => { + self.try_terminate_inboxed_invocation( + TerminationFlavor::Cancel, + maybe_fid, + state, + effects, + full_invocation_id, + ) + .await? } } } - fn kill_inboxed_invocation( + fn terminate_inboxed_invocation( &mut self, - effects: &mut Effects, inbox_entry: InboxEntry, + error: InvocationError, + effects: &mut Effects, ) -> Result<(Option, SpanRelation), Error> { // remove service invocation from inbox and send failure response let service_invocation = inbox_entry.service_invocation; @@ -428,22 +553,14 @@ where let span_context = service_invocation.span_context; let parent_span = span_context.as_parent(); - self.try_send_failure_response( - effects, - &fid, - service_invocation.response_sink, - &KILLED_INVOCATION_ERROR, - ); + self.try_send_failure_response(effects, &fid, service_invocation.response_sink, &error); self.notify_invocation_result( &fid, service_invocation.method_name, span_context, MillisSinceEpoch::now(), - Err(( - KILLED_INVOCATION_ERROR.code(), - KILLED_INVOCATION_ERROR.to_string(), - )), + Err((error.code(), error.to_string())), effects, ); @@ -458,9 +575,7 @@ where metadata: InvocationMetadata, state: &mut State, effects: &mut Effects, - ) -> Result<(FullInvocationId, SpanRelation), Error> { - let related_span = metadata.journal_metadata.span_context.as_parent(); - + ) -> Result<(), Error> { self.kill_child_invocations( &full_invocation_id, state, @@ -477,9 +592,8 @@ where KILLED_INVOCATION_ERROR, ) .await?; - effects.abort_invocation(full_invocation_id.clone()); - - Ok((full_invocation_id, related_span)) + effects.abort_invocation(full_invocation_id); + Ok(()) } async fn kill_child_invocations( @@ -506,7 +620,12 @@ where enrichment_result.service_key, enrichment_result.invocation_uuid, ); - self.send_message(OutboxMessage::Kill(target_fid), effects); + self.send_message( + OutboxMessage::InvocationTermination(InvocationTermination::kill( + target_fid, + )), + effects, + ); } // we neither kill background calls nor delayed calls since we are considering them detached from this // call tree. In the future we want to support a mode which also kills these calls (causally related). @@ -518,6 +637,92 @@ where Ok(()) } + async fn cancel_journal_leaves( + &mut self, + full_invocation_id: FullInvocationId, + invocation_status: InvocationStatusProjection, + journal_length: EntryIndex, + state: &mut State, + effects: &mut Effects, + ) -> Result { + let mut journal = state.get_journal(&full_invocation_id.service_id, journal_length); + + let canceled_result = CompletionResult::Failure( + UserErrorCode::from(CANCELED_INVOCATION_ERROR.code()), + CANCELED_INVOCATION_ERROR.message().into(), + ); + + let mut resume_invocation = false; + + while let Some(journal_entry) = journal.next().await { + let (journal_index, journal_entry) = journal_entry?; + + if let JournalEntry::Entry(journal_entry) = journal_entry { + let (header, _) = journal_entry.into_inner(); + match header { + // cancel uncompleted invocations + EnrichedEntryHeader::Invoke { + is_completed, + enrichment_result: Some(enrichment_result), + } if !is_completed => { + let target_fid = FullInvocationId::new( + enrichment_result.service_name, + enrichment_result.service_key, + enrichment_result.invocation_uuid, + ); + + self.send_message( + OutboxMessage::InvocationTermination(InvocationTermination::cancel( + target_fid, + )), + effects, + ); + } + EnrichedEntryHeader::Awakeable { is_completed } + | EnrichedEntryHeader::GetState { is_completed } + | EnrichedEntryHeader::Sleep { is_completed } + | EnrichedEntryHeader::PollInputStream { is_completed } + if !is_completed => + { + match &invocation_status { + InvocationStatusProjection::Invoked => { + Self::handle_completion_for_invoked( + full_invocation_id.clone(), + Completion::new(journal_index, canceled_result.clone()), + effects, + ) + } + InvocationStatusProjection::Suspended(waiting_for_completed_entry) => { + resume_invocation |= Self::handle_completion_for_suspended( + full_invocation_id.clone(), + Completion::new(journal_index, canceled_result.clone()), + waiting_for_completed_entry, + effects, + ); + } + InvocationStatusProjection::Virtual(notification_target) => { + Self::handle_completion_for_virtual( + full_invocation_id.clone(), + Completion::new(journal_index, canceled_result.clone()), + notification_target.clone(), + effects, + ) + } + } + } + header => { + assert!( + header.is_completed().unwrap_or(true), + "All non canceled journal entries must be completed." + ); + } + } + } + } + + Ok(resume_invocation) + } + async fn on_timer( &mut self, TimerValue { @@ -999,8 +1204,11 @@ where match status { InvocationStatus::Invoked(metadata) => { if metadata.invocation_uuid == full_invocation_id.invocation_uuid { - effects.store_completion(full_invocation_id.clone(), completion.clone()); - effects.forward_completion(full_invocation_id.clone(), completion); + Self::handle_completion_for_invoked( + full_invocation_id.clone(), + completion, + effects, + ); related_sid = Some(full_invocation_id); span_relation = metadata.journal_metadata.span_context.as_parent(); } else { @@ -1018,9 +1226,13 @@ where } => { if metadata.invocation_uuid == full_invocation_id.invocation_uuid { span_relation = metadata.journal_metadata.span_context.as_parent(); - effects.store_completion(full_invocation_id.clone(), completion.clone()); - if waiting_for_completed_entries.contains(&completion.entry_index) { + if Self::handle_completion_for_suspended( + full_invocation_id.clone(), + completion, + &waiting_for_completed_entries, + effects, + ) { effects.resume_service(full_invocation_id.service_id.clone(), metadata); } related_sid = Some(full_invocation_id); @@ -1042,16 +1254,14 @@ where ) } InvocationStatus::Virtual { - invocation_uuid, completion_notification_target, .. } => { - effects.store_completion(full_invocation_id.clone(), completion.clone()); - effects.notify_virtual_journal_completion( - completion_notification_target.service, - completion_notification_target.method, - invocation_uuid, + Self::handle_completion_for_virtual( + full_invocation_id, completion, + completion_notification_target, + effects, ); } } @@ -1059,6 +1269,46 @@ where Ok((related_sid, span_relation)) } + fn handle_completion_for_virtual( + full_invocation_id: FullInvocationId, + completion: Completion, + completion_notification_target: NotificationTarget, + effects: &mut Effects, + ) { + let invocation_uuid = full_invocation_id.invocation_uuid; + + effects.store_completion(full_invocation_id, completion.clone()); + effects.notify_virtual_journal_completion( + completion_notification_target.service, + completion_notification_target.method, + invocation_uuid, + completion, + ); + } + + fn handle_completion_for_suspended( + full_invocation_id: FullInvocationId, + completion: Completion, + waiting_for_completed_entries: &HashSet, + effects: &mut Effects, + ) -> bool { + let resume_invocation = waiting_for_completed_entries.contains(&completion.entry_index); + effects.store_completion(full_invocation_id, completion); + + resume_invocation + } + + fn handle_completion_for_invoked( + full_invocation_id: FullInvocationId, + completion: Completion, + effects: &mut Effects, + ) { + effects.store_completion(full_invocation_id.clone(), completion.clone()); + effects.forward_completion(full_invocation_id, completion); + } + + // TODO: Introduce distinction between invocation_status and service_instance_status to + // properly handle case when the given invocation is not executing + avoid cloning maybe_fid async fn read_invocation_status( maybe_full_invocation_id: MaybeFullInvocationId, state: &mut State, @@ -1180,6 +1430,13 @@ where } } +/// Projected [`InvocationStatus`] for cancellation purposes. +enum InvocationStatusProjection { + Invoked, + Suspended(HashSet), + Virtual(NotificationTarget), +} + fn extract_span_relation(status: &InvocationStatus) -> SpanRelation { match status { InvocationStatus::Invoked(metadata) => metadata.journal_metadata.span_context.as_parent(), @@ -1202,7 +1459,8 @@ mod tests { use bytestring::ByteString; use futures::future::ok; use futures::{stream, FutureExt}; - use googletest::{all, any, assert_that, pat}; + use googletest::matcher::Matcher; + use googletest::{all, any, assert_that, pat, unordered_elements_are}; use restate_invoker_api::EffectKind; use restate_service_protocol::awakeable_id::AwakeableIdentifier; use restate_service_protocol::codec::ProtobufRawEntryCodec; @@ -1224,26 +1482,91 @@ mod tests { } impl StateReaderMock { - fn register_invocation_status( + fn register_invoked_status(&mut self, fid: FullInvocationId, journal: Vec) { + let invocation_uuid = fid.invocation_uuid; + + self.register_invocation_status( + fid, + InvocationStatus::Invoked(Self::mock_invocation_metadata( + u32::try_from(journal.len()).unwrap(), + invocation_uuid, + )), + journal, + ); + } + + fn mock_invocation_metadata( + journal_length: u32, + invocation_uuid: InvocationUuid, + ) -> InvocationMetadata { + InvocationMetadata { + invocation_uuid, + journal_metadata: JournalMetadata { + length: journal_length, + span_context: ServiceInvocationSpanContext::empty(), + }, + deployment_id: None, + method: ByteString::from("".to_string()), + response_sink: None, + timestamps: StatusTimestamps::now(), + } + } + + fn register_suspended_status( &mut self, fid: FullInvocationId, + waiting_for_completed_entries: impl IntoIterator, journal: Vec, ) { - let service_id = fid.service_id.clone(); - self.invocations.insert( - fid.service_id, - InvocationStatus::Invoked(InvocationMetadata { - invocation_uuid: fid.invocation_uuid, + let invocation_uuid = fid.invocation_uuid; + + self.register_invocation_status( + fid, + InvocationStatus::Suspended { + metadata: Self::mock_invocation_metadata( + u32::try_from(journal.len()).unwrap(), + invocation_uuid, + ), + waiting_for_completed_entries: HashSet::from_iter( + waiting_for_completed_entries, + ), + }, + journal, + ); + } + + fn register_virtual_status( + &mut self, + fid: FullInvocationId, + completion_notification_target: NotificationTarget, + kill_notification_target: NotificationTarget, + journal: Vec, + ) { + let invocation_uuid = fid.invocation_uuid; + self.register_invocation_status( + fid, + InvocationStatus::Virtual { + invocation_uuid, journal_metadata: JournalMetadata { length: u32::try_from(journal.len()).unwrap(), span_context: ServiceInvocationSpanContext::empty(), }, - deployment_id: None, - method: ByteString::from("".to_string()), - response_sink: None, + completion_notification_target, + kill_notification_target, timestamps: StatusTimestamps::now(), - }), + }, + journal, ); + } + + fn register_invocation_status( + &mut self, + fid: FullInvocationId, + invocation_status: InvocationStatus, + journal: Vec, + ) { + let service_id = fid.service_id.clone(); + self.invocations.insert(fid.service_id, invocation_status); self.journals.insert(service_id, journal); } @@ -1408,7 +1731,7 @@ mod tests { }, }); - state_reader.register_invocation_status( + state_reader.register_invoked_status( sid_caller.clone(), vec![JournalEntry::Entry(EnrichedRawEntry::new( EnrichedEntryHeader::Awakeable { @@ -1466,7 +1789,7 @@ mod tests { }, }); - state_reader.register_invocation_status( + state_reader.register_invoked_status( sid_caller.clone(), vec![JournalEntry::Entry(EnrichedRawEntry::new( EnrichedEntryHeader::Awakeable { @@ -1511,7 +1834,7 @@ mod tests { result: ResponseResult::Success(Bytes::from_static(b"hello")), }); - state_reader.register_invocation_status(fid.clone(), vec![]); + state_reader.register_invoked_status(fid.clone(), vec![]); state_machine .on_apply(cmd, &mut effects, &mut state_reader) @@ -1542,7 +1865,7 @@ mod tests { let inboxed_fid = FullInvocationId::generate("svc", "key"); let caller_fid = FullInvocationId::mock_random(); - state_reader.register_invocation_status(fid, vec![]); + state_reader.register_invoked_status(fid, vec![]); state_reader.enqueue_into_inbox( inboxed_fid.service_id.clone(), InboxEntry { @@ -1560,7 +1883,9 @@ mod tests { command_interpreter .on_apply( - Command::Kill(MaybeFullInvocationId::from(inboxed_fid.clone())), + Command::TerminateInvocation(InvocationTermination::kill( + MaybeFullInvocationId::from(inboxed_fid.clone()), + )), &mut effects, &mut state_reader, ) @@ -1604,50 +1929,20 @@ mod tests { let background_fid = FullInvocationId::mock_random(); let finished_call_fid = FullInvocationId::mock_random(); - state_reader.register_invocation_status( + state_reader.register_invoked_status( fid.clone(), vec![ - JournalEntry::Entry(EnrichedRawEntry::new( - EnrichedEntryHeader::Invoke { - is_completed: false, - enrichment_result: Some(InvokeEnrichmentResult { - invocation_uuid: call_fid.invocation_uuid, - service_key: call_fid.service_id.key.clone(), - service_name: call_fid.service_id.service_name.clone(), - span_context: ServiceInvocationSpanContext::empty(), - }), - }, - Bytes::default(), - )), - JournalEntry::Entry(EnrichedRawEntry::new( - EnrichedEntryHeader::BackgroundInvoke { - enrichment_result: InvokeEnrichmentResult { - invocation_uuid: background_fid.invocation_uuid, - service_key: background_fid.service_id.key.clone(), - service_name: background_fid.service_id.service_name.clone(), - span_context: ServiceInvocationSpanContext::empty(), - }, - }, - Bytes::default(), - )), - JournalEntry::Entry(EnrichedRawEntry::new( - EnrichedEntryHeader::Invoke { - is_completed: true, - enrichment_result: Some(InvokeEnrichmentResult { - invocation_uuid: finished_call_fid.invocation_uuid, - service_key: finished_call_fid.service_id.key.clone(), - service_name: finished_call_fid.service_id.service_name.clone(), - span_context: ServiceInvocationSpanContext::empty(), - }), - }, - Bytes::default(), - )), + uncompleted_invoke_entry(call_fid.clone()), + background_invoke_entry(background_fid.clone()), + completed_invoke_entry(finished_call_fid.clone()), ], ); command_interpreter .on_apply( - Command::Kill(MaybeFullInvocationId::from(fid.clone())), + Command::TerminateInvocation(InvocationTermination::kill( + MaybeFullInvocationId::from(fid.clone()), + )), &mut effects, &mut state_reader, ) @@ -1662,19 +1957,315 @@ mod tests { contains(pat!(Effect::DropJournalAndFreeService { service_id: eq(fid.service_id.clone()), })), - contains(pat!(Effect::EnqueueIntoOutbox { - message: pat!(restate_storage_api::outbox_table::OutboxMessage::Kill(eq( - call_fid - ))) - })), + contains(terminate_invocation_outbox_message_matcher( + call_fid, + TerminationFlavor::Kill + )), not(contains(pat!(Effect::EnqueueIntoOutbox { - message: pat!(restate_storage_api::outbox_table::OutboxMessage::Kill( - any!(eq(background_fid), eq(finished_call_fid)) - )) + message: pat!( + restate_storage_api::outbox_table::OutboxMessage::InvocationTermination( + pat!(InvocationTermination { + maybe_fid: any!( + eq(MaybeFullInvocationId::from(background_fid)), + eq(MaybeFullInvocationId::from(finished_call_fid)) + ) + }) + ) + ) }))) ) ); Ok(()) } + + fn completed_invoke_entry(target_fid: FullInvocationId) -> JournalEntry { + JournalEntry::Entry(EnrichedRawEntry::new( + EnrichedEntryHeader::Invoke { + is_completed: true, + enrichment_result: Some(InvokeEnrichmentResult { + invocation_uuid: target_fid.invocation_uuid, + service_key: target_fid.service_id.key, + service_name: target_fid.service_id.service_name, + span_context: ServiceInvocationSpanContext::empty(), + }), + }, + Bytes::default(), + )) + } + + fn background_invoke_entry(target_fid: FullInvocationId) -> JournalEntry { + JournalEntry::Entry(EnrichedRawEntry::new( + EnrichedEntryHeader::BackgroundInvoke { + enrichment_result: InvokeEnrichmentResult { + invocation_uuid: target_fid.invocation_uuid, + service_key: target_fid.service_id.key, + service_name: target_fid.service_id.service_name, + span_context: ServiceInvocationSpanContext::empty(), + }, + }, + Bytes::default(), + )) + } + + fn uncompleted_invoke_entry(target_fid: FullInvocationId) -> JournalEntry { + JournalEntry::Entry(EnrichedRawEntry::new( + EnrichedEntryHeader::Invoke { + is_completed: false, + enrichment_result: Some(InvokeEnrichmentResult { + invocation_uuid: target_fid.invocation_uuid, + service_key: target_fid.service_id.key, + service_name: target_fid.service_id.service_name, + span_context: ServiceInvocationSpanContext::empty(), + }), + }, + Bytes::default(), + )) + } + + #[test(tokio::test)] + async fn cancel_invoked_invocation() -> Result<(), Error> { + let mut command_interpreter = CommandInterpreter::::new(0, 0); + let mut state_reader = StateReaderMock::default(); + let mut effects = Effects::default(); + + let fid = FullInvocationId::mock_random(); + let call_fid = FullInvocationId::mock_random(); + let background_fid = FullInvocationId::mock_random(); + let finished_call_fid = FullInvocationId::mock_random(); + + state_reader.register_invoked_status( + fid.clone(), + create_termination_journal( + call_fid.clone(), + background_fid.clone(), + finished_call_fid.clone(), + ), + ); + + command_interpreter + .on_apply( + Command::TerminateInvocation(InvocationTermination::cancel( + MaybeFullInvocationId::from(fid.clone()), + )), + &mut effects, + &mut state_reader, + ) + .await?; + + let effects = effects.into_inner(); + + assert_that!( + effects, + unordered_elements_are![ + terminate_invocation_outbox_message_matcher(call_fid, TerminationFlavor::Cancel), + store_canceled_completion_matcher(3), + store_canceled_completion_matcher(4), + store_canceled_completion_matcher(5), + store_canceled_completion_matcher(6), + forward_canceled_completion_matcher(3), + forward_canceled_completion_matcher(4), + forward_canceled_completion_matcher(5), + forward_canceled_completion_matcher(6), + ] + ); + + Ok(()) + } + + #[test(tokio::test)] + async fn cancel_suspended_invocation() -> Result<(), Error> { + let mut command_interpreter = CommandInterpreter::::new(0, 0); + let mut state_reader = StateReaderMock::default(); + let mut effects = Effects::default(); + + let fid = FullInvocationId::mock_random(); + let call_fid = FullInvocationId::mock_random(); + let background_fid = FullInvocationId::mock_random(); + let finished_call_fid = FullInvocationId::mock_random(); + + let journal = create_termination_journal( + call_fid.clone(), + background_fid.clone(), + finished_call_fid.clone(), + ); + state_reader.register_suspended_status(fid.clone(), vec![3, 4, 5, 6], journal); + + command_interpreter + .on_apply( + Command::TerminateInvocation(InvocationTermination::cancel( + MaybeFullInvocationId::from(fid.clone()), + )), + &mut effects, + &mut state_reader, + ) + .await?; + + let effects = effects.into_inner(); + + assert_that!( + effects, + unordered_elements_are![ + terminate_invocation_outbox_message_matcher(call_fid, TerminationFlavor::Cancel), + store_canceled_completion_matcher(3), + store_canceled_completion_matcher(4), + store_canceled_completion_matcher(5), + store_canceled_completion_matcher(6), + pat!(Effect::ResumeService { + service_id: eq(fid.service_id), + }), + ] + ); + + Ok(()) + } + + #[test(tokio::test)] + async fn cancel_virtual_invocation() -> Result<(), Error> { + let mut command_interpreter = CommandInterpreter::::new(0, 0); + let mut state_reader = StateReaderMock::default(); + let mut effects = Effects::default(); + + let fid = FullInvocationId::mock_random(); + let call_fid = FullInvocationId::mock_random(); + let background_fid = FullInvocationId::mock_random(); + let finished_call_fid = FullInvocationId::mock_random(); + + let notification_service_id = ServiceId::new("notification", "key"); + + let completion_notification_target = NotificationTarget { + service: notification_service_id.clone(), + method: "completion".to_owned(), + }; + let kill_notification_target = NotificationTarget { + service: notification_service_id, + method: "kill".to_owned(), + }; + + state_reader.register_virtual_status( + fid.clone(), + completion_notification_target, + kill_notification_target, + create_termination_journal( + call_fid.clone(), + background_fid.clone(), + finished_call_fid.clone(), + ), + ); + + command_interpreter + .on_apply( + Command::TerminateInvocation(InvocationTermination::cancel( + MaybeFullInvocationId::from(fid.clone()), + )), + &mut effects, + &mut state_reader, + ) + .await?; + + let effects = effects.into_inner(); + + assert_that!( + effects, + unordered_elements_are![ + terminate_invocation_outbox_message_matcher(call_fid, TerminationFlavor::Cancel), + store_canceled_completion_matcher(3), + store_canceled_completion_matcher(4), + store_canceled_completion_matcher(5), + store_canceled_completion_matcher(6), + notify_virtual_journal_canceled_completion_matcher(3), + notify_virtual_journal_canceled_completion_matcher(4), + notify_virtual_journal_canceled_completion_matcher(5), + notify_virtual_journal_canceled_completion_matcher(6), + ] + ); + + Ok(()) + } + + fn create_termination_journal( + call_fid: FullInvocationId, + background_fid: FullInvocationId, + finished_call_fid: FullInvocationId, + ) -> Vec { + vec![ + uncompleted_invoke_entry(call_fid), + completed_invoke_entry(finished_call_fid), + background_invoke_entry(background_fid), + JournalEntry::Entry(EnrichedRawEntry::new( + EnrichedEntryHeader::PollInputStream { + is_completed: false, + }, + Bytes::default(), + )), + JournalEntry::Entry(EnrichedRawEntry::new( + EnrichedEntryHeader::GetState { + is_completed: false, + }, + Bytes::default(), + )), + JournalEntry::Entry(EnrichedRawEntry::new( + EnrichedEntryHeader::Sleep { + is_completed: false, + }, + Bytes::default(), + )), + JournalEntry::Entry(EnrichedRawEntry::new( + EnrichedEntryHeader::Awakeable { + is_completed: false, + }, + Bytes::default(), + )), + ] + } + + fn canceled_completion_matcher(entry_index: EntryIndex) -> impl Matcher { + pat!(Completion { + entry_index: eq(entry_index), + result: pat!(CompletionResult::Failure( + eq(UserErrorCode::Cancelled), + eq(ByteString::from_static("canceled")) + )) + }) + } + + fn store_canceled_completion_matcher( + entry_index: EntryIndex, + ) -> impl Matcher { + pat!(Effect::StoreCompletion { + completion: canceled_completion_matcher(entry_index), + }) + } + + fn forward_canceled_completion_matcher( + entry_index: EntryIndex, + ) -> impl Matcher { + pat!(Effect::ForwardCompletion { + completion: canceled_completion_matcher(entry_index), + }) + } + + fn notify_virtual_journal_canceled_completion_matcher( + entry_index: EntryIndex, + ) -> impl Matcher { + pat!(Effect::NotifyVirtualJournalCompletion { + completion: canceled_completion_matcher(entry_index), + }) + } + + fn terminate_invocation_outbox_message_matcher( + target_fid: impl Into, + termination_flavor: TerminationFlavor, + ) -> impl Matcher { + pat!(Effect::EnqueueIntoOutbox { + message: pat!( + restate_storage_api::outbox_table::OutboxMessage::InvocationTermination(pat!( + InvocationTermination { + maybe_fid: eq(target_fid.into()), + flavor: eq(termination_flavor) + } + )) + ) + }) + } } diff --git a/crates/worker/src/partition/state_machine/commands.rs b/crates/worker/src/partition/state_machine/commands.rs index 274e18ae87..96b4db64dd 100644 --- a/crates/worker/src/partition/state_machine/commands.rs +++ b/crates/worker/src/partition/state_machine/commands.rs @@ -11,7 +11,7 @@ use crate::partition::services::non_deterministic::Effects as NBISEffects; use crate::partition::types::{InvokerEffect, TimerValue}; use restate_types::identifiers::{IngressDispatcherId, PartitionId, PeerId}; -use restate_types::invocation::{InvocationResponse, MaybeFullInvocationId, ServiceInvocation}; +use restate_types::invocation::{InvocationResponse, InvocationTermination, ServiceInvocation}; use restate_types::message::{AckKind, MessageIndex}; /// Envelope for [`partition::Command`] that might require an explicit acknowledge. @@ -202,7 +202,7 @@ pub struct IngressAckResponse { /// State machine input commands #[derive(Debug)] pub enum Command { - Kill(MaybeFullInvocationId), + TerminateInvocation(InvocationTermination), Invoker(InvokerEffect), Timer(TimerValue), OutboxTruncation(MessageIndex), diff --git a/crates/worker/src/partition/state_machine/effects.rs b/crates/worker/src/partition/state_machine/effects.rs index 4c0397298c..c7a0662a54 100644 --- a/crates/worker/src/partition/state_machine/effects.rs +++ b/crates/worker/src/partition/state_machine/effects.rs @@ -284,13 +284,13 @@ impl Effect { ), Effect::EnqueueIntoOutbox { seq_number, - message: OutboxMessage::Kill(fid), + message: OutboxMessage::InvocationTermination(invocation_termination), } => debug_if_leader!( is_leader, - rpc.service = %fid.service_id.service_name, - restate.invocation.id = %fid, + restate.invocation.id = %invocation_termination.maybe_fid, restate.outbox.seq = seq_number, - "Effect: Send kill command to partition processor", + "Effect: Send invocation termination command '{:?}' to partition processor", + invocation_termination.flavor ), Effect::EnqueueIntoOutbox { seq_number, diff --git a/crates/worker/src/partition/state_machine/mod.rs b/crates/worker/src/partition/state_machine/mod.rs index c86bc1aab9..e7721bbc36 100644 --- a/crates/worker/src/partition/state_machine/mod.rs +++ b/crates/worker/src/partition/state_machine/mod.rs @@ -106,8 +106,8 @@ mod tests { FullInvocationId, InvocationUuid, PartitionId, PartitionKey, ServiceId, WithPartitionKey, }; use restate_types::invocation::{ - InvocationResponse, MaybeFullInvocationId, ResponseResult, ServiceInvocation, - ServiceInvocationResponseSink, ServiceInvocationSpanContext, + InvocationResponse, InvocationTermination, MaybeFullInvocationId, ResponseResult, + ServiceInvocation, ServiceInvocationResponseSink, ServiceInvocationSpanContext, }; use restate_types::journal::enriched::EnrichedRawEntry; use restate_types::journal::{Completion, CompletionResult}; @@ -346,8 +346,8 @@ mod tests { assert!(result.is_some()); let actions = state_machine - .apply_cmd(Command::Kill(MaybeFullInvocationId::from( - inboxed_fid.clone(), + .apply_cmd(Command::TerminateInvocation(InvocationTermination::kill( + MaybeFullInvocationId::from(inboxed_fid.clone()), ))) .await; diff --git a/crates/worker/src/partition/storage/invoker.rs b/crates/worker/src/partition/storage/invoker.rs index 45b4cbc474..9e19af2a55 100644 --- a/crates/worker/src/partition/storage/invoker.rs +++ b/crates/worker/src/partition/storage/invoker.rs @@ -62,9 +62,8 @@ where let journal_stream = transaction .get_journal(&fid.service_id, journal_metadata.length) .map(|entry| { - entry - .map_err(InvokerStorageReaderError::Storage) - .map(|(_, journal_entry)| match journal_entry { + entry.map_err(InvokerStorageReaderError::Storage).map( + |(_, journal_entry)| match journal_entry { JournalEntry::Entry(entry) => entry.erase_enrichment(), JournalEntry::Completion(_) => { panic!("should only read entries when reading the journal") diff --git a/crates/worker/src/services.rs b/crates/worker/src/services.rs index 3e9452cf38..1eb07398fa 100644 --- a/crates/worker/src/services.rs +++ b/crates/worker/src/services.rs @@ -17,7 +17,7 @@ use restate_consensus::ProposalSender; use restate_network::PartitionTableError; use restate_types::identifiers::InvocationId; use restate_types::identifiers::WithPartitionKey; -use restate_types::invocation::MaybeFullInvocationId; +use restate_types::invocation::{InvocationTermination, MaybeFullInvocationId}; use restate_types::message::PeerTarget; use tokio::sync::mpsc; use tracing::debug; @@ -131,7 +131,7 @@ where let target_peer_id = partition_table .partition_key_to_target_peer(invocation_id.partition_key()) .await?; - let msg = StateMachineAckCommand::no_ack(StateMachineCommand::Kill(MaybeFullInvocationId::from(invocation_id))); + let msg = StateMachineAckCommand::no_ack(StateMachineCommand::TerminateInvocation(InvocationTermination::kill(MaybeFullInvocationId::from(invocation_id)))); proposal_tx.send((target_peer_id, msg)).await.map_err(|_| Error::ConsensusClosed)? } }