diff --git a/Cargo.lock b/Cargo.lock index b96488b5f9..bb35e838e1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6026,11 +6026,8 @@ dependencies = [ "bytes-utils", "codederror", "hyper 0.14.28", - "jsonptr", "paste", - "prettyplease", "prost", - "prost-build", "regress 0.9.0", "restate-base64-util", "restate-errors", @@ -6038,18 +6035,15 @@ dependencies = [ "restate-service-client", "restate-test-util", "restate-types", - "schemars", "serde", "serde_json", "size", "strum_macros 0.26.2", - "syn 2.0.55", "test-log", "thiserror", "tokio", "tracing", "tracing-subscriber", - "typify", "uuid", ] @@ -6226,19 +6220,26 @@ dependencies = [ "hostname", "http 0.2.12", "humantime", + "jsonptr", "num-traits", "once_cell", "opentelemetry", + "prettyplease", + "prost", + "prost-build", "rand", + "regress 0.9.0", "restate-base64-util", "restate-serde-util", "restate-test-util", "schemars", "serde", + "serde_json", "serde_with", "sha2", "strum 0.26.2", "strum_macros 0.26.2", + "syn 2.0.55", "sync_wrapper", "tempfile", "test-log", @@ -6248,6 +6249,7 @@ dependencies = [ "toml", "tracing", "tracing-opentelemetry", + "typify", "ulid", "uuid", "xxhash-rust", @@ -6730,6 +6732,7 @@ dependencies = [ "bytes", "mlua", "restate-service-protocol", + "restate-types", "thiserror", ] diff --git a/crates/admin/src/schema_registry/error.rs b/crates/admin/src/schema_registry/error.rs index 708e5a49fb..713b29a674 100644 --- a/crates/admin/src/schema_registry/error.rs +++ b/crates/admin/src/schema_registry/error.rs @@ -14,7 +14,7 @@ use http::Uri; use restate_core::metadata_store::ReadModifyWriteError; use restate_core::ShutdownError; use restate_schema_api::invocation_target::BadInputContentType; -use restate_service_protocol::discovery::schema; +use restate_types::endpoint_manifest; use restate_types::errors::GenericError; use restate_types::identifiers::DeploymentId; use restate_types::invocation::ServiceType; @@ -91,7 +91,7 @@ pub enum ServiceError { BadOutputContentType(String, InvalidHeaderValue), #[error("invalid combination of service type and handler type '({0}, {1:?})'")] #[code(unknown)] - BadServiceAndHandlerType(ServiceType, Option), + BadServiceAndHandlerType(ServiceType, Option), #[error("modifying retention time for service type {0} is unsupported")] #[code(unknown)] CannotModifyRetentionTime(ServiceType), diff --git a/crates/admin/src/schema_registry/updater.rs b/crates/admin/src/schema_registry/updater.rs index b69301f814..a73f9df5a0 100644 --- a/crates/admin/src/schema_registry/updater.rs +++ b/crates/admin/src/schema_registry/updater.rs @@ -24,7 +24,7 @@ use restate_schema_api::invocation_target::{ use restate_schema_api::subscription::{ EventReceiverServiceType, Sink, Source, Subscription, SubscriptionValidator, }; -use restate_service_protocol::discovery::schema; +use restate_types::endpoint_manifest; use restate_types::identifiers::{DeploymentId, SubscriptionId}; use restate_types::invocation::{ InvocationTargetType, ServiceType, VirtualObjectHandlerType, WorkflowHandlerType, @@ -65,7 +65,7 @@ impl SchemaUpdater { &mut self, requested_deployment_id: Option, deployment_metadata: DeploymentMetadata, - services: Vec, + services: Vec, force: bool, ) -> Result { let deployment_id: Option; @@ -450,22 +450,23 @@ struct DiscoveredHandlerMetadata { impl DiscoveredHandlerMetadata { fn from_schema( service_type: ServiceType, - handler: schema::Handler, + handler: endpoint_manifest::Handler, ) -> Result { let ty = match (service_type, handler.ty) { - (ServiceType::Service, None | Some(schema::HandlerType::Shared)) => { + (ServiceType::Service, None | Some(endpoint_manifest::HandlerType::Shared)) => { InvocationTargetType::Service } - (ServiceType::VirtualObject, None | Some(schema::HandlerType::Exclusive)) => { - InvocationTargetType::VirtualObject(VirtualObjectHandlerType::Exclusive) - } - (ServiceType::VirtualObject, Some(schema::HandlerType::Shared)) => { + ( + ServiceType::VirtualObject, + None | Some(endpoint_manifest::HandlerType::Exclusive), + ) => InvocationTargetType::VirtualObject(VirtualObjectHandlerType::Exclusive), + (ServiceType::VirtualObject, Some(endpoint_manifest::HandlerType::Shared)) => { InvocationTargetType::VirtualObject(VirtualObjectHandlerType::Shared) } - (ServiceType::Workflow, None | Some(schema::HandlerType::Shared)) => { + (ServiceType::Workflow, None | Some(endpoint_manifest::HandlerType::Shared)) => { InvocationTargetType::Workflow(WorkflowHandlerType::Shared) } - (ServiceType::Workflow, Some(schema::HandlerType::Workflow)) => { + (ServiceType::Workflow, Some(endpoint_manifest::HandlerType::Workflow)) => { InvocationTargetType::Workflow(WorkflowHandlerType::Workflow) } _ => { @@ -494,7 +495,7 @@ impl DiscoveredHandlerMetadata { fn input_rules_from_schema( handler_name: &str, - schema: schema::InputPayload, + schema: endpoint_manifest::InputPayload, ) -> Result { let required = schema.required.unwrap_or(false); @@ -527,7 +528,7 @@ impl DiscoveredHandlerMetadata { } fn output_rules_from_schema( - schema: schema::OutputPayload, + schema: endpoint_manifest::OutputPayload, ) -> Result { Ok(if let Some(ct) = schema.content_type { OutputRules { @@ -589,11 +590,11 @@ mod tests { const GREETER_SERVICE_NAME: &str = "greeter.Greeter"; const ANOTHER_GREETER_SERVICE_NAME: &str = "greeter.AnotherGreeter"; - fn greeter_service() -> schema::Service { - schema::Service { - ty: schema::ServiceType::Service, + fn greeter_service() -> endpoint_manifest::Service { + endpoint_manifest::Service { + ty: endpoint_manifest::ServiceType::Service, name: GREETER_SERVICE_NAME.parse().unwrap(), - handlers: vec![schema::Handler { + handlers: vec![endpoint_manifest::Handler { name: "greet".parse().unwrap(), ty: None, input: None, @@ -602,11 +603,11 @@ mod tests { } } - fn greeter_virtual_object() -> schema::Service { - schema::Service { - ty: schema::ServiceType::VirtualObject, + fn greeter_virtual_object() -> endpoint_manifest::Service { + endpoint_manifest::Service { + ty: endpoint_manifest::ServiceType::VirtualObject, name: GREETER_SERVICE_NAME.parse().unwrap(), - handlers: vec![schema::Handler { + handlers: vec![endpoint_manifest::Handler { name: "greet".parse().unwrap(), ty: None, input: None, @@ -615,11 +616,11 @@ mod tests { } } - fn another_greeter_service() -> schema::Service { - schema::Service { - ty: schema::ServiceType::Service, + fn another_greeter_service() -> endpoint_manifest::Service { + endpoint_manifest::Service { + ty: endpoint_manifest::ServiceType::Service, name: ANOTHER_GREETER_SERVICE_NAME.parse().unwrap(), - handlers: vec![schema::Handler { + handlers: vec![endpoint_manifest::Handler { name: "another_greeter".parse().unwrap(), ty: None, input: None, @@ -918,18 +919,18 @@ mod tests { use restate_test_util::{check, let_assert}; use test_log::test; - fn greeter_v1_service() -> schema::Service { - schema::Service { - ty: schema::ServiceType::Service, + fn greeter_v1_service() -> endpoint_manifest::Service { + endpoint_manifest::Service { + ty: endpoint_manifest::ServiceType::Service, name: GREETER_SERVICE_NAME.parse().unwrap(), handlers: vec![ - schema::Handler { + endpoint_manifest::Handler { name: "greet".parse().unwrap(), ty: None, input: None, output: None, }, - schema::Handler { + endpoint_manifest::Handler { name: "doSomething".parse().unwrap(), ty: None, input: None, @@ -939,11 +940,11 @@ mod tests { } } - fn greeter_v2_service() -> schema::Service { - schema::Service { - ty: schema::ServiceType::Service, + fn greeter_v2_service() -> endpoint_manifest::Service { + endpoint_manifest::Service { + ty: endpoint_manifest::ServiceType::Service, name: GREETER_SERVICE_NAME.parse().unwrap(), - handlers: vec![schema::Handler { + handlers: vec![endpoint_manifest::Handler { name: "greet".parse().unwrap(), ty: None, input: None, diff --git a/crates/errors/src/error_codes/RT0013.md b/crates/errors/src/error_codes/RT0013.md new file mode 100644 index 0000000000..e54af9a7f0 --- /dev/null +++ b/crates/errors/src/error_codes/RT0013.md @@ -0,0 +1,8 @@ +## RT0013 + +The service endpoint does not support any of the supported service protocol versions of the server. Therefore, the server cannot talk to this endpoint. Please make sure that the service endpoint's SDK and the Restate server are compatible. + +Suggestions: + +* Register a service endpoint which uses an SDK which is compatible with the used server +* Upgrade the server to a version which is compatible with the used SDK \ No newline at end of file diff --git a/crates/errors/src/error_codes/RT0014.md b/crates/errors/src/error_codes/RT0014.md new file mode 100644 index 0000000000..c30a3a02a3 --- /dev/null +++ b/crates/errors/src/error_codes/RT0014.md @@ -0,0 +1,8 @@ +## RT0014 + +The server cannot resume an in-flight invocation which has been started with a now incompatible service protocol version. Restate does not support upgrading service protocols yet. + +Suggestions: + +* Downgrade the server to a version which is compatible with the used service protocol version +* Kill the affected invocation via the CLI. \ No newline at end of file diff --git a/crates/errors/src/lib.rs b/crates/errors/src/lib.rs index 4dcf95f00f..513270ced1 100644 --- a/crates/errors/src/lib.rs +++ b/crates/errors/src/lib.rs @@ -35,8 +35,9 @@ mod helper; // META are meta related errors. declare_restate_error_codes!( - RT0001, RT0002, RT0003, RT0004, RT0005, RT0006, RT0007, RT0009, RT0010, RT0011, RT0012, - META0003, META0004, META0005, META0006, META0009, META0010, META0011, META0012, META0013 + RT0001, RT0002, RT0003, RT0004, RT0005, RT0006, RT0007, RT0009, RT0010, RT0011, RT0012, RT0013, + RT0014, META0003, META0004, META0005, META0006, META0009, META0010, META0011, META0012, + META0013 ); // -- Some commonly used errors diff --git a/crates/invoker-api/Cargo.toml b/crates/invoker-api/Cargo.toml index 2bdf534d8c..d1a22fcba4 100644 --- a/crates/invoker-api/Cargo.toml +++ b/crates/invoker-api/Cargo.toml @@ -13,8 +13,9 @@ mocks = [] serde = ["dep:serde"] [dependencies] -restate-types = { workspace = true } restate-errors = { workspace = true } +restate-types = { workspace = true } + anyhow = { workspace = true } bytes = { workspace = true } diff --git a/crates/invoker-api/src/effects.rs b/crates/invoker-api/src/effects.rs index dc3126d142..1e861ae20b 100644 --- a/crates/invoker-api/src/effects.rs +++ b/crates/invoker-api/src/effects.rs @@ -8,9 +8,10 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +use restate_types::deployment::PinnedDeployment; use restate_types::errors::InvocationError; +use restate_types::identifiers::EntryIndex; use restate_types::identifiers::InvocationId; -use restate_types::identifiers::{DeploymentId, EntryIndex}; use restate_types::journal::enriched::EnrichedRawEntry; use std::collections::HashSet; @@ -25,7 +26,7 @@ pub struct Effect { #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum EffectKind { /// This is sent before any new entry is created by the invoker. This won't be sent if the deployment_id is already set. - SelectedDeployment(DeploymentId), + PinnedDeployment(PinnedDeployment), JournalEntry { entry_index: EntryIndex, entry: EnrichedRawEntry, diff --git a/crates/invoker-api/src/journal_reader.rs b/crates/invoker-api/src/journal_reader.rs index a98f7ccfb9..cd0300de68 100644 --- a/crates/invoker-api/src/journal_reader.rs +++ b/crates/invoker-api/src/journal_reader.rs @@ -9,7 +9,8 @@ // by the Apache License, Version 2.0. use futures::Stream; -use restate_types::identifiers::{DeploymentId, InvocationId}; +use restate_types::deployment::PinnedDeployment; +use restate_types::identifiers::InvocationId; use restate_types::invocation::ServiceInvocationSpanContext; use restate_types::journal::raw::PlainRawEntry; use restate_types::journal::EntryIndex; @@ -20,17 +21,17 @@ use std::future::Future; pub struct JournalMetadata { pub length: EntryIndex, pub span_context: ServiceInvocationSpanContext, - pub deployment_id: Option, + pub pinned_deployment: Option, } impl JournalMetadata { pub fn new( length: EntryIndex, span_context: ServiceInvocationSpanContext, - deployment_id: Option, + pinned_deployment: Option, ) -> Self { Self { - deployment_id, + pinned_deployment, span_context, length, } diff --git a/crates/invoker-impl/src/invocation_state_machine.rs b/crates/invoker-impl/src/invocation_state_machine.rs index 417366fd34..b2b27498ca 100644 --- a/crates/invoker-impl/src/invocation_state_machine.rs +++ b/crates/invoker-impl/src/invocation_state_machine.rs @@ -10,7 +10,6 @@ use super::*; -use restate_types::identifiers::DeploymentId; use restate_types::journal::Completion; use restate_types::retries; use std::fmt; @@ -85,7 +84,7 @@ enum InvocationState { entries_to_ack: HashSet, // If Some, we need to notify the deployment id to the partition processor - chosen_deployment: Option, + pinned_deployment: Option, }, WaitingRetry { @@ -154,7 +153,7 @@ impl InvocationStateMachine { journal_tracker: Default::default(), abort_handle, entries_to_ack: Default::default(), - chosen_deployment: None, + pinned_deployment: None, }; } @@ -164,34 +163,34 @@ impl InvocationStateMachine { } } - pub(super) fn notify_chosen_deployment(&mut self, endpoint_id: DeploymentId) { + pub(super) fn notify_pinned_deployment(&mut self, deployment: PinnedDeployment) { debug_assert!(matches!( &self.invocation_state, InvocationState::InFlight { - chosen_deployment: None, + pinned_deployment: None, .. } )); if let InvocationState::InFlight { - chosen_deployment, .. + pinned_deployment, .. } = &mut self.invocation_state { - *chosen_deployment = Some(endpoint_id); + *pinned_deployment = Some(deployment); } } - pub(super) fn chosen_deployment_to_notify(&mut self) -> Option { + pub(super) fn pinned_deployment_to_notify(&mut self) -> Option { debug_assert!(matches!( &self.invocation_state, InvocationState::InFlight { .. } )); if let InvocationState::InFlight { - chosen_deployment, .. + pinned_deployment, .. } = &mut self.invocation_state { - chosen_deployment.take() + pinned_deployment.take() } else { None } diff --git a/crates/invoker-impl/src/invocation_task.rs b/crates/invoker-impl/src/invocation_task/mod.rs similarity index 50% rename from crates/invoker-impl/src/invocation_task.rs rename to crates/invoker-impl/src/invocation_task/mod.rs index db727deae1..4d800b16cc 100644 --- a/crates/invoker-impl/src/invocation_task.rs +++ b/crates/invoker-impl/src/invocation_task/mod.rs @@ -8,54 +8,49 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +mod service_protocol_runner; + use super::Notification; +use crate::invocation_task::service_protocol_runner::ServiceProtocolRunner; use bytes::Bytes; -use futures::future::FusedFuture; -use futures::{future, stream, FutureExt, Stream, StreamExt}; -use hyper::body::Sender; +use futures::{future, stream, FutureExt, StreamExt}; use hyper::http::response::Parts as ResponseParts; -use hyper::http::uri::PathAndQuery; use hyper::http::{HeaderName, HeaderValue}; -use hyper::{http, Body, HeaderMap, Response}; -use opentelemetry::propagation::TextMapPropagator; -use opentelemetry_http::HeaderInjector; -use opentelemetry_sdk::propagation::TraceContextPropagator; -use restate_errors::warn_it; +use hyper::{http, Body, Response}; use restate_invoker_api::{ EagerState, EntryEnricher, InvocationErrorReport, InvokeInputJournal, JournalReader, StateReader, }; -use restate_schema_api::deployment::{ - DeploymentMetadata, DeploymentResolver, DeploymentType, ProtocolType, -}; -use restate_service_client::{Endpoint, Parts, Request, ServiceClient, ServiceClientError}; -use restate_service_protocol::message::{ - Decoder, Encoder, EncodingError, MessageHeader, MessageType, ProtocolMessage, -}; +use restate_schema_api::deployment::DeploymentResolver; +use restate_service_client::{Request, ServiceClient, ServiceClientError}; +use restate_service_protocol::message::{EncodingError, MessageType}; +use restate_types::deployment::PinnedDeployment; use restate_types::errors::InvocationError; use restate_types::identifiers::{DeploymentId, EntryIndex, InvocationId, PartitionLeaderEpoch}; -use restate_types::invocation::{InvocationTarget, ServiceInvocationSpanContext}; +use restate_types::invocation::InvocationTarget; use restate_types::journal::enriched::EnrichedRawEntry; -use restate_types::journal::raw::PlainRawEntry; use restate_types::journal::EntryType; +use restate_types::service_protocol::ServiceProtocolVersion; +use restate_types::service_protocol::{MAX_SERVICE_PROTOCOL_VERSION, MIN_SERVICE_PROTOCOL_VERSION}; use std::collections::HashSet; use std::error::Error; -use std::future::{poll_fn, Future}; +use std::future::Future; use std::iter; +use std::ops::RangeInclusive; use std::pin::Pin; use std::task::{ready, Context, Poll}; use std::time::Duration; use tokio::sync::mpsc; use tokio::task::JoinError; use tokio::task::JoinHandle; -use tracing::{debug, info, instrument, trace, warn, Span}; -use tracing_opentelemetry::OpenTelemetrySpanExt; +use tracing::instrument; // Clippy false positive, might be caused by Bytes contained within HeaderValue. // https://github.com/rust-lang/rust/issues/40543#issuecomment-1212981256 #[allow(clippy::declare_interior_mutable_const)] -const APPLICATION_RESTATE: HeaderValue = HeaderValue::from_static("application/restate"); +const SERVICE_PROTOCOL_VERSION_V1: HeaderValue = + HeaderValue::from_static("application/vnd.restate.invocation.v1"); #[allow(clippy::declare_interior_mutable_const)] const X_RESTATE_SERVER: HeaderName = HeaderName::from_static("x-restate-server"); @@ -72,9 +67,9 @@ pub(crate) enum InvocationTaskError { #[error("unexpected http status code: {0}")] #[code(restate_errors::RT0012)] UnexpectedResponse(http::StatusCode), - #[error("unexpected content type: {0:?}")] + #[error("unexpected content type '{0:?}'; expected content type '{1:?}'")] #[code(restate_errors::RT0012)] - UnexpectedContentType(Option), + UnexpectedContentType(Option, HeaderValue), #[error("received unexpected message: {0:?}")] #[code(restate_errors::RT0012)] UnexpectedMessage(MessageType), @@ -127,6 +122,12 @@ pub(crate) enum InvocationTaskError { Option, #[source] InvocationError, ), + #[error("cannot talk to service endpoint '{0}' because its service protocol versions [{}, {}] are incompatible with the server's service protocol versions [{}, {}].", .1.start(), .1.end(), i32::from(MIN_SERVICE_PROTOCOL_VERSION), i32::from(MAX_SERVICE_PROTOCOL_VERSION))] + #[code(restate_errors::RT0013)] + IncompatibleServiceEndpoint(DeploymentId, RangeInclusive), + #[error("cannot resume invocation because it was created with an incompatible service protocol version '{}' and the server does not support upgrading versions yet", .0.as_repr())] + #[code(restate_errors::RT0014)] + UnsupportedServiceProtocolVersion(ServiceProtocolVersion), } #[derive(Debug, Default)] @@ -216,7 +217,7 @@ pub(super) struct InvocationTaskOutput { pub(super) enum InvocationTaskOutputInner { // `has_changed` indicates if we believe this is a freshly selected endpoint or not. - SelectedDeployment(DeploymentId, /* has_changed: */ bool), + PinnedDeployment(PinnedDeployment, /* has_changed: */ bool), ServerHeaderReceived(String), NewEntry { entry_index: EntryIndex, @@ -251,6 +252,8 @@ pub(super) struct InvocationTask { inactivity_timeout: Duration, abort_timeout: Duration, disable_eager_state: bool, + message_size_warning: usize, + message_size_limit: Option, // Invoker tx/rx state_reader: SR, @@ -259,13 +262,6 @@ pub(super) struct InvocationTask { deployment_metadata_resolver: DMR, invoker_tx: mpsc::UnboundedSender, invoker_rx: mpsc::UnboundedReceiver, - - // Encoder/Decoder - encoder: Encoder, - decoder: Decoder, - - // Task state - next_journal_index: EntryIndex, } /// This is needed to split the run_internal in multiple loop functions and have shortcircuiting. @@ -285,7 +281,8 @@ impl> From> for TerminalLoopState { match TerminalLoopState::from($value) { @@ -312,7 +309,6 @@ where partition: PartitionLeaderEpoch, invocation_id: InvocationId, invocation_target: InvocationTarget, - protocol_version: u16, inactivity_timeout: Duration, abort_timeout: Duration, disable_eager_state: bool, @@ -333,15 +329,14 @@ where inactivity_timeout, abort_timeout, disable_eager_state, - next_journal_index: 0, state_reader, journal_reader, entry_enricher, deployment_metadata_resolver, invoker_tx, invoker_rx, - encoder: Encoder::new(protocol_version), - decoder: Decoder::new(message_size_warning, message_size_limit), + message_size_limit, + message_size_warning, } } @@ -349,15 +344,7 @@ where #[instrument(level = "debug", name = "invoker_invocation_task", fields(rpc.system = "restate", rpc.service = %self.invocation_target.service_name(), restate.invocation.id = %self.invocation_id, restate.invocation.target = %self.invocation_target), skip_all)] pub async fn run(mut self, input_journal: InvokeInputJournal) { // Execute the task - let terminal_state = self.run_internal(input_journal).await; - - // Sanity check of the stream decoder - if self.decoder.has_remaining() { - warn_it!( - InvocationTaskError::WriteAfterEndOfStream, - "The read buffer is non empty after the stream has been closed." - ); - } + let terminal_state = self.select_protocol_version_and_run(input_journal).await; // Sanity check of the final state let inner = match terminal_state { @@ -372,7 +359,10 @@ where self.send_invoker_tx(inner); } - async fn run_internal(&mut self, input_journal: InvokeInputJournal) -> TerminalLoopState<()> { + async fn select_protocol_version_and_run( + &mut self, + input_journal: InvokeInputJournal, + ) -> TerminalLoopState<()> { // Resolve journal and its metadata let read_journal_future = async { Ok(match input_journal { @@ -409,15 +399,31 @@ where shortcircuit!(tokio::try_join!(read_journal_future, read_state_future)); // Resolve the deployment metadata - let (deployment, deployment_changed) = - if let Some(deployment_id) = journal_metadata.deployment_id { + let (deployment, chosen_service_protocol_version, deployment_changed) = + if let Some(pinned_deployment) = &journal_metadata.pinned_deployment { // We have a pinned deployment that we can't change even if newer // deployments have been registered for the same service. let deployment_metadata = shortcircuit!(self .deployment_metadata_resolver - .get_deployment(&deployment_id) - .ok_or_else(|| InvocationTaskError::UnknownDeployment(deployment_id))); - (deployment_metadata, /* has_changed= */ false) + .get_deployment(&pinned_deployment.deployment_id) + .ok_or_else(|| InvocationTaskError::UnknownDeployment( + pinned_deployment.deployment_id + ))); + + // todo: We should support resuming an invocation with a newer protocol version if + // the endpoint supports it + if !ServiceProtocolVersion::is_supported(pinned_deployment.service_protocol_version) + { + shortcircuit!(Err(InvocationTaskError::UnsupportedServiceProtocolVersion( + pinned_deployment.service_protocol_version + ))); + } + + ( + deployment_metadata, + pinned_deployment.service_protocol_version, + /* has_changed= */ false, + ) } else { // We can choose the freshest deployment for the latest revision // of the registered service. @@ -425,418 +431,42 @@ where .deployment_metadata_resolver .resolve_latest_deployment_for_service(self.invocation_target.service_name()) .ok_or(InvocationTaskError::NoDeploymentForService)); - (deployment, /* has_changed= */ true) + + let chosen_service_protocol_version = + shortcircuit!(ServiceProtocolVersion::choose_max_supported_version( + &deployment.metadata.supported_protocol_versions, + ) + .ok_or_else(|| { + InvocationTaskError::IncompatibleServiceEndpoint( + deployment.id, + deployment.metadata.supported_protocol_versions.clone(), + ) + })); + + ( + deployment, + chosen_service_protocol_version, + /* has_changed= */ true, + ) }; - self.send_invoker_tx(InvocationTaskOutputInner::SelectedDeployment( - deployment.id, + self.send_invoker_tx(InvocationTaskOutputInner::PinnedDeployment( + PinnedDeployment::new(deployment.id, chosen_service_protocol_version), deployment_changed, )); - // Figure out the protocol type. Force RequestResponse if inactivity_timeout is zero - let protocol_type = if self.inactivity_timeout.is_zero() { - ProtocolType::RequestResponse - } else { - deployment.metadata.ty.protocol_type() - }; + // create a correctly versioned service protocol runner + let service_protocol_runner = + ServiceProtocolRunner::new(self, chosen_service_protocol_version); - // Close the invoker_rx in case it's request response, this avoids further buffering of messages in this channel. - if protocol_type == ProtocolType::RequestResponse { - self.invoker_rx.close(); - } - - let path: PathAndQuery = format!( - "/invoke/{}/{}", - self.invocation_target.service_name(), - self.invocation_target.handler_name() - ) - .try_into() - .expect("must be able to build a valid invocation path"); - - let journal_size = journal_metadata.length; - - // Attach parent and uri to the current span - let invocation_task_span = Span::current(); - journal_metadata - .span_context - .as_parent() - .attach_to_span(&invocation_task_span); - - info!( - deployment.address = %deployment.metadata.address_display(), - path = %path, - "Executing invocation at deployment" - ); - - // Create an arc of the parent SpanContext. - // We send this with every journal entry to correctly link new spans generated from journal entries. - let service_invocation_span_context = journal_metadata.span_context; - - // Prepare the request and send start message - let (mut http_stream_tx, request) = self.prepare_request(path, deployment.metadata); - shortcircuit!( - self.write_start(&mut http_stream_tx, journal_size, state_iter) - .await - ); - - // Initialize the response stream state - let mut http_stream_rx = ResponseStreamState::initialize(&self.client, request); - - // Execute the replay - shortcircuit!( - self.replay_loop(&mut http_stream_tx, &mut http_stream_rx, journal_stream) - .await - ); - - // Check all the entries have been replayed - debug_assert_eq!(self.next_journal_index, journal_size); - - // If we have the invoker_rx and the protocol type is bidi stream, - // then we can use the bidi_stream loop reading the invoker_rx and the http_stream_rx - if protocol_type == ProtocolType::BidiStream { - shortcircuit!( - self.bidi_stream_loop( - &service_invocation_span_context, - http_stream_tx, - &mut http_stream_rx, - ) - .await - ); - } else { - // Drop the http_stream_tx. - // This is required in HTTP/1.1 to let the deployment send the headers back - drop(http_stream_tx) - } - - // We don't have the invoker_rx, so we simply consume the response - self.response_stream_loop(&service_invocation_span_context, &mut http_stream_rx) + service_protocol_runner + .run(journal_metadata, deployment, journal_stream, state_iter) .await } +} - // --- Loops - - /// This loop concurrently pushes journal entries and waits for the response headers and end of replay. - async fn replay_loop( - &mut self, - http_stream_tx: &mut Sender, - http_stream_rx: &mut ResponseStreamState, - journal_stream: JournalStream, - ) -> TerminalLoopState<()> - where - JournalStream: Stream + Unpin, - { - let mut journal_stream = journal_stream.fuse(); - let got_headers_future = poll_fn(|cx| http_stream_rx.poll_only_headers(cx)).fuse(); - tokio::pin!(got_headers_future); - - loop { - tokio::select! { - got_headers_res = got_headers_future.as_mut(), if !got_headers_future.is_terminated() => { - // The reason we want to poll headers in this function is - // to exit early in case an error is returned during replays. - let headers = shortcircuit!(got_headers_res); - shortcircuit!(self.handle_response_headers(headers)); - }, - opt_je = journal_stream.next() => { - match opt_je { - Some(je) => { - shortcircuit!(self.write(http_stream_tx, ProtocolMessage::UnparsedEntry(je)).await); - self.next_journal_index += 1; - }, - None => { - // No need to wait for the headers to continue - trace!("Finished to replay the journal"); - return TerminalLoopState::Continue(()) - } - } - } - } - } - } - - /// This loop concurrently reads the http response stream and journal completions from the invoker. - async fn bidi_stream_loop( - &mut self, - parent_span_context: &ServiceInvocationSpanContext, - mut http_stream_tx: Sender, - http_stream_rx: &mut ResponseStreamState, - ) -> TerminalLoopState<()> { - loop { - tokio::select! { - opt_completion = self.invoker_rx.recv() => { - match opt_completion { - Some(Notification::Completion(completion)) => { - trace!("Sending the completion to the wire"); - shortcircuit!(self.write(&mut http_stream_tx, completion.into()).await); - }, - Some(Notification::Ack(entry_index)) => { - trace!("Sending the ack to the wire"); - shortcircuit!(self.write(&mut http_stream_tx, ProtocolMessage::new_entry_ack(entry_index)).await); - }, - None => { - // Completion channel is closed, - // the invoker main loop won't send completions anymore. - // Response stream might still be open though. - return TerminalLoopState::Continue(()) - }, - } - }, - chunk = poll_fn(|cx| http_stream_rx.poll_next_chunk(cx)) => { - match shortcircuit!(chunk) { - ResponseChunk::Parts(parts) => shortcircuit!(self.handle_response_headers(parts)), - ResponseChunk::Data(buf) => shortcircuit!(self.handle_read(parent_span_context, buf)), - ResponseChunk::End => { - // Response stream was closed without SuspensionMessage, EndMessage or ErrorMessage - return TerminalLoopState::Failed(InvocationTaskError::ErrorMessageReceived( - None, - InvocationError::default() - )) - } - } - }, - _ = tokio::time::sleep(self.inactivity_timeout) => { - debug!("Inactivity detected, going to suspend invocation"); - // Just return. This will drop the invoker_rx and http_stream_tx, - // closing the request stream and the invoker input channel. - return TerminalLoopState::Continue(()) - }, - } - } - } - - async fn response_stream_loop( - &mut self, - parent_span_context: &ServiceInvocationSpanContext, - http_stream_rx: &mut ResponseStreamState, - ) -> TerminalLoopState<()> { - loop { - tokio::select! { - chunk = poll_fn(|cx| http_stream_rx.poll_next_chunk(cx)) => { - match shortcircuit!(chunk) { - ResponseChunk::Parts(parts) => shortcircuit!(self.handle_response_headers(parts)), - ResponseChunk::Data(buf) => shortcircuit!(self.handle_read(parent_span_context, buf)), - ResponseChunk::End => { - // Response stream was closed without SuspensionMessage, EndMessage or ErrorMessage - return TerminalLoopState::Failed(InvocationTaskError::ErrorMessageReceived( - None, - InvocationError::default() - )) - } - } - }, - _ = tokio::time::sleep(self.abort_timeout) => { - warn!("Inactivity detected, going to close invocation"); - return TerminalLoopState::Failed(InvocationTaskError::ResponseTimeout) - }, - } - } - } - - // --- Read and write methods - - async fn write_start>( - &mut self, - http_stream_tx: &mut Sender, - journal_size: u32, - state_entries: EagerState, - ) -> Result<(), InvocationTaskError> { - let is_partial = state_entries.is_partial(); - - // Send the invoke frame - self.write( - http_stream_tx, - ProtocolMessage::new_start_message( - Bytes::copy_from_slice(&self.invocation_id.to_bytes()), - self.invocation_id.to_string(), - self.invocation_target.key().map(|bs| bs.as_bytes().clone()), - journal_size, - is_partial, - state_entries, - ), - ) - .await - } - - async fn write( - &mut self, - http_stream_tx: &mut Sender, - msg: ProtocolMessage, - ) -> Result<(), InvocationTaskError> { - trace!(restate.protocol.message = ?msg, "Sending message"); - let buf = self.encoder.encode(msg); - - if let Err(hyper_err) = http_stream_tx.send_data(buf).await { - // is_closed() is try only if the request channel (Sender) has been closed. - // This can happen if the deployment is suspending. - if !hyper_err.is_closed() { - return Err(InvocationTaskError::Client(ServiceClientError::Http( - hyper_err.into(), - ))); - } - }; - Ok(()) - } - - fn handle_response_headers( - &mut self, - mut parts: ResponseParts, - ) -> Result<(), InvocationTaskError> { - if !parts.status.is_success() { - return Err(InvocationTaskError::UnexpectedResponse(parts.status)); - } - - let content_type = parts.headers.remove(http::header::CONTENT_TYPE); - match content_type { - // Check content type is application/restate - Some(ct) => - { - #[allow(clippy::borrow_interior_mutable_const)] - if ct != APPLICATION_RESTATE { - return Err(InvocationTaskError::UnexpectedContentType(Some(ct))); - } - } - None => return Err(InvocationTaskError::UnexpectedContentType(None)), - } - - if let Some(hv) = parts.headers.remove(X_RESTATE_SERVER) { - self.send_invoker_tx(InvocationTaskOutputInner::ServerHeaderReceived( - hv.to_str() - .map_err(|e| InvocationTaskError::BadHeader(X_RESTATE_SERVER, e))? - .to_owned(), - )) - } - - Ok(()) - } - - fn handle_read( - &mut self, - parent_span_context: &ServiceInvocationSpanContext, - buf: Bytes, - ) -> TerminalLoopState<()> { - self.decoder.push(buf); - - while let Some((frame_header, frame)) = shortcircuit!(self.decoder.consume_next()) { - shortcircuit!(self.handle_message(parent_span_context, frame_header, frame)); - } - - TerminalLoopState::Continue(()) - } - - fn handle_message( - &mut self, - parent_span_context: &ServiceInvocationSpanContext, - mh: MessageHeader, - message: ProtocolMessage, - ) -> TerminalLoopState<()> { - trace!(restate.protocol.message_header = ?mh, restate.protocol.message = ?message, "Received message"); - match message { - ProtocolMessage::Start { .. } => TerminalLoopState::Failed( - InvocationTaskError::UnexpectedMessage(MessageType::Start), - ), - ProtocolMessage::Completion(_) => TerminalLoopState::Failed( - InvocationTaskError::UnexpectedMessage(MessageType::Completion), - ), - ProtocolMessage::EntryAck(_) => TerminalLoopState::Failed( - InvocationTaskError::UnexpectedMessage(MessageType::EntryAck), - ), - ProtocolMessage::Suspension(suspension) => { - let suspension_indexes = HashSet::from_iter(suspension.entry_indexes); - // We currently don't support empty suspension_indexes set - if suspension_indexes.is_empty() { - return TerminalLoopState::Failed(InvocationTaskError::EmptySuspensionMessage); - } - // Sanity check on the suspension indexes - if *suspension_indexes.iter().max().unwrap() >= self.next_journal_index { - return TerminalLoopState::Failed(InvocationTaskError::BadSuspensionMessage( - suspension_indexes, - self.next_journal_index, - )); - } - TerminalLoopState::Suspended(suspension_indexes) - } - ProtocolMessage::Error(e) => { - TerminalLoopState::Failed(InvocationTaskError::ErrorMessageReceived( - Some(InvocationErrorRelatedEntry { - related_entry_index: e.related_entry_index, - related_entry_name: e.related_entry_name.clone(), - related_entry_type: e - .related_entry_type - .and_then(|t| u16::try_from(t).ok()) - .and_then(|idx| MessageType::try_from(idx).ok()) - .and_then(|mt| EntryType::try_from(mt).ok()), - }), - InvocationError::from(e), - )) - } - ProtocolMessage::End(_) => TerminalLoopState::Closed, - ProtocolMessage::UnparsedEntry(entry) => { - let entry_type = entry.header().as_entry_type(); - let enriched_entry = shortcircuit!(self - .entry_enricher - .enrich_entry(entry, &self.invocation_target, parent_span_context) - .map_err(|e| InvocationTaskError::EntryEnrichment( - self.next_journal_index, - entry_type, - e - ))); - self.send_invoker_tx(InvocationTaskOutputInner::NewEntry { - entry_index: self.next_journal_index, - entry: enriched_entry, - requires_ack: mh - .requires_ack() - .expect("All entry messages support requires_ack"), - }); - self.next_journal_index += 1; - TerminalLoopState::Continue(()) - } - } - } - - fn prepare_request( - &mut self, - path: PathAndQuery, - deployment_metadata: DeploymentMetadata, - ) -> (Sender, Request) { - let (http_stream_tx, req_body) = Body::channel(); - - let mut headers = HeaderMap::from_iter([ - (http::header::CONTENT_TYPE, APPLICATION_RESTATE), - (http::header::ACCEPT, APPLICATION_RESTATE), - ]); - - // Inject OpenTelemetry context - TraceContextPropagator::new().inject_context( - &Span::current().context(), - &mut HeaderInjector(&mut headers), - ); - - let address = match deployment_metadata.ty { - DeploymentType::Lambda { - arn, - assume_role_arn, - } => Endpoint::Lambda(arn, assume_role_arn), - DeploymentType::Http { - address, - protocol_type, - } => Endpoint::Http( - address, - match protocol_type { - ProtocolType::RequestResponse => http::Version::default(), - ProtocolType::BidiStream => http::Version::HTTP_2, - }, - ), - }; - - headers.extend(deployment_metadata.delivery_options.additional_headers); - - ( - http_stream_tx, - Request::new(Parts::new(address, path, headers), req_body), - ) - } - - fn send_invoker_tx(&mut self, invocation_task_output_inner: InvocationTaskOutputInner) { +impl InvocationTask { + fn send_invoker_tx(&self, invocation_task_output_inner: InvocationTaskOutputInner) { let _ = self.invoker_tx.send(InvocationTaskOutput { partition: self.partition, invocation_id: self.invocation_id, @@ -845,6 +475,17 @@ where } } +fn service_protocol_version_to_header_value( + service_protocol_version: ServiceProtocolVersion, +) -> HeaderValue { + match service_protocol_version { + ServiceProtocolVersion::Unspecified => { + unreachable!("unknown protocol version should never be chosen") + } + ServiceProtocolVersion::V1 => SERVICE_PROTOCOL_VERSION_V1, + } +} + enum ResponseChunk { Parts(ResponseParts), Data(Bytes), diff --git a/crates/invoker-impl/src/invocation_task/service_protocol_runner.rs b/crates/invoker-impl/src/invocation_task/service_protocol_runner.rs new file mode 100644 index 0000000000..a67411001d --- /dev/null +++ b/crates/invoker-impl/src/invocation_task/service_protocol_runner.rs @@ -0,0 +1,538 @@ +// Copyright (c) 2024 - Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use crate::invocation_task::{ + service_protocol_version_to_header_value, InvocationErrorRelatedEntry, InvocationTask, + InvocationTaskError, InvocationTaskOutputInner, ResponseChunk, ResponseStreamState, + TerminalLoopState, X_RESTATE_SERVER, +}; +use crate::Notification; +use bytes::Bytes; +use futures::future::FusedFuture; +use futures::{FutureExt, Stream, StreamExt}; +use hyper::body::Sender; +use hyper::http::uri::PathAndQuery; +use hyper::{http, Body, HeaderMap}; +use opentelemetry::propagation::TextMapPropagator; +use opentelemetry_http::HeaderInjector; +use opentelemetry_sdk::propagation::TraceContextPropagator; +use restate_errors::warn_it; +use restate_invoker_api::{EagerState, EntryEnricher, JournalMetadata}; +use restate_schema_api::deployment::{ + Deployment, DeploymentMetadata, DeploymentType, ProtocolType, +}; +use restate_service_client::{Endpoint, Parts, Request, ServiceClientError}; +use restate_service_protocol::message::{ + Decoder, Encoder, MessageHeader, MessageType, ProtocolMessage, +}; +use restate_types::errors::InvocationError; +use restate_types::identifiers::EntryIndex; +use restate_types::invocation::ServiceInvocationSpanContext; +use restate_types::journal::raw::PlainRawEntry; +use restate_types::journal::EntryType; +use restate_types::service_protocol::ServiceProtocolVersion; +use std::collections::HashSet; +use std::future::poll_fn; +use tracing::log::warn; +use tracing::{debug, info, trace, Span}; +use tracing_opentelemetry::OpenTelemetrySpanExt; + +/// Runs the interaction between the server and the service endpoint. +pub struct ServiceProtocolRunner<'a, SR, JR, EE, DMR> { + invocation_task: &'a mut InvocationTask, + + service_protocol_version: ServiceProtocolVersion, + + // Encoder/Decoder + encoder: Encoder, + decoder: Decoder, + + // task state + next_journal_index: EntryIndex, +} + +impl<'a, SR, JR, EE, DMR> ServiceProtocolRunner<'a, SR, JR, EE, DMR> +where + EE: EntryEnricher, +{ + pub fn new( + invocation_task: &'a mut InvocationTask, + service_protocol_version: ServiceProtocolVersion, + ) -> Self { + let encoder = Encoder::new(service_protocol_version); + let decoder = Decoder::new( + service_protocol_version, + invocation_task.message_size_warning, + invocation_task.message_size_limit, + ); + + Self { + invocation_task, + service_protocol_version, + encoder, + decoder, + next_journal_index: 0, + } + } + + pub async fn run( + mut self, + journal_metadata: JournalMetadata, + deployment: Deployment, + journal_stream: JournalStream, + state_iter: EagerState, + ) -> TerminalLoopState<()> + where + JournalStream: Stream + Unpin, + StateIter: Iterator, + { + // Figure out the protocol type. Force RequestResponse if inactivity_timeout is zero + let protocol_type = if self.invocation_task.inactivity_timeout.is_zero() { + ProtocolType::RequestResponse + } else { + deployment.metadata.ty.protocol_type() + }; + + // Close the invoker_rx in case it's request response, this avoids further buffering of messages in this channel. + if protocol_type == ProtocolType::RequestResponse { + self.invocation_task.invoker_rx.close(); + } + + let path: PathAndQuery = format!( + "/invoke/{}/{}", + self.invocation_task.invocation_target.service_name(), + self.invocation_task.invocation_target.handler_name() + ) + .try_into() + .expect("must be able to build a valid invocation path"); + + let journal_size = journal_metadata.length; + + // Attach parent and uri to the current span + let invocation_task_span = Span::current(); + journal_metadata + .span_context + .as_parent() + .attach_to_span(&invocation_task_span); + + info!( + deployment.address = %deployment.metadata.address_display(), + deployment.service_protocol_version = %self.service_protocol_version.as_repr(), + path = %path, + "Executing invocation at deployment" + ); + + // Create an arc of the parent SpanContext. + // We send this with every journal entry to correctly link new spans generated from journal entries. + let service_invocation_span_context = journal_metadata.span_context; + + // Prepare the request and send start message + let (mut http_stream_tx, request) = + Self::prepare_request(path, deployment.metadata, self.service_protocol_version); + + crate::shortcircuit!( + self.write_start(&mut http_stream_tx, journal_size, state_iter) + .await + ); + + // Initialize the response stream state + let mut http_stream_rx = + ResponseStreamState::initialize(&self.invocation_task.client, request); + + // Execute the replay + crate::shortcircuit!( + self.replay_loop(&mut http_stream_tx, &mut http_stream_rx, journal_stream) + .await + ); + + // Check all the entries have been replayed + debug_assert_eq!(self.next_journal_index, journal_size); + + // If we have the invoker_rx and the protocol type is bidi stream, + // then we can use the bidi_stream loop reading the invoker_rx and the http_stream_rx + if protocol_type == ProtocolType::BidiStream { + crate::shortcircuit!( + self.bidi_stream_loop( + &service_invocation_span_context, + http_stream_tx, + &mut http_stream_rx, + ) + .await + ); + } else { + // Drop the http_stream_tx. + // This is required in HTTP/1.1 to let the deployment send the headers back + drop(http_stream_tx) + } + + // We don't have the invoker_rx, so we simply consume the response + let result = self + .response_stream_loop(&service_invocation_span_context, &mut http_stream_rx) + .await; + + // Sanity check of the stream decoder + if self.decoder.has_remaining() { + warn_it!( + InvocationTaskError::WriteAfterEndOfStream, + "The read buffer is non empty after the stream has been closed." + ); + } + + result + } + + fn prepare_request( + path: PathAndQuery, + deployment_metadata: DeploymentMetadata, + service_protocol_version: ServiceProtocolVersion, + ) -> (Sender, Request) { + let (http_stream_tx, req_body) = Body::channel(); + + let service_protocol_header_value = + service_protocol_version_to_header_value(service_protocol_version); + + let mut headers = HeaderMap::from_iter([ + ( + http::header::CONTENT_TYPE, + service_protocol_header_value.clone(), + ), + (http::header::ACCEPT, service_protocol_header_value), + ]); + + // Inject OpenTelemetry context + TraceContextPropagator::new().inject_context( + &Span::current().context(), + &mut HeaderInjector(&mut headers), + ); + + let address = match deployment_metadata.ty { + DeploymentType::Lambda { + arn, + assume_role_arn, + } => Endpoint::Lambda(arn, assume_role_arn), + DeploymentType::Http { + address, + protocol_type, + } => Endpoint::Http( + address, + match protocol_type { + ProtocolType::RequestResponse => http::Version::default(), + ProtocolType::BidiStream => http::Version::HTTP_2, + }, + ), + }; + + headers.extend(deployment_metadata.delivery_options.additional_headers); + + ( + http_stream_tx, + Request::new(Parts::new(address, path, headers), req_body), + ) + } + + // --- Loops + + /// This loop concurrently pushes journal entries and waits for the response headers and end of replay. + async fn replay_loop( + &mut self, + http_stream_tx: &mut Sender, + http_stream_rx: &mut ResponseStreamState, + journal_stream: JournalStream, + ) -> TerminalLoopState<()> + where + JournalStream: Stream + Unpin, + { + let mut journal_stream = journal_stream.fuse(); + let got_headers_future = poll_fn(|cx| http_stream_rx.poll_only_headers(cx)).fuse(); + tokio::pin!(got_headers_future); + + loop { + tokio::select! { + got_headers_res = got_headers_future.as_mut(), if !got_headers_future.is_terminated() => { + // The reason we want to poll headers in this function is + // to exit early in case an error is returned during replays. + let headers = crate::shortcircuit!(got_headers_res); + crate::shortcircuit!(self.handle_response_headers(headers)); + }, + opt_je = journal_stream.next() => { + match opt_je { + Some(je) => { + crate::shortcircuit!(self.write(http_stream_tx, ProtocolMessage::UnparsedEntry(je)).await); + self.next_journal_index += 1; + }, + None => { + // No need to wait for the headers to continue + trace!("Finished to replay the journal"); + return TerminalLoopState::Continue(()) + } + } + } + } + } + } + + /// This loop concurrently reads the http response stream and journal completions from the invoker. + async fn bidi_stream_loop( + &mut self, + parent_span_context: &ServiceInvocationSpanContext, + mut http_stream_tx: Sender, + http_stream_rx: &mut ResponseStreamState, + ) -> TerminalLoopState<()> { + loop { + tokio::select! { + opt_completion = self.invocation_task.invoker_rx.recv() => { + match opt_completion { + Some(Notification::Completion(completion)) => { + trace!("Sending the completion to the wire"); + crate::shortcircuit!(self.write(&mut http_stream_tx, completion.into()).await); + }, + Some(Notification::Ack(entry_index)) => { + trace!("Sending the ack to the wire"); + crate::shortcircuit!(self.write(&mut http_stream_tx, ProtocolMessage::new_entry_ack(entry_index)).await); + }, + None => { + // Completion channel is closed, + // the invoker main loop won't send completions anymore. + // Response stream might still be open though. + return TerminalLoopState::Continue(()) + }, + } + }, + chunk = poll_fn(|cx| http_stream_rx.poll_next_chunk(cx)) => { + match crate::shortcircuit!(chunk) { + ResponseChunk::Parts(parts) => crate::shortcircuit!(self.handle_response_headers(parts)), + ResponseChunk::Data(buf) => crate::shortcircuit!(self.handle_read(parent_span_context, buf)), + ResponseChunk::End => { + // Response stream was closed without SuspensionMessage, EndMessage or ErrorMessage + return TerminalLoopState::Failed(InvocationTaskError::ErrorMessageReceived( + None, + InvocationError::default() + )) + } + } + }, + _ = tokio::time::sleep(self.invocation_task.inactivity_timeout) => { + debug!("Inactivity detected, going to suspend invocation"); + // Just return. This will drop the invoker_rx and http_stream_tx, + // closing the request stream and the invoker input channel. + return TerminalLoopState::Continue(()) + }, + } + } + } + + async fn response_stream_loop( + &mut self, + parent_span_context: &ServiceInvocationSpanContext, + http_stream_rx: &mut ResponseStreamState, + ) -> TerminalLoopState<()> { + loop { + tokio::select! { + chunk = poll_fn(|cx| http_stream_rx.poll_next_chunk(cx)) => { + match crate::shortcircuit!(chunk) { + ResponseChunk::Parts(parts) => crate::shortcircuit!(self.handle_response_headers(parts)), + ResponseChunk::Data(buf) => crate::shortcircuit!(self.handle_read(parent_span_context, buf)), + ResponseChunk::End => { + // Response stream was closed without SuspensionMessage, EndMessage or ErrorMessage + return TerminalLoopState::Failed(InvocationTaskError::ErrorMessageReceived( + None, + InvocationError::default() + )) + } + } + }, + _ = tokio::time::sleep(self.invocation_task.abort_timeout) => { + warn!("Inactivity detected, going to close invocation"); + return TerminalLoopState::Failed(InvocationTaskError::ResponseTimeout) + }, + } + } + } + + // --- Read and write methods + + async fn write_start>( + &mut self, + http_stream_tx: &mut Sender, + journal_size: u32, + state_entries: EagerState, + ) -> Result<(), InvocationTaskError> { + let is_partial = state_entries.is_partial(); + + // Send the invoke frame + self.write( + http_stream_tx, + ProtocolMessage::new_start_message( + Bytes::copy_from_slice(&self.invocation_task.invocation_id.to_bytes()), + self.invocation_task.invocation_id.to_string(), + self.invocation_task + .invocation_target + .key() + .map(|bs| bs.as_bytes().clone()), + journal_size, + is_partial, + state_entries, + ), + ) + .await + } + + async fn write( + &mut self, + http_stream_tx: &mut Sender, + msg: ProtocolMessage, + ) -> Result<(), InvocationTaskError> { + trace!(restate.protocol.message = ?msg, "Sending message"); + let buf = self.encoder.encode(msg); + + if let Err(hyper_err) = http_stream_tx.send_data(buf).await { + // is_closed() is try only if the request channel (Sender) has been closed. + // This can happen if the deployment is suspending. + if !hyper_err.is_closed() { + return Err(InvocationTaskError::Client(ServiceClientError::Http( + hyper_err.into(), + ))); + } + }; + Ok(()) + } + + fn handle_response_headers( + &mut self, + mut parts: http::response::Parts, + ) -> Result<(), InvocationTaskError> { + if !parts.status.is_success() { + return Err(InvocationTaskError::UnexpectedResponse(parts.status)); + } + + let content_type = parts.headers.remove(http::header::CONTENT_TYPE); + let expected_content_type = + service_protocol_version_to_header_value(self.service_protocol_version); + match content_type { + Some(ct) => + { + #[allow(clippy::borrow_interior_mutable_const)] + if ct != expected_content_type { + return Err(InvocationTaskError::UnexpectedContentType( + Some(ct), + expected_content_type, + )); + } + } + None => { + return Err(InvocationTaskError::UnexpectedContentType( + None, + expected_content_type, + )) + } + } + + if let Some(hv) = parts.headers.remove(X_RESTATE_SERVER) { + self.invocation_task + .send_invoker_tx(InvocationTaskOutputInner::ServerHeaderReceived( + hv.to_str() + .map_err(|e| InvocationTaskError::BadHeader(X_RESTATE_SERVER, e))? + .to_owned(), + )) + } + + Ok(()) + } + + fn handle_read( + &mut self, + parent_span_context: &ServiceInvocationSpanContext, + buf: Bytes, + ) -> TerminalLoopState<()> { + self.decoder.push(buf); + + while let Some((frame_header, frame)) = crate::shortcircuit!(self.decoder.consume_next()) { + crate::shortcircuit!(self.handle_message(parent_span_context, frame_header, frame)); + } + + TerminalLoopState::Continue(()) + } + + fn handle_message( + &mut self, + parent_span_context: &ServiceInvocationSpanContext, + mh: MessageHeader, + message: ProtocolMessage, + ) -> TerminalLoopState<()> { + trace!(restate.protocol.message_header = ?mh, restate.protocol.message = ?message, "Received message"); + match message { + ProtocolMessage::Start { .. } => TerminalLoopState::Failed( + InvocationTaskError::UnexpectedMessage(MessageType::Start), + ), + ProtocolMessage::Completion(_) => TerminalLoopState::Failed( + InvocationTaskError::UnexpectedMessage(MessageType::Completion), + ), + ProtocolMessage::EntryAck(_) => TerminalLoopState::Failed( + InvocationTaskError::UnexpectedMessage(MessageType::EntryAck), + ), + ProtocolMessage::Suspension(suspension) => { + let suspension_indexes = HashSet::from_iter(suspension.entry_indexes); + // We currently don't support empty suspension_indexes set + if suspension_indexes.is_empty() { + return TerminalLoopState::Failed(InvocationTaskError::EmptySuspensionMessage); + } + // Sanity check on the suspension indexes + if *suspension_indexes.iter().max().unwrap() >= self.next_journal_index { + return TerminalLoopState::Failed(InvocationTaskError::BadSuspensionMessage( + suspension_indexes, + self.next_journal_index, + )); + } + TerminalLoopState::Suspended(suspension_indexes) + } + ProtocolMessage::Error(e) => { + TerminalLoopState::Failed(InvocationTaskError::ErrorMessageReceived( + Some(InvocationErrorRelatedEntry { + related_entry_index: e.related_entry_index, + related_entry_name: e.related_entry_name.clone(), + related_entry_type: e + .related_entry_type + .and_then(|t| u16::try_from(t).ok()) + .and_then(|idx| MessageType::try_from(idx).ok()) + .and_then(|mt| EntryType::try_from(mt).ok()), + }), + InvocationError::from(e), + )) + } + ProtocolMessage::End(_) => TerminalLoopState::Closed, + ProtocolMessage::UnparsedEntry(entry) => { + let entry_type = entry.header().as_entry_type(); + let enriched_entry = crate::shortcircuit!(self + .invocation_task + .entry_enricher + .enrich_entry( + entry, + &self.invocation_task.invocation_target, + parent_span_context + ) + .map_err(|e| InvocationTaskError::EntryEnrichment( + self.next_journal_index, + entry_type, + e + ))); + self.invocation_task + .send_invoker_tx(InvocationTaskOutputInner::NewEntry { + entry_index: self.next_journal_index, + entry: enriched_entry, + requires_ack: mh + .requires_ack() + .expect("All entry messages support requires_ack"), + }); + self.next_journal_index += 1; + TerminalLoopState::Continue(()) + } + } + } +} diff --git a/crates/invoker-impl/src/lib.rs b/crates/invoker-impl/src/lib.rs index 9a020b1204..c52eaf6f14 100644 --- a/crates/invoker-impl/src/lib.rs +++ b/crates/invoker-impl/src/lib.rs @@ -56,7 +56,7 @@ use crate::invocation_task::InvocationTaskError; pub use input_command::ChannelStatusReader; pub use input_command::InvokerHandle; use restate_service_client::{AssumeRoleCacheMode, ServiceClient}; -use restate_service_protocol::RESTATE_SERVICE_PROTOCOL_VERSION; +use restate_types::deployment::PinnedDeployment; use restate_types::invocation::InvocationTarget; use crate::metric_definitions::{ @@ -100,7 +100,7 @@ where SR: JournalReader + StateReader + Clone + Send + Sync + 'static, ::JournalStream: Unpin + Send + 'static, ::StateIter: Send, - EE: EntryEnricher + Clone + Send + 'static, + EE: EntryEnricher + Clone + Send + Sync + 'static, DMR: DeploymentResolver + Clone + Send + 'static, { fn start_invocation_task( @@ -121,7 +121,6 @@ where partition, invocation_id, invocation_target, - RESTATE_SERVICE_PROTOCOL_VERSION, opts.inactivity_timeout.into(), opts.abort_timeout.into(), opts.disable_eager_state, @@ -235,7 +234,7 @@ where SR: JournalReader + StateReader + Clone + Send + Sync + 'static, ::JournalStream: Unpin + Send + 'static, ::StateIter: Send, - EE: EntryEnricher + Clone + Send + 'static, + EE: EntryEnricher + Clone + Send + Sync + 'static, EMR: DeploymentResolver + Clone + Send + 'static, { pub fn handle(&self) -> InvokerHandle { @@ -374,11 +373,11 @@ where inner } = invocation_task_msg; match inner { - InvocationTaskOutputInner::SelectedDeployment(deployment_id, has_changed) => { - self.handle_selected_deployment( + InvocationTaskOutputInner::PinnedDeployment(deployment_metadata, has_changed) => { + self.handle_pinned_deployment( partition, invocation_id, - deployment_id, + deployment_metadata, has_changed, ).await } @@ -550,14 +549,14 @@ where fields( restate.invocation.id = %invocation_id, restate.invoker.partition_leader_epoch = ?partition, - restate.deployment.id = %deployment_id, + restate.deployment.id = %pinned_deployment.deployment_id, ) )] - async fn handle_selected_deployment( + async fn handle_pinned_deployment( &mut self, partition: PartitionLeaderEpoch, invocation_id: InvocationId, - deployment_id: DeploymentId, + pinned_deployment: PinnedDeployment, has_changed: bool, ) { if let Some((_, ism)) = self @@ -566,17 +565,20 @@ where { trace!( restate.invocation.target = %ism.invocation_target, - "Chosen deployment {}. Invocation state: {:?}", - deployment_id, + "Pinned deployment '{}'. Invocation state: {:?}", + pinned_deployment, ism.invocation_state_debug() ); - self.status_store - .on_deployment_chosen(&partition, &invocation_id, deployment_id); + self.status_store.on_deployment_chosen( + &partition, + &invocation_id, + pinned_deployment.deployment_id, + ); // If we think this selected deployment has been freshly picked, otherwise // we assume that we have stored it previously. if has_changed { - ism.notify_chosen_deployment(deployment_id); + ism.notify_pinned_deployment(pinned_deployment); } } else { // If no state machine, this might be an event for an aborted invocation. @@ -648,11 +650,11 @@ where "Received a new entry. Invocation state: {:?}", ism.invocation_state_debug() ); - if let Some(deployment_id) = ism.chosen_deployment_to_notify() { + if let Some(pinned_deployment) = ism.pinned_deployment_to_notify() { let _ = output_tx .send(Effect { invocation_id, - kind: EffectKind::SelectedDeployment(deployment_id), + kind: EffectKind::PinnedDeployment(pinned_deployment), }) .await; } diff --git a/crates/partition-store/tests/invocation_status_table_test/mod.rs b/crates/partition-store/tests/invocation_status_table_test/mod.rs index e6f7c9364c..b800b09dfc 100644 --- a/crates/partition-store/tests/invocation_status_table_test/mod.rs +++ b/crates/partition-store/tests/invocation_status_table_test/mod.rs @@ -60,7 +60,7 @@ fn invoked_status(invocation_target: InvocationTarget) -> InvocationStatus { InvocationStatus::Invoked(InFlightInvocationMetadata { invocation_target, journal_metadata: JournalMetadata::initialize(ServiceInvocationSpanContext::empty()), - deployment_id: None, + pinned_deployment: None, response_sinks: HashSet::new(), timestamps: StatusTimestamps::new(MillisSinceEpoch::new(0), MillisSinceEpoch::new(0)), source: Source::Ingress, @@ -74,7 +74,7 @@ fn suspended_status(invocation_target: InvocationTarget) -> InvocationStatus { metadata: InFlightInvocationMetadata { invocation_target, journal_metadata: JournalMetadata::initialize(ServiceInvocationSpanContext::empty()), - deployment_id: None, + pinned_deployment: None, response_sinks: HashSet::new(), timestamps: StatusTimestamps::new(MillisSinceEpoch::new(0), MillisSinceEpoch::new(0)), source: Source::Ingress, diff --git a/crates/schema-api/src/lib.rs b/crates/schema-api/src/lib.rs index 88e86c0e32..70e1e79cd7 100644 --- a/crates/schema-api/src/lib.rs +++ b/crates/schema-api/src/lib.rs @@ -10,8 +10,6 @@ //! This crate contains all the different APIs for accessing schemas. -pub const MAX_SERVICE_PROTOCOL_VERSION_VALUE: i32 = i32::MAX; - #[cfg(feature = "invocation_target")] pub mod invocation_target; @@ -194,7 +192,7 @@ pub mod deployment { pub mod mocks { use super::*; - use crate::MAX_SERVICE_PROTOCOL_VERSION_VALUE; + use restate_types::service_protocol::MAX_SERVICE_PROTOCOL_VERSION_VALUE; use std::collections::HashMap; impl Deployment { diff --git a/crates/service-protocol/Cargo.toml b/crates/service-protocol/Cargo.toml index b21de1782b..b939f7999e 100644 --- a/crates/service-protocol/Cargo.toml +++ b/crates/service-protocol/Cargo.toml @@ -11,11 +11,10 @@ publish = false default = [] awakeable-id = ["dep:base64", "dep:restate-base64-util", "dep:restate-types"] -codec = ["protocol", "dep:restate-types", "dep:paste"] +codec = ["dep:restate-types", "dep:paste"] discovery = ["dep:serde", "dep:serde_json", "dep:regress", "dep:tracing", "dep:codederror", "dep:restate-errors", "dep:restate-schema-api", "dep:hyper", "dep:restate-service-client", "dep:restate-types", "dep:tokio"] -message = ["protocol", "dep:restate-types", "dep:bytes-utils", "dep:codederror", "dep:restate-errors", "dep:size", "dep:tracing"] +message = ["dep:restate-types", "dep:bytes-utils", "dep:codederror", "dep:restate-errors", "dep:size", "dep:tracing"] mocks = ["awakeable-id"] -protocol = [] [dependencies] restate-base64-util = { workspace = true, optional = true } @@ -42,17 +41,9 @@ regress = { version = "0.9", optional = true } [dev-dependencies] restate-test-util = { workspace = true } +restate-types = { workspace = true, features = ["test-util"] } test-log = { workspace = true } tokio = { workspace = true } tracing-subscriber = { workspace = true } uuid = { workspace = true } - -[build-dependencies] -prost-build = { workspace = true } -prettyplease = "0.2" -schemars = { workspace = true } -serde_json = { workspace = true } -syn = "2.0" -typify = { version = "0.0.16" } -jsonptr = "0.4.7" diff --git a/crates/service-protocol/README.md b/crates/service-protocol/README.md deleted file mode 100644 index 318da58592..0000000000 --- a/crates/service-protocol/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# Service protocol implementation - -To update the subtree, from the root directory of the project: - -```shell -git subtree pull --prefix crates/service-protocol/service-protocol git@github.com:restatedev/service-protocol.git main --squash -``` \ No newline at end of file diff --git a/crates/service-protocol/src/codec.rs b/crates/service-protocol/src/codec.rs index 7b95ec59e8..74ed8e44d3 100644 --- a/crates/service-protocol/src/codec.rs +++ b/crates/service-protocol/src/codec.rs @@ -8,14 +8,13 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. -use super::pb::protocol; - use bytes::{Buf, BufMut, Bytes, BytesMut}; use prost::Message; use restate_types::invocation::Header; use restate_types::journal::enriched::{EnrichedEntryHeader, EnrichedRawEntry}; use restate_types::journal::raw::*; use restate_types::journal::{CompletionResult, Entry, EntryType}; +use restate_types::service_protocol; use std::fmt::Debug; use std::mem; @@ -27,7 +26,7 @@ macro_rules! match_decode { ($ty:expr, $buf:expr, { $($variant:ident),* }) => { match $ty { $(EntryType::$variant { .. } => paste::paste! { - protocol::[<$variant EntryMessage>]::decode($buf) + service_protocol::[<$variant EntryMessage>]::decode($buf) .map_err(|e| RawEntryCodecError::new($ty.clone(), ErrorKind::Decode { source: Some(e.into()) })) .and_then(|msg| msg.try_into().map_err(|f| RawEntryCodecError::new($ty.clone(), ErrorKind::MissingField(f)))) },)* @@ -50,10 +49,10 @@ impl RawEntryCodec for ProtobufRawEntryCodec { fn serialize_as_input_entry(headers: Vec
, value: Bytes) -> EnrichedRawEntry { RawEntry::new( EnrichedEntryHeader::Input {}, - protocol::InputEntryMessage { + service_protocol::InputEntryMessage { headers: headers .into_iter() - .map(|h| protocol::Header { + .map(|h| service_protocol::Header { key: h.name.to_string(), value: h.value.to_string(), }) @@ -68,7 +67,7 @@ impl RawEntryCodec for ProtobufRawEntryCodec { fn serialize_get_state_keys_completion(keys: Vec) -> CompletionResult { CompletionResult::Success( - protocol::get_state_keys_entry_message::StateKeys { keys } + service_protocol::get_state_keys_entry_message::StateKeys { keys } .encode_to_vec() .into(), ) @@ -127,11 +126,11 @@ impl RawEntryCodec for ProtobufRawEntryCodec { // Prepare the result to serialize in protobuf let completion_result_message = match completion_result { CompletionResult::Empty => { - protocol::completion_message::Result::Empty(protocol::Empty {}) + service_protocol::completion_message::Result::Empty(service_protocol::Empty {}) } - CompletionResult::Success(b) => protocol::completion_message::Result::Value(b), + CompletionResult::Success(b) => service_protocol::completion_message::Result::Value(b), CompletionResult::Failure(code, message) => { - protocol::completion_message::Result::Failure(protocol::Failure { + service_protocol::completion_message::Result::Failure(service_protocol::Failure { code: code.into(), message: message.to_string(), }) @@ -166,13 +165,6 @@ mod mocks { use super::*; use crate::awakeable_id::AwakeableIdentifier; - use crate::pb::protocol::{ - awakeable_entry_message, call_entry_message, complete_awakeable_entry_message, - get_state_entry_message, get_state_keys_entry_message, output_entry_message, - AwakeableEntryMessage, CallEntryMessage, ClearAllStateEntryMessage, ClearStateEntryMessage, - CompleteAwakeableEntryMessage, Failure, GetStateEntryMessage, GetStateKeysEntryMessage, - InputEntryMessage, OneWayCallEntryMessage, OutputEntryMessage, SetStateEntryMessage, - }; use restate_types::identifiers::InvocationId; use restate_types::invocation::{InvocationTarget, VirtualObjectHandlerType}; use restate_types::journal::enriched::{ @@ -182,6 +174,13 @@ mod mocks { AwakeableEntry, CompletableEntry, CompleteAwakeableEntry, EntryResult, GetStateKeysEntry, GetStateKeysResult, GetStateResult, InputEntry, OutputEntry, }; + use restate_types::service_protocol::{ + awakeable_entry_message, call_entry_message, complete_awakeable_entry_message, + get_state_entry_message, get_state_keys_entry_message, output_entry_message, + AwakeableEntryMessage, CallEntryMessage, ClearAllStateEntryMessage, ClearStateEntryMessage, + CompleteAwakeableEntryMessage, Failure, GetStateEntryMessage, GetStateKeysEntryMessage, + InputEntryMessage, OneWayCallEntryMessage, OutputEntryMessage, SetStateEntryMessage, + }; impl ProtobufRawEntryCodec { pub fn serialize(entry: Entry) -> PlainRawEntry { @@ -221,7 +220,7 @@ mod mocks { key: entry.key, result: entry.value.map(|value| match value { GetStateResult::Empty => { - get_state_entry_message::Result::Empty(protocol::Empty {}) + get_state_entry_message::Result::Empty(service_protocol::Empty {}) } GetStateResult::Result(v) => get_state_entry_message::Result::Value(v), GetStateResult::Failure(code, reason) => { @@ -451,12 +450,12 @@ mod tests { is_completed: false, enrichment_result: None, }, - protocol::CallEntryMessage { + service_protocol::CallEntryMessage { service_name: "MySvc".to_string(), handler_name: "MyMethod".to_string(), parameter: Bytes::from_static(b"input"), - ..protocol::CallEntryMessage::default() + ..service_protocol::CallEntryMessage::default() } .encode_to_vec() .into(), diff --git a/crates/service-protocol/src/discovery.rs b/crates/service-protocol/src/discovery.rs index 5a621d13a3..f4102b80c6 100644 --- a/crates/service-protocol/src/discovery.rs +++ b/crates/service-protocol/src/discovery.rs @@ -8,8 +8,6 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. -use crate::pb::protocol; -use crate::{MAX_SERVICE_PROTOCOL_VERSION, MIN_SERVICE_PROTOCOL_VERSION}; use bytes::Bytes; use codederror::CodedError; use hyper::header::{ACCEPT, CONTENT_TYPE}; @@ -19,9 +17,13 @@ use hyper::http::{HeaderName, HeaderValue}; use hyper::{Body, HeaderMap, StatusCode}; use restate_errors::{META0003, META0012, META0013}; use restate_schema_api::deployment::ProtocolType; -use restate_schema_api::MAX_SERVICE_PROTOCOL_VERSION_VALUE; use restate_service_client::{Endpoint, Parts, Request, ServiceClient, ServiceClientError}; +use restate_types::endpoint_manifest; use restate_types::retries::{RetryIter, RetryPolicy}; +use restate_types::service_protocol::{ + ServiceProtocolVersion, MAX_SERVICE_PROTOCOL_VERSION, MAX_SERVICE_PROTOCOL_VERSION_VALUE, + MIN_SERVICE_PROTOCOL_VERSION, +}; use std::borrow::Cow; use std::collections::HashMap; use std::fmt::Display; @@ -32,24 +34,6 @@ const APPLICATION_JSON: HeaderValue = HeaderValue::from_static("application/json const DISCOVER_PATH: &str = "/discover"; -pub mod schema { - #![allow(warnings)] - #![allow(clippy::all)] - #![allow(unknown_lints)] - - include!(concat!(env!("OUT_DIR"), "/deployment.rs")); - - impl From for restate_types::invocation::ServiceType { - fn from(value: ServiceType) -> Self { - match value { - ServiceType::VirtualObject => restate_types::invocation::ServiceType::VirtualObject, - ServiceType::Service => restate_types::invocation::ServiceType::Service, - ServiceType::Workflow => restate_types::invocation::ServiceType::Workflow, - } - } - } -} - #[derive(Clone)] pub struct DiscoverEndpoint(Endpoint, HashMap); @@ -77,7 +61,7 @@ impl DiscoverEndpoint { #[derive(Debug)] pub struct DiscoveredMetadata { pub protocol_type: ProtocolType, - pub services: Vec, + pub services: Vec, // type is i32 because the generated ServiceProtocolVersion enum uses this as its representation // and we need to represent unknown later versions pub supported_protocol_versions: RangeInclusive, @@ -177,18 +161,18 @@ impl ServiceDiscovery { } // Parse the response - let response: schema::Endpoint = + let response: endpoint_manifest::Endpoint = serde_json::from_slice(&body).map_err(|e| DiscoveryError::Decode(e, body))?; Self::create_discovered_metadata_from_endpoint_response(response) } fn create_discovered_metadata_from_endpoint_response( - endpoint_response: schema::Endpoint, + endpoint_response: endpoint_manifest::Endpoint, ) -> Result { let protocol_type = match endpoint_response.protocol_mode { - Some(schema::ProtocolMode::BidiStream) => ProtocolType::BidiStream, - Some(schema::ProtocolMode::RequestResponse) => ProtocolType::RequestResponse, + Some(endpoint_manifest::ProtocolMode::BidiStream) => ProtocolType::BidiStream, + Some(endpoint_manifest::ProtocolMode::RequestResponse) => ProtocolType::RequestResponse, None => { return Err(DiscoveryError::BadResponse("missing protocol mode".into())); } @@ -228,7 +212,7 @@ impl ServiceDiscovery { let min_version = endpoint_response.min_protocol_version as i32; let max_version = endpoint_response.max_protocol_version as i32; - if !protocol::ServiceProtocolVersion::is_supported(min_version, max_version) { + if !ServiceProtocolVersion::is_compatible(min_version, max_version) { return Err(DiscoveryError::UnsupportedServiceProtocol { min_version, max_version, @@ -299,13 +283,14 @@ impl ServiceDiscovery { #[cfg(test)] mod tests { - use crate::discovery::schema::ProtocolMode; - use crate::discovery::{schema, DiscoveryError, ServiceDiscovery}; - use crate::MAX_SERVICE_PROTOCOL_VERSION; + use crate::discovery::endpoint_manifest::ProtocolMode; + use crate::discovery::{DiscoveryError, ServiceDiscovery}; + use restate_types::endpoint_manifest; + use restate_types::service_protocol::MAX_SERVICE_PROTOCOL_VERSION; #[test] fn fail_on_invalid_min_protocol_version_with_bad_response() { - let response = schema::Endpoint { + let response = endpoint_manifest::Endpoint { min_protocol_version: 0, max_protocol_version: 1, services: Vec::new(), @@ -320,7 +305,7 @@ mod tests { #[test] fn fail_on_invalid_max_protocol_version_with_bad_response() { - let response = schema::Endpoint { + let response = endpoint_manifest::Endpoint { min_protocol_version: 1, max_protocol_version: i64::MAX, services: Vec::new(), @@ -335,7 +320,7 @@ mod tests { #[test] fn fail_on_max_protocol_version_smaller_than_min_protocol_version_with_bad_response() { - let response = schema::Endpoint { + let response = endpoint_manifest::Endpoint { min_protocol_version: 10, max_protocol_version: 9, services: Vec::new(), @@ -351,7 +336,7 @@ mod tests { #[test] fn fail_with_unsupported_protocol_version() { let unsupported_version = i32::from(MAX_SERVICE_PROTOCOL_VERSION) + 1; - let response = schema::Endpoint { + let response = endpoint_manifest::Endpoint { min_protocol_version: unsupported_version as i64, max_protocol_version: unsupported_version as i64, services: Vec::new(), diff --git a/crates/service-protocol/src/lib.rs b/crates/service-protocol/src/lib.rs index f13affaf2e..63b8f7d9cb 100644 --- a/crates/service-protocol/src/lib.rs +++ b/crates/service-protocol/src/lib.rs @@ -10,16 +10,8 @@ //! This crate contains the code-generated structs of [service-protocol](https://github.com/restatedev/service-protocol) and the codec to use them. -use crate::pb::protocol; - pub const RESTATE_SERVICE_PROTOCOL_VERSION: u16 = 2; -// Range of supported service protocol versions by this server -pub const MIN_SERVICE_PROTOCOL_VERSION: protocol::ServiceProtocolVersion = - protocol::ServiceProtocolVersion::V1; -pub const MAX_SERVICE_PROTOCOL_VERSION: protocol::ServiceProtocolVersion = - protocol::ServiceProtocolVersion::V1; - #[cfg(feature = "codec")] pub mod codec; #[cfg(feature = "discovery")] @@ -29,239 +21,3 @@ pub mod message; #[cfg(feature = "awakeable-id")] pub mod awakeable_id; - -#[cfg(any(feature = "protocol", test))] -pub mod pb { - pub mod protocol { - #![allow(warnings)] - #![allow(clippy::all)] - #![allow(unknown_lints)] - - use crate::{MAX_SERVICE_PROTOCOL_VERSION, MIN_SERVICE_PROTOCOL_VERSION}; - include!(concat!(env!("OUT_DIR"), "/dev.restate.service.protocol.rs")); - - impl ServiceProtocolVersion { - pub fn is_supported(min_version: i32, max_version: i32) -> bool { - min_version <= i32::from(MAX_SERVICE_PROTOCOL_VERSION) - && max_version >= i32::from(MIN_SERVICE_PROTOCOL_VERSION) - } - - pub fn max_supported_version( - min_version: i32, - max_version: i32, - ) -> Option { - if ServiceProtocolVersion::is_supported(min_version, max_version) { - ServiceProtocolVersion::from_repr(std::cmp::min( - max_version, - i32::from(MAX_SERVICE_PROTOCOL_VERSION), - )) - } else { - None - } - } - } - } - - pub mod discovery { - #![allow(warnings)] - #![allow(clippy::all)] - #![allow(unknown_lints)] - include!(concat!( - env!("OUT_DIR"), - "/dev.restate.service.discovery.rs" - )); - } -} - -/// This module implements conversions back and forth from proto messages to [`journal::Entry`] model. -/// These are used by the [`codec::ProtobufRawEntryCodec`]. -#[cfg(feature = "codec")] -mod pb_into { - use super::pb::protocol::*; - - use restate_types::journal::*; - - impl TryFrom for Entry { - type Error = &'static str; - - fn try_from(msg: InputEntryMessage) -> Result { - Ok(Self::Input(InputEntry { value: msg.value })) - } - } - - impl TryFrom for Entry { - type Error = &'static str; - - fn try_from(msg: OutputEntryMessage) -> Result { - Ok(Entry::Output(OutputEntry { - result: match msg.result.ok_or("result")? { - output_entry_message::Result::Value(r) => EntryResult::Success(r), - output_entry_message::Result::Failure(Failure { code, message }) => { - EntryResult::Failure(code.into(), message.into()) - } - }, - })) - } - } - - impl TryFrom for Entry { - type Error = &'static str; - - fn try_from(msg: GetStateEntryMessage) -> Result { - Ok(Self::GetState(GetStateEntry { - key: msg.key, - value: msg.result.map(|v| match v { - get_state_entry_message::Result::Empty(_) => GetStateResult::Empty, - get_state_entry_message::Result::Value(b) => GetStateResult::Result(b), - get_state_entry_message::Result::Failure(failure) => { - GetStateResult::Failure(failure.code.into(), failure.message.into()) - } - }), - })) - } - } - - impl TryFrom for Entry { - type Error = &'static str; - - fn try_from(msg: SetStateEntryMessage) -> Result { - Ok(Self::SetState(SetStateEntry { - key: msg.key, - value: msg.value, - })) - } - } - - impl TryFrom for Entry { - type Error = &'static str; - - fn try_from(msg: ClearStateEntryMessage) -> Result { - Ok(Self::ClearState(ClearStateEntry { key: msg.key })) - } - } - - impl TryFrom for Entry { - type Error = &'static str; - - fn try_from(msg: GetStateKeysEntryMessage) -> Result { - Ok(Self::GetStateKeys(GetStateKeysEntry { - value: msg.result.map(|v| match v { - get_state_keys_entry_message::Result::Value(b) => { - GetStateKeysResult::Result(b.keys) - } - get_state_keys_entry_message::Result::Failure(failure) => { - GetStateKeysResult::Failure(failure.code.into(), failure.message.into()) - } - }), - })) - } - } - - impl TryFrom for Entry { - type Error = &'static str; - - fn try_from(_: ClearAllStateEntryMessage) -> Result { - Ok(Self::ClearAllState) - } - } - - impl TryFrom for Entry { - type Error = &'static str; - - fn try_from(msg: SleepEntryMessage) -> Result { - Ok(Self::Sleep(SleepEntry { - wake_up_time: msg.wake_up_time, - result: msg.result.map(|r| match r { - sleep_entry_message::Result::Empty(_) => SleepResult::Fired, - sleep_entry_message::Result::Failure(failure) => { - SleepResult::Failure(failure.code.into(), failure.message.into()) - } - }), - })) - } - } - - impl TryFrom for Entry { - type Error = &'static str; - - fn try_from(msg: CallEntryMessage) -> Result { - Ok(Self::Call(InvokeEntry { - request: InvokeRequest { - service_name: msg.service_name.into(), - handler_name: msg.handler_name.into(), - parameter: msg.parameter, - key: msg.key.into(), - }, - result: msg.result.map(|v| match v { - call_entry_message::Result::Value(r) => EntryResult::Success(r), - call_entry_message::Result::Failure(Failure { code, message }) => { - EntryResult::Failure(code.into(), message.into()) - } - }), - })) - } - } - - impl TryFrom for Entry { - type Error = &'static str; - - fn try_from(msg: OneWayCallEntryMessage) -> Result { - Ok(Self::OneWayCall(OneWayCallEntry { - request: InvokeRequest { - service_name: msg.service_name.into(), - handler_name: msg.handler_name.into(), - parameter: msg.parameter, - key: msg.key.into(), - }, - invoke_time: msg.invoke_time, - })) - } - } - - impl TryFrom for Entry { - type Error = &'static str; - - fn try_from(msg: AwakeableEntryMessage) -> Result { - Ok(Self::Awakeable(AwakeableEntry { - result: msg.result.map(|v| match v { - awakeable_entry_message::Result::Value(r) => EntryResult::Success(r), - awakeable_entry_message::Result::Failure(Failure { code, message }) => { - EntryResult::Failure(code.into(), message.into()) - } - }), - })) - } - } - - impl TryFrom for Entry { - type Error = &'static str; - - fn try_from(msg: CompleteAwakeableEntryMessage) -> Result { - Ok(Self::CompleteAwakeable(CompleteAwakeableEntry { - id: msg.id.into(), - result: match msg.result.ok_or("result")? { - complete_awakeable_entry_message::Result::Value(r) => EntryResult::Success(r), - complete_awakeable_entry_message::Result::Failure(Failure { - code, - message, - }) => EntryResult::Failure(code.into(), message.into()), - }, - })) - } - } - - impl TryFrom for Entry { - type Error = &'static str; - - fn try_from(msg: RunEntryMessage) -> Result { - Ok(Self::Run(RunEntry { - result: match msg.result.ok_or("result")? { - run_entry_message::Result::Value(r) => EntryResult::Success(r), - run_entry_message::Result::Failure(Failure { code, message }) => { - EntryResult::Failure(code.into(), message.into()) - } - }, - })) - } - } -} diff --git a/crates/service-protocol/src/message/encoding.rs b/crates/service-protocol/src/message/encoding.rs index 664e90ee02..a18e8058ea 100644 --- a/crates/service-protocol/src/message/encoding.rs +++ b/crates/service-protocol/src/message/encoding.rs @@ -16,6 +16,7 @@ use std::mem; use bytes::{Buf, BufMut, Bytes, BytesMut}; use bytes_utils::SegmentedBuf; use restate_types::journal::raw::{PlainEntryHeader, RawEntry}; +use restate_types::service_protocol::ServiceProtocolVersion; use size::Size; use tracing::warn; @@ -33,13 +34,16 @@ pub enum EncodingError { // --- Input message encoder -pub struct Encoder { - protocol_version: u16, -} +pub struct Encoder {} impl Encoder { - pub fn new(protocol_version: u16) -> Self { - Self { protocol_version } + pub fn new(service_protocol_version: ServiceProtocolVersion) -> Self { + assert_eq!( + service_protocol_version, + ServiceProtocolVersion::V1, + "Encoder only supports service protocol version V1" + ); + Self {} } /// Encodes a message to bytes @@ -63,7 +67,7 @@ impl Encoder { mut buf: impl BufMut, msg: ProtocolMessage, ) -> Result<(), prost::EncodeError> { - let header = generate_header(&msg, self.protocol_version); + let header = generate_header(&msg); buf.put_u64(header.into()); // Note: @@ -74,13 +78,13 @@ impl Encoder { } } -fn generate_header(msg: &ProtocolMessage, protocol_version: u16) -> MessageHeader { +fn generate_header(msg: &ProtocolMessage) -> MessageHeader { let len: u32 = msg .encoded_len() .try_into() .expect("Protocol messages can't be larger than u32"); match msg { - ProtocolMessage::Start(_) => MessageHeader::new_start(protocol_version, len), + ProtocolMessage::Start(_) => MessageHeader::new_start(len), ProtocolMessage::Completion(_) => MessageHeader::new(MessageType::Completion, len), ProtocolMessage::Suspension(_) => MessageHeader::new(MessageType::Suspension, len), ProtocolMessage::Error(_) => MessageHeader::new(MessageType::Error, len), @@ -122,14 +126,17 @@ pub struct Decoder { message_size_limit: usize, } -impl Default for Decoder { - fn default() -> Self { - Decoder::new(usize::MAX, None) - } -} - impl Decoder { - pub fn new(message_size_warning: usize, message_size_limit: Option) -> Self { + pub fn new( + service_protocol_version: ServiceProtocolVersion, + message_size_warning: usize, + message_size_limit: Option, + ) -> Self { + assert_eq!( + service_protocol_version, + ServiceProtocolVersion::V1, + "Decoder only supports service protocol version V1" + ); Self { buf: SegmentedBuf::new(), state: DecoderState::WaitingHeader, @@ -234,17 +241,17 @@ fn decode_protocol_message( mut buf: impl Buf, ) -> Result { Ok(match header.message_type() { - MessageType::Start => ProtocolMessage::Start(pb::protocol::StartMessage::decode(buf)?), + MessageType::Start => ProtocolMessage::Start(service_protocol::StartMessage::decode(buf)?), MessageType::Completion => { - ProtocolMessage::Completion(pb::protocol::CompletionMessage::decode(buf)?) + ProtocolMessage::Completion(service_protocol::CompletionMessage::decode(buf)?) } MessageType::Suspension => { - ProtocolMessage::Suspension(pb::protocol::SuspensionMessage::decode(buf)?) + ProtocolMessage::Suspension(service_protocol::SuspensionMessage::decode(buf)?) } - MessageType::Error => ProtocolMessage::Error(pb::protocol::ErrorMessage::decode(buf)?), - MessageType::End => ProtocolMessage::End(pb::protocol::EndMessage::decode(buf)?), + MessageType::Error => ProtocolMessage::Error(service_protocol::ErrorMessage::decode(buf)?), + MessageType::End => ProtocolMessage::End(service_protocol::EndMessage::decode(buf)?), MessageType::EntryAck => { - ProtocolMessage::EntryAck(pb::protocol::EntryAckMessage::decode(buf)?) + ProtocolMessage::EntryAck(service_protocol::EntryAckMessage::decode(buf)?) } _ => ProtocolMessage::UnparsedEntry(RawEntry::new( message_header_to_raw_header(header), @@ -347,9 +354,8 @@ mod tests { #[test] fn fill_decoder_with_several_messages() { - let protocol_version = 1; - let encoder = Encoder::new(protocol_version); - let mut decoder = Decoder::default(); + let encoder = Encoder::new(ServiceProtocolVersion::V1); + let mut decoder = Decoder::new(ServiceProtocolVersion::V1, usize::MAX, None); let expected_msg_0 = ProtocolMessage::new_start_message( "key".into(), @@ -377,10 +383,6 @@ mod tests { decoder.push(encoder.encode(expected_msg_2.clone())); let (actual_msg_header_0, actual_msg_0) = decoder.consume_next().unwrap().unwrap(); - assert_eq!( - actual_msg_header_0.protocol_version(), - Some(protocol_version) - ); assert_eq!(actual_msg_header_0.message_type(), MessageType::Start); assert_eq!(actual_msg_0, expected_msg_0); @@ -407,8 +409,8 @@ mod tests { } fn partial_decoding_test(split_index: usize) { - let encoder = Encoder::new(0); - let mut decoder = Decoder::default(); + let encoder = Encoder::new(ServiceProtocolVersion::V1); + let mut decoder = Decoder::new(ServiceProtocolVersion::V1, usize::MAX, None); let expected_msg: ProtocolMessage = ProtobufRawEntryCodec::serialize_as_input_entry( vec![], @@ -433,9 +435,13 @@ mod tests { #[test] fn hit_message_size_limit() { - let mut decoder = Decoder::new((u8::MAX / 2) as usize, Some(u8::MAX as usize)); + let mut decoder = Decoder::new( + ServiceProtocolVersion::V1, + (u8::MAX / 2) as usize, + Some(u8::MAX as usize), + ); - let encoder = Encoder::new(0); + let encoder = Encoder::new(ServiceProtocolVersion::V1); let message = ProtocolMessage::from( ProtobufRawEntryCodec::serialize_as_input_entry( vec![], diff --git a/crates/service-protocol/src/message/header.rs b/crates/service-protocol/src/message/header.rs index dc94936544..fb2179557a 100644 --- a/crates/service-protocol/src/message/header.rs +++ b/crates/service-protocol/src/message/header.rs @@ -11,7 +11,6 @@ use restate_types::journal::EntryType; const CUSTOM_MESSAGE_MASK: u16 = 0xFC00; -const VERSION_MASK: u64 = 0x03FF_0000_0000; const COMPLETED_MASK: u64 = 0x0001_0000_0000; const REQUIRES_ACK_MASK: u64 = 0x8000_0000_0000; @@ -87,10 +86,6 @@ impl MessageType { ) } - fn has_protocol_version(&self) -> bool { - *self == MessageType::Start - } - fn has_requires_ack_flag(&self) -> bool { matches!( self.kind(), @@ -215,9 +210,6 @@ pub struct MessageHeader { length: u32, // --- Flags - /// Only `StartMessage` has protocol_version. - protocol_version: Option, - /// Only `CompletableEntries` have completed flag. See [`MessageType#allows_completed_flag`]. completed_flag: Option, /// All Entry messages may have requires ack flag. @@ -227,18 +219,12 @@ pub struct MessageHeader { impl MessageHeader { #[inline] pub fn new(ty: MessageType, length: u32) -> Self { - Self::_new(ty, None, None, None, length) + Self::_new(ty, None, None, length) } #[inline] - pub fn new_start(protocol_version: u16, length: u32) -> Self { - Self::_new( - MessageType::Start, - Some(protocol_version), - None, - None, - length, - ) + pub fn new_start(length: u32) -> Self { + Self::_new(MessageType::Start, None, None, length) } #[inline] @@ -252,7 +238,6 @@ impl MessageHeader { MessageHeader { ty, length, - protocol_version: None, completed_flag, // It is always false when sending entries from the runtime requires_ack_flag: Some(false), @@ -262,7 +247,6 @@ impl MessageHeader { #[inline] fn _new( ty: MessageType, - protocol_version: Option, completed_flag: Option, requires_ack_flag: Option, length: u32, @@ -270,7 +254,6 @@ impl MessageHeader { MessageHeader { ty, length, - protocol_version, completed_flag, requires_ack_flag, } @@ -286,11 +269,6 @@ impl MessageHeader { self.ty } - #[inline] - pub fn protocol_version(&self) -> Option { - self.protocol_version - } - #[inline] pub fn completed(&self) -> Option { self.completed_flag @@ -325,18 +303,13 @@ impl TryFrom for MessageHeader { fn try_from(value: u64) -> Result { let ty_code = (value >> 48) as u16; let ty: MessageType = ty_code.try_into()?; - let protocol_version = if ty.has_protocol_version() { - Some(((value & VERSION_MASK) >> 32) as u16) - } else { - None - }; + let completed_flag = read_flag_if!(ty.has_completed_flag(), value, COMPLETED_MASK); let requires_ack_flag = read_flag_if!(ty.has_requires_ack_flag(), value, REQUIRES_ACK_MASK); let length = value as u32; Ok(MessageHeader::_new( ty, - protocol_version, completed_flag, requires_ack_flag, length, @@ -359,9 +332,6 @@ impl From for u64 { let mut res = ((u16::from(message_header.ty) as u64) << 48) | (message_header.length as u64); - if let Some(protocol_version) = message_header.protocol_version { - res |= (protocol_version as u64) << 32; - } write_flag!(message_header.completed_flag, &mut res, COMPLETED_MASK); write_flag!( message_header.requires_ack_flag, @@ -445,7 +415,6 @@ mod tests { assert_eq!(header.message_type(), $ty); assert_eq!(header.message_kind(), $kind); assert_eq!(header.completed(), $completed); - assert_eq!(header.protocol_version(), $protocol_version); assert_eq!(header.requires_ack(), $requires_ack); assert_eq!(header.frame_length(), $len); } @@ -454,7 +423,7 @@ mod tests { roundtrip_test!( start, - MessageHeader::new_start(1, 25), + MessageHeader::new_start(25), Start, Core, 25, @@ -501,7 +470,7 @@ mod tests { roundtrip_test!( set_state_with_requires_ack, - MessageHeader::_new(SetStateEntry, None, None, Some(true), 10341), + MessageHeader::_new(SetStateEntry, None, Some(true), 10341), SetStateEntry, State, 10341, @@ -519,7 +488,7 @@ mod tests { roundtrip_test!( custom_entry_with_requires_ack, - MessageHeader::_new(MessageType::CustomEntry(0xFC00), None, None, Some(true), 10341), + MessageHeader::_new(MessageType::CustomEntry(0xFC00), None, Some(true), 10341), MessageType::CustomEntry(0xFC00), MessageKind::CustomEntry, 10341, diff --git a/crates/service-protocol/src/message/mod.rs b/crates/service-protocol/src/message/mod.rs index 2ce30d6187..9c6aa76a43 100644 --- a/crates/service-protocol/src/message/mod.rs +++ b/crates/service-protocol/src/message/mod.rs @@ -11,11 +11,8 @@ //! Module containing definitions of Protocol messages, //! including encoding and decoding of headers and message payloads. -use super::pb; - use bytes::Bytes; use prost::Message; -use restate_types::errors::InvocationError; use restate_types::journal::raw::PlainRawEntry; use restate_types::journal::CompletionResult; use restate_types::journal::{Completion, EntryIndex}; @@ -25,16 +22,17 @@ mod header; pub use encoding::{Decoder, Encoder, EncodingError}; pub use header::{MessageHeader, MessageKind, MessageType}; +use restate_types::service_protocol; #[derive(Debug, Clone, PartialEq)] pub enum ProtocolMessage { // Core - Start(pb::protocol::StartMessage), - Completion(pb::protocol::CompletionMessage), - Suspension(pb::protocol::SuspensionMessage), - Error(pb::protocol::ErrorMessage), - End(pb::protocol::EndMessage), - EntryAck(pb::protocol::EntryAckMessage), + Start(service_protocol::StartMessage), + Completion(service_protocol::CompletionMessage), + Suspension(service_protocol::SuspensionMessage), + Error(service_protocol::ErrorMessage), + End(service_protocol::EndMessage), + EntryAck(service_protocol::EntryAckMessage), // Entries are not parsed at this point UnparsedEntry(PlainRawEntry), @@ -49,14 +47,14 @@ impl ProtocolMessage { partial_state: bool, state_map_entries: impl IntoIterator, ) -> Self { - Self::Start(pb::protocol::StartMessage { + Self::Start(service_protocol::StartMessage { id, debug_id, known_entries, partial_state, state_map: state_map_entries .into_iter() - .map(|(key, value)| pb::protocol::start_message::StateEntry { key, value }) + .map(|(key, value)| service_protocol::start_message::StateEntry { key, value }) .collect(), key: key .and_then(|b| String::from_utf8(b.to_vec()).ok()) @@ -65,7 +63,7 @@ impl ProtocolMessage { } pub fn new_entry_ack(entry_index: EntryIndex) -> ProtocolMessage { - Self::EntryAck(pb::protocol::EntryAckMessage { entry_index }) + Self::EntryAck(service_protocol::EntryAckMessage { entry_index }) } pub(crate) fn encoded_len(&self) -> usize { @@ -85,24 +83,24 @@ impl From for ProtocolMessage { fn from(completion: Completion) -> Self { match completion.result { CompletionResult::Empty => { - ProtocolMessage::Completion(pb::protocol::CompletionMessage { + ProtocolMessage::Completion(service_protocol::CompletionMessage { entry_index: completion.entry_index, - result: Some(pb::protocol::completion_message::Result::Empty( - pb::protocol::Empty {}, + result: Some(service_protocol::completion_message::Result::Empty( + service_protocol::Empty {}, )), }) } CompletionResult::Success(b) => { - ProtocolMessage::Completion(pb::protocol::CompletionMessage { + ProtocolMessage::Completion(service_protocol::CompletionMessage { entry_index: completion.entry_index, - result: Some(pb::protocol::completion_message::Result::Value(b)), + result: Some(service_protocol::completion_message::Result::Value(b)), }) } CompletionResult::Failure(code, message) => { - ProtocolMessage::Completion(pb::protocol::CompletionMessage { + ProtocolMessage::Completion(service_protocol::CompletionMessage { entry_index: completion.entry_index, - result: Some(pb::protocol::completion_message::Result::Failure( - pb::protocol::Failure { + result: Some(service_protocol::completion_message::Result::Failure( + service_protocol::Failure { code: code.into(), message: message.to_string(), }, @@ -118,13 +116,3 @@ impl From for ProtocolMessage { Self::UnparsedEntry(value) } } - -impl From for InvocationError { - fn from(value: pb::protocol::ErrorMessage) -> Self { - if value.description.is_empty() { - InvocationError::new(value.code, value.message) - } else { - InvocationError::new(value.code, value.message).with_description(value.description) - } - } -} diff --git a/crates/storage-api/build.rs b/crates/storage-api/build.rs index 96789e48f2..2defa008ae 100644 --- a/crates/storage-api/build.rs +++ b/crates/storage-api/build.rs @@ -13,5 +13,12 @@ fn main() -> std::io::Result<()> { .bytes(["."]) // allow older protobuf compiler to be used .protoc_arg("--experimental_allow_proto3_optional") - .compile_protos(&["proto/dev/restate/storage/v1/domain.proto"], &["proto"]) + .extern_path( + ".dev.restate.service.protocol", + "::restate_types::service_protocol", + ) + .compile_protos( + &["proto/dev/restate/storage/v1/domain.proto"], + &["proto", "../types/service-protocol"], + ) } diff --git a/crates/storage-api/proto/dev/restate/storage/v1/domain.proto b/crates/storage-api/proto/dev/restate/storage/v1/domain.proto index 6f1da6690c..a28cbbf98d 100644 --- a/crates/storage-api/proto/dev/restate/storage/v1/domain.proto +++ b/crates/storage-api/proto/dev/restate/storage/v1/domain.proto @@ -1,6 +1,7 @@ syntax = "proto3"; import "google/protobuf/empty.proto"; +import "dev/restate/service/protocol.proto"; package dev.restate.storage.domain.v1; @@ -83,13 +84,11 @@ message InvocationStatus { repeated ServiceInvocationResponseSink response_sinks = 3; uint64 creation_time = 4; uint64 modification_time = 5; - oneof deployment_id { - google.protobuf.Empty none = 7; - string value = 8; - } - Source source = 9; - Duration completion_retention_time = 10; - optional string idempotency_key = 11; + optional string deployment_id = 7; + Source source = 8; + Duration completion_retention_time = 9; + optional string idempotency_key = 10; + optional dev.restate.service.protocol.ServiceProtocolVersion service_protocol_version = 11; } message Suspended { @@ -99,13 +98,11 @@ message InvocationStatus { uint64 creation_time = 4; uint64 modification_time = 5; repeated uint32 waiting_for_completed_entries = 6; - oneof deployment_id { - google.protobuf.Empty none = 8; - string value = 9; - } - Source source = 10; - Duration completion_retention_time = 11; - optional string idempotency_key = 12; + optional string deployment_id = 7; + Source source = 8; + Duration completion_retention_time = 9; + optional string idempotency_key = 10; + optional dev.restate.service.protocol.ServiceProtocolVersion service_protocol_version = 11; } message Completed { diff --git a/crates/storage-api/src/invocation_status_table/mod.rs b/crates/storage-api/src/invocation_status_table/mod.rs index 238e64652e..c99a1f050f 100644 --- a/crates/storage-api/src/invocation_status_table/mod.rs +++ b/crates/storage-api/src/invocation_status_table/mod.rs @@ -12,7 +12,8 @@ use crate::{protobuf_storage_encode_decode, Result}; use bytes::Bytes; use bytestring::ByteString; use futures_util::Stream; -use restate_types::identifiers::{DeploymentId, EntryIndex, InvocationId, PartitionKey}; +use restate_types::deployment::PinnedDeployment; +use restate_types::identifiers::{EntryIndex, InvocationId, PartitionKey}; use restate_types::invocation::{ Header, InvocationInput, InvocationTarget, ResponseResult, ServiceInvocation, ServiceInvocationResponseSink, ServiceInvocationSpanContext, Source, @@ -250,7 +251,7 @@ impl InboxedInvocation { pub struct InFlightInvocationMetadata { pub invocation_target: InvocationTarget, pub journal_metadata: JournalMetadata, - pub deployment_id: Option, + pub pinned_deployment: Option, pub response_sinks: HashSet, pub timestamps: StatusTimestamps, pub source: Source, @@ -267,7 +268,7 @@ impl InFlightInvocationMetadata { Self { invocation_target: service_invocation.invocation_target, journal_metadata: JournalMetadata::initialize(service_invocation.span_context), - deployment_id: None, + pinned_deployment: None, response_sinks: service_invocation.response_sink.into_iter().collect(), timestamps: StatusTimestamps::now(), source: service_invocation.source, @@ -292,7 +293,7 @@ impl InFlightInvocationMetadata { Self { invocation_target: inboxed_invocation.invocation_target, journal_metadata: JournalMetadata::initialize(inboxed_invocation.span_context), - deployment_id: None, + pinned_deployment: None, response_sinks: inboxed_invocation.response_sinks, timestamps: inboxed_invocation.timestamps, source: inboxed_invocation.source, @@ -306,12 +307,12 @@ impl InFlightInvocationMetadata { ) } - pub fn set_deployment_id(&mut self, deployment_id: DeploymentId) { + pub fn set_pinned_deployment(&mut self, pinned_deployment: PinnedDeployment) { debug_assert_eq!( - self.deployment_id, None, - "No deployment_id should be fixed for the current invocation" + self.pinned_deployment, None, + "No deployment should be chosen for the current invocation" ); - self.deployment_id = Some(deployment_id); + self.pinned_deployment = Some(pinned_deployment); self.timestamps.update(); } } @@ -386,7 +387,7 @@ mod mocks { VirtualObjectHandlerType::Exclusive, ), journal_metadata: JournalMetadata::initialize(ServiceInvocationSpanContext::empty()), - deployment_id: None, + pinned_deployment: None, response_sinks: HashSet::new(), timestamps: StatusTimestamps::now(), source: Source::Ingress, diff --git a/crates/storage-api/src/storage.rs b/crates/storage-api/src/storage.rs index 6b28b08274..a7bc20bb14 100644 --- a/crates/storage-api/src/storage.rs +++ b/crates/storage-api/src/storage.rs @@ -78,11 +78,13 @@ pub mod v1 { use bytestring::ByteString; use opentelemetry::trace::TraceState; use prost::Message; + use restate_types::deployment::PinnedDeployment; use restate_types::errors::{IdDecodeError, InvocationError}; - use restate_types::identifiers::WithPartitionKey; + use restate_types::identifiers::{DeploymentId, WithPartitionKey}; use restate_types::invocation::{InvocationTermination, TerminationFlavor}; use restate_types::journal::enriched::AwakeableEnrichmentResult; + use restate_types::service_protocol::ServiceProtocolVersion; use restate_types::storage::{ StorageCodecKind, StorageDecode, StorageDecodeError, StorageEncode, StorageEncodeError, }; @@ -292,6 +294,34 @@ pub mod v1 { } } + fn derive_pinned_deployment( + deployment_id: Option, + service_protocol_version: Option, + ) -> Result, ConversionError> { + let deployment_id = deployment_id + .map(|deployment_id| deployment_id.parse().expect("valid deployment id")); + + if let Some(deployment_id) = deployment_id { + let service_protocol_version = + service_protocol_version.ok_or(ConversionError::invalid_data(anyhow!( + "service_protocol_version has not been set" + )))?; + let service_protocol_version = ServiceProtocolVersion::from_repr( + service_protocol_version, + ) + .ok_or(ConversionError::unexpected_enum_variant( + "service_protocol_version", + service_protocol_version, + ))?; + Ok(Some(PinnedDeployment::new( + deployment_id, + service_protocol_version, + ))) + } else { + Ok(None) + } + } + impl TryFrom for crate::invocation_status_table::InFlightInvocationMetadata { type Error = ConversionError; @@ -302,15 +332,8 @@ pub mod v1 { .ok_or(ConversionError::missing_field("invocation_target"))?, )?; - let deployment_id = - value.deployment_id.and_then( - |one_of_deployment_id| match one_of_deployment_id { - invocation_status::invoked::DeploymentId::None(_) => None, - invocation_status::invoked::DeploymentId::Value(id) => { - Some(id.parse().expect("valid deployment id")) - } - }, - ); + let pinned_deployment = + derive_pinned_deployment(value.deployment_id, value.service_protocol_version)?; let journal_metadata = crate::invocation_status_table::JournalMetadata::try_from( value @@ -344,7 +367,7 @@ pub mod v1 { Ok(crate::invocation_status_table::InFlightInvocationMetadata { invocation_target, journal_metadata, - deployment_id, + pinned_deployment, response_sinks, timestamps: crate::invocation_status_table::StatusTimestamps::new( MillisSinceEpoch::new(value.creation_time), @@ -361,7 +384,7 @@ pub mod v1 { fn from(value: crate::invocation_status_table::InFlightInvocationMetadata) -> Self { let crate::invocation_status_table::InFlightInvocationMetadata { invocation_target, - deployment_id, + pinned_deployment, response_sinks, journal_metadata, timestamps, @@ -370,24 +393,28 @@ pub mod v1 { idempotency_key, } = value; + let (deployment_id, service_protocol_version) = match pinned_deployment { + None => (None, None), + Some(pinned_deployment) => ( + Some(pinned_deployment.deployment_id.to_string()), + Some(pinned_deployment.service_protocol_version.as_repr()), + ), + }; + Invoked { invocation_target: Some(invocation_target.into()), response_sinks: response_sinks .into_iter() .map(|s| ServiceInvocationResponseSink::from(Some(s))) .collect(), - deployment_id: Some(match deployment_id { - None => invocation_status::invoked::DeploymentId::None(()), - Some(deployment_id) => invocation_status::invoked::DeploymentId::Value( - deployment_id.to_string(), - ), - }), + deployment_id, + service_protocol_version, journal_meta: Some(JournalMeta::from(journal_metadata)), creation_time: timestamps.creation_time().as_u64(), modification_time: timestamps.modification_time().as_u64(), source: Some(Source::from(source)), completion_retention_time: Some(Duration::from(completion_retention_time)), - idempotency_key: idempotency_key.map(|s| s.to_string()), + idempotency_key: idempotency_key.map(|key| key.to_string()), } } } @@ -407,13 +434,8 @@ pub mod v1 { .ok_or(ConversionError::missing_field("invocation_target"))?, )?; - let deployment_id = - value.deployment_id.and_then( - |one_of_deployment_id| match one_of_deployment_id { - invocation_status::suspended::DeploymentId::None(_) => None, - invocation_status::suspended::DeploymentId::Value(id) => Some(id), - }, - ); + let pinned_deployment = + derive_pinned_deployment(value.deployment_id, value.service_protocol_version)?; let journal_metadata = crate::invocation_status_table::JournalMetadata::try_from( value @@ -451,8 +473,7 @@ pub mod v1 { crate::invocation_status_table::InFlightInvocationMetadata { invocation_target, journal_metadata, - deployment_id: deployment_id - .map(|d| d.parse().expect("valid deployment id")), + pinned_deployment, response_sinks, timestamps: crate::invocation_status_table::StatusTimestamps::new( MillisSinceEpoch::new(value.creation_time), @@ -483,6 +504,14 @@ pub mod v1 { let waiting_for_completed_entries = waiting_for_completed_entries.into_iter().collect(); + let (deployment_id, service_protocol_version) = match metadata.pinned_deployment { + None => (None, None), + Some(pinned_deployment) => ( + Some(pinned_deployment.deployment_id.to_string()), + Some(pinned_deployment.service_protocol_version.as_repr()), + ), + }; + Suspended { invocation_target: Some(metadata.invocation_target.into()), response_sinks: metadata @@ -491,12 +520,8 @@ pub mod v1 { .map(|s| ServiceInvocationResponseSink::from(Some(s))) .collect(), journal_meta: Some(journal_meta), - deployment_id: Some(match metadata.deployment_id { - None => invocation_status::suspended::DeploymentId::None(()), - Some(deployment_id) => invocation_status::suspended::DeploymentId::Value( - deployment_id.to_string(), - ), - }), + deployment_id, + service_protocol_version, creation_time: metadata.timestamps.creation_time().as_u64(), modification_time: metadata.timestamps.modification_time().as_u64(), waiting_for_completed_entries, @@ -504,7 +529,7 @@ pub mod v1 { completion_retention_time: Some(Duration::from( metadata.completion_retention_time, )), - idempotency_key: metadata.idempotency_key.map(|s| s.to_string()), + idempotency_key: metadata.idempotency_key.map(|key| key.to_string()), } } } diff --git a/crates/storage-query-datafusion/src/invocation_status/row.rs b/crates/storage-query-datafusion/src/invocation_status/row.rs index 7e4e9dba20..1239a42fa5 100644 --- a/crates/storage-query-datafusion/src/invocation_status/row.rs +++ b/crates/storage-query-datafusion/src/invocation_status/row.rs @@ -90,7 +90,7 @@ fn fill_in_flight_invocation_metadata( meta: InFlightInvocationMetadata, ) { // journal_metadata and stats are filled by other functions - if let Some(deployment_id) = meta.deployment_id { + if let Some(deployment_id) = meta.pinned_deployment { row.pinned_deployment_id(deployment_id.to_string()); } fill_invoked_by(row, output, meta.source) diff --git a/crates/storage-query-datafusion/src/journal/tests.rs b/crates/storage-query-datafusion/src/journal/tests.rs index 10efa20036..59d5f8962c 100644 --- a/crates/storage-query-datafusion/src/journal/tests.rs +++ b/crates/storage-query-datafusion/src/journal/tests.rs @@ -27,6 +27,7 @@ use restate_types::journal::enriched::{ CallEnrichmentResult, EnrichedEntryHeader, EnrichedRawEntry, }; use restate_types::journal::{Entry, EntryType, InputEntry}; +use restate_types::service_protocol; #[tokio::test] async fn get_entries() { @@ -74,7 +75,7 @@ async fn get_entries() { 2, JournalEntry::Entry(EnrichedRawEntry::new( EnrichedEntryHeader::Run {}, - restate_service_protocol::pb::protocol::RunEntryMessage { + service_protocol::RunEntryMessage { name: "my-side-effect".to_string(), result: None, } diff --git a/crates/types/Cargo.toml b/crates/types/Cargo.toml index a1dd1f75bf..eeadf227c2 100644 --- a/crates/types/Cargo.toml +++ b/crates/types/Cargo.toml @@ -33,9 +33,12 @@ http = { workspace = true } humantime = { workspace = true } once_cell = { workspace = true } opentelemetry = { workspace = true } +prost = { workspace = true } rand = { workspace = true } +regress = { version = "0.9" } schemars = { workspace = true, optional = true } serde = { workspace = true, features = ["rc"] } +serde_json = { workspace = true } serde_with = { workspace = true } sha2 = { workspace = true } strum = { workspace = true } @@ -65,3 +68,12 @@ googletest = { workspace = true } rand = { workspace = true } test-log = { workspace = true } tokio = { workspace = true, features = ["test-util"] } + +[build-dependencies] +prost-build = { workspace = true } +prettyplease = "0.2" +schemars = { workspace = true } +serde_json = { workspace = true } +syn = "2.0" +typify = { version = "0.0.16" } +jsonptr = "0.4.7" diff --git a/crates/types/README.md b/crates/types/README.md new file mode 100644 index 0000000000..2543cb4d3d --- /dev/null +++ b/crates/types/README.md @@ -0,0 +1,7 @@ +# Service protocol types + +To update the subtree, from the root directory of the project: + +```shell +git subtree pull --prefix crates/types/service-protocol git@github.com:restatedev/service-protocol.git main --squash +``` \ No newline at end of file diff --git a/crates/service-protocol/build.rs b/crates/types/build.rs similarity index 94% rename from crates/service-protocol/build.rs rename to crates/types/build.rs index 3b5de11dd4..d1bda9278d 100644 --- a/crates/service-protocol/build.rs +++ b/crates/types/build.rs @@ -20,7 +20,7 @@ fn main() -> std::io::Result<()> { .protoc_arg("--experimental_allow_proto3_optional") .enum_attribute( "protocol.ServiceProtocolVersion", - "#[derive(::strum_macros::FromRepr)]", + "#[derive(::serde::Serialize, ::serde::Deserialize, ::strum_macros::FromRepr)]", ) .compile_protos( &[ @@ -72,6 +72,6 @@ fn main() -> std::io::Result<()> { ); let mut out_file = Path::new(&env::var("OUT_DIR").unwrap()).to_path_buf(); - out_file.push("deployment.rs"); + out_file.push("endpoint_manifest.rs"); std::fs::write(out_file, contents) } diff --git a/crates/service-protocol/service-protocol/.github/workflows/lint.yaml b/crates/types/service-protocol/.github/workflows/lint.yaml similarity index 89% rename from crates/service-protocol/service-protocol/.github/workflows/lint.yaml rename to crates/types/service-protocol/.github/workflows/lint.yaml index 4f006e3ac1..8b2d8e5fc1 100644 --- a/crates/service-protocol/service-protocol/.github/workflows/lint.yaml +++ b/crates/types/service-protocol/.github/workflows/lint.yaml @@ -12,7 +12,7 @@ jobs: steps: - name: Checkout Code - uses: actions/checkout@v4 + uses: actions/checkout@v3 - name: Run protolint uses: plexsystems/protolint-action@v0.7.0 diff --git a/crates/service-protocol/service-protocol/.gitignore b/crates/types/service-protocol/.gitignore similarity index 100% rename from crates/service-protocol/service-protocol/.gitignore rename to crates/types/service-protocol/.gitignore diff --git a/crates/service-protocol/service-protocol/.prettierrc.toml b/crates/types/service-protocol/.prettierrc.toml similarity index 100% rename from crates/service-protocol/service-protocol/.prettierrc.toml rename to crates/types/service-protocol/.prettierrc.toml diff --git a/crates/service-protocol/service-protocol/.protolint.yaml b/crates/types/service-protocol/.protolint.yaml similarity index 100% rename from crates/service-protocol/service-protocol/.protolint.yaml rename to crates/types/service-protocol/.protolint.yaml diff --git a/crates/service-protocol/service-protocol/LICENSE b/crates/types/service-protocol/LICENSE similarity index 100% rename from crates/service-protocol/service-protocol/LICENSE rename to crates/types/service-protocol/LICENSE diff --git a/crates/service-protocol/service-protocol/README.md b/crates/types/service-protocol/README.md similarity index 100% rename from crates/service-protocol/service-protocol/README.md rename to crates/types/service-protocol/README.md diff --git a/crates/service-protocol/service-protocol/dev/restate/service/discovery.proto b/crates/types/service-protocol/dev/restate/service/discovery.proto similarity index 100% rename from crates/service-protocol/service-protocol/dev/restate/service/discovery.proto rename to crates/types/service-protocol/dev/restate/service/discovery.proto diff --git a/crates/service-protocol/service-protocol/dev/restate/service/protocol.proto b/crates/types/service-protocol/dev/restate/service/protocol.proto similarity index 100% rename from crates/service-protocol/service-protocol/dev/restate/service/protocol.proto rename to crates/types/service-protocol/dev/restate/service/protocol.proto diff --git a/crates/service-protocol/service-protocol/endpoint_manifest_schema.json b/crates/types/service-protocol/endpoint_manifest_schema.json similarity index 100% rename from crates/service-protocol/service-protocol/endpoint_manifest_schema.json rename to crates/types/service-protocol/endpoint_manifest_schema.json diff --git a/crates/service-protocol/service-protocol/service-invocation-protocol.md b/crates/types/service-protocol/service-invocation-protocol.md similarity index 100% rename from crates/service-protocol/service-protocol/service-invocation-protocol.md rename to crates/types/service-protocol/service-invocation-protocol.md diff --git a/crates/types/src/deployment.rs b/crates/types/src/deployment.rs index 2847749818..b17f924186 100644 --- a/crates/types/src/deployment.rs +++ b/crates/types/src/deployment.rs @@ -9,6 +9,7 @@ // by the Apache License, Version 2.0. use std::fmt; +use std::fmt::{Display, Formatter}; use std::mem::size_of; use std::str::FromStr; @@ -18,6 +19,7 @@ use crate::base62_util::base62_max_length_for_type; use crate::errors::IdDecodeError; use crate::id_util::{IdDecoder, IdEncoder, IdResourceType}; use crate::identifiers::{DeploymentId, ResourceId, TimestampAwareId}; +use crate::service_protocol::ServiceProtocolVersion; use crate::time::MillisSinceEpoch; impl ResourceId for DeploymentId { @@ -78,6 +80,36 @@ impl schemars::JsonSchema for DeploymentId { } } +/// Deployment which was chosen to run an invocation on. +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct PinnedDeployment { + pub deployment_id: DeploymentId, + pub service_protocol_version: ServiceProtocolVersion, +} + +impl Display for PinnedDeployment { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!( + f, + "id: {}, service protocol version: {}", + self.deployment_id, + self.service_protocol_version.as_repr() + ) + } +} + +impl PinnedDeployment { + pub fn new( + deployment_id: DeploymentId, + service_protocol_version: ServiceProtocolVersion, + ) -> Self { + Self { + deployment_id, + service_protocol_version, + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/types/src/endpoint_manifest.rs b/crates/types/src/endpoint_manifest.rs new file mode 100644 index 0000000000..68a983c0c1 --- /dev/null +++ b/crates/types/src/endpoint_manifest.rs @@ -0,0 +1,27 @@ +// Copyright (c) 2024 - Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +mod generated { + #![allow(clippy::clone_on_copy)] + + include!(concat!(env!("OUT_DIR"), "/endpoint_manifest.rs")); +} + +pub use generated::*; + +impl From for crate::invocation::ServiceType { + fn from(value: ServiceType) -> Self { + match value { + ServiceType::VirtualObject => crate::invocation::ServiceType::VirtualObject, + ServiceType::Service => crate::invocation::ServiceType::Service, + ServiceType::Workflow => crate::invocation::ServiceType::Workflow, + } + } +} diff --git a/crates/types/src/lib.rs b/crates/types/src/lib.rs index 03095ed827..42d129abd9 100644 --- a/crates/types/src/lib.rs +++ b/crates/types/src/lib.rs @@ -20,6 +20,7 @@ pub mod arc_util; pub mod art; pub mod config; pub mod deployment; +pub mod endpoint_manifest; pub mod epoch; pub mod errors; pub mod identifiers; @@ -33,6 +34,8 @@ pub mod net; pub mod nodes_config; pub mod partition_table; pub mod retries; +pub mod service_discovery; +pub mod service_protocol; pub mod state_mut; pub mod storage; pub mod subscription; diff --git a/crates/types/src/service_discovery.rs b/crates/types/src/service_discovery.rs new file mode 100644 index 0000000000..9cd1061cf6 --- /dev/null +++ b/crates/types/src/service_discovery.rs @@ -0,0 +1,14 @@ +// Copyright (c) 2024 - Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +include!(concat!( + env!("OUT_DIR"), + "/dev.restate.service.discovery.rs" +)); diff --git a/crates/types/src/service_protocol.rs b/crates/types/src/service_protocol.rs new file mode 100644 index 0000000000..55ee8ffc56 --- /dev/null +++ b/crates/types/src/service_protocol.rs @@ -0,0 +1,255 @@ +// Copyright (c) 2024 - Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use crate::errors::InvocationError; +use std::ops::RangeInclusive; + +// Range of supported service protocol versions by this server +pub const MIN_SERVICE_PROTOCOL_VERSION: ServiceProtocolVersion = ServiceProtocolVersion::V1; +pub const MAX_SERVICE_PROTOCOL_VERSION: ServiceProtocolVersion = ServiceProtocolVersion::V1; + +pub const MAX_SERVICE_PROTOCOL_VERSION_VALUE: i32 = i32::MAX; + +include!(concat!(env!("OUT_DIR"), "/dev.restate.service.protocol.rs")); + +impl ServiceProtocolVersion { + pub fn as_repr(&self) -> i32 { + i32::from(*self) + } + + pub fn is_compatible(min_version: i32, max_version: i32) -> bool { + min_version <= i32::from(MAX_SERVICE_PROTOCOL_VERSION) + && max_version >= i32::from(MIN_SERVICE_PROTOCOL_VERSION) + } + + pub fn is_supported(version: ServiceProtocolVersion) -> bool { + MIN_SERVICE_PROTOCOL_VERSION <= version && version <= MAX_SERVICE_PROTOCOL_VERSION + } + + pub fn choose_max_supported_version( + versions: &RangeInclusive, + ) -> Option { + if ServiceProtocolVersion::is_compatible(*versions.start(), *versions.end()) { + ServiceProtocolVersion::from_repr(std::cmp::min( + *versions.end(), + i32::from(MAX_SERVICE_PROTOCOL_VERSION), + )) + } else { + None + } + } +} + +impl From for InvocationError { + fn from(value: ErrorMessage) -> Self { + if value.description.is_empty() { + InvocationError::new(value.code, value.message) + } else { + InvocationError::new(value.code, value.message).with_description(value.description) + } + } +} + +/// This module implements conversions back and forth from proto messages to [`journal::Entry`] model. +/// These are used by the [`codec::ProtobufRawEntryCodec`]. +mod pb_into { + use super::*; + + use crate::journal::{ + AwakeableEntry, ClearStateEntry, CompleteAwakeableEntry, Entry, EntryResult, GetStateEntry, + GetStateKeysEntry, GetStateKeysResult, GetStateResult, InputEntry, InvokeEntry, + InvokeRequest, OneWayCallEntry, OutputEntry, RunEntry, SetStateEntry, SleepEntry, + SleepResult, + }; + + impl TryFrom for Entry { + type Error = &'static str; + + fn try_from(msg: InputEntryMessage) -> Result { + Ok(Self::Input(InputEntry { value: msg.value })) + } + } + + impl TryFrom for Entry { + type Error = &'static str; + + fn try_from(msg: OutputEntryMessage) -> Result { + Ok(Entry::Output(OutputEntry { + result: match msg.result.ok_or("result")? { + output_entry_message::Result::Value(r) => EntryResult::Success(r), + output_entry_message::Result::Failure(Failure { code, message }) => { + EntryResult::Failure(code.into(), message.into()) + } + }, + })) + } + } + + impl TryFrom for Entry { + type Error = &'static str; + + fn try_from(msg: GetStateEntryMessage) -> Result { + Ok(Self::GetState(GetStateEntry { + key: msg.key, + value: msg.result.map(|v| match v { + get_state_entry_message::Result::Empty(_) => GetStateResult::Empty, + get_state_entry_message::Result::Value(b) => GetStateResult::Result(b), + get_state_entry_message::Result::Failure(failure) => { + GetStateResult::Failure(failure.code.into(), failure.message.into()) + } + }), + })) + } + } + + impl TryFrom for Entry { + type Error = &'static str; + + fn try_from(msg: SetStateEntryMessage) -> Result { + Ok(Self::SetState(SetStateEntry { + key: msg.key, + value: msg.value, + })) + } + } + + impl TryFrom for Entry { + type Error = &'static str; + + fn try_from(msg: ClearStateEntryMessage) -> Result { + Ok(Self::ClearState(ClearStateEntry { key: msg.key })) + } + } + + impl TryFrom for Entry { + type Error = &'static str; + + fn try_from(msg: GetStateKeysEntryMessage) -> Result { + Ok(Self::GetStateKeys(GetStateKeysEntry { + value: msg.result.map(|v| match v { + get_state_keys_entry_message::Result::Value(b) => { + GetStateKeysResult::Result(b.keys) + } + get_state_keys_entry_message::Result::Failure(failure) => { + GetStateKeysResult::Failure(failure.code.into(), failure.message.into()) + } + }), + })) + } + } + + impl TryFrom for Entry { + type Error = &'static str; + + fn try_from(_: ClearAllStateEntryMessage) -> Result { + Ok(Self::ClearAllState) + } + } + + impl TryFrom for Entry { + type Error = &'static str; + + fn try_from(msg: SleepEntryMessage) -> Result { + Ok(Self::Sleep(SleepEntry { + wake_up_time: msg.wake_up_time, + result: msg.result.map(|r| match r { + sleep_entry_message::Result::Empty(_) => SleepResult::Fired, + sleep_entry_message::Result::Failure(failure) => { + SleepResult::Failure(failure.code.into(), failure.message.into()) + } + }), + })) + } + } + + impl TryFrom for Entry { + type Error = &'static str; + + fn try_from(msg: CallEntryMessage) -> Result { + Ok(Self::Call(InvokeEntry { + request: InvokeRequest { + service_name: msg.service_name.into(), + handler_name: msg.handler_name.into(), + parameter: msg.parameter, + key: msg.key.into(), + }, + result: msg.result.map(|v| match v { + call_entry_message::Result::Value(r) => EntryResult::Success(r), + call_entry_message::Result::Failure(Failure { code, message }) => { + EntryResult::Failure(code.into(), message.into()) + } + }), + })) + } + } + + impl TryFrom for Entry { + type Error = &'static str; + + fn try_from(msg: OneWayCallEntryMessage) -> Result { + Ok(Self::OneWayCall(OneWayCallEntry { + request: InvokeRequest { + service_name: msg.service_name.into(), + handler_name: msg.handler_name.into(), + parameter: msg.parameter, + key: msg.key.into(), + }, + invoke_time: msg.invoke_time, + })) + } + } + + impl TryFrom for Entry { + type Error = &'static str; + + fn try_from(msg: AwakeableEntryMessage) -> Result { + Ok(Self::Awakeable(AwakeableEntry { + result: msg.result.map(|v| match v { + awakeable_entry_message::Result::Value(r) => EntryResult::Success(r), + awakeable_entry_message::Result::Failure(Failure { code, message }) => { + EntryResult::Failure(code.into(), message.into()) + } + }), + })) + } + } + + impl TryFrom for Entry { + type Error = &'static str; + + fn try_from(msg: CompleteAwakeableEntryMessage) -> Result { + Ok(Self::CompleteAwakeable(CompleteAwakeableEntry { + id: msg.id.into(), + result: match msg.result.ok_or("result")? { + complete_awakeable_entry_message::Result::Value(r) => EntryResult::Success(r), + complete_awakeable_entry_message::Result::Failure(Failure { + code, + message, + }) => EntryResult::Failure(code.into(), message.into()), + }, + })) + } + } + + impl TryFrom for Entry { + type Error = &'static str; + + fn try_from(msg: RunEntryMessage) -> Result { + Ok(Self::Run(RunEntry { + result: match msg.result.ok_or("result")? { + run_entry_message::Result::Value(r) => EntryResult::Success(r), + run_entry_message::Result::Failure(Failure { code, message }) => { + EntryResult::Failure(code.into(), message.into()) + } + }, + })) + } + } +} diff --git a/crates/worker/Cargo.toml b/crates/worker/Cargo.toml index 1bdb15f775..5044625f5c 100644 --- a/crates/worker/Cargo.toml +++ b/crates/worker/Cargo.toml @@ -38,7 +38,7 @@ restate-schema = { workspace = true } restate-schema-api = { workspace = true, features = [ "service", "subscription"] } restate-serde-util = { workspace = true, features = ["proto"] } restate-service-client = { workspace = true } -restate-service-protocol = { workspace = true, features = [ "codec", "awakeable-id", "protocol", "message" ] } +restate-service-protocol = { workspace = true, features = [ "codec", "awakeable-id", "message" ] } restate-storage-api = { workspace = true } restate-storage-query-datafusion = { workspace = true } restate-storage-query-postgres = { workspace = true } diff --git a/crates/worker/src/partition/state_machine/command_interpreter/mod.rs b/crates/worker/src/partition/state_machine/command_interpreter/mod.rs index 2d1987b3ef..930cf5d16f 100644 --- a/crates/worker/src/partition/state_machine/command_interpreter/mod.rs +++ b/crates/worker/src/partition/state_machine/command_interpreter/mod.rs @@ -959,8 +959,12 @@ where invocation_metadata: InFlightInvocationMetadata, ) -> Result<(), Error> { match kind { - InvokerEffectKind::SelectedDeployment(deployment_id) => { - effects.store_chosen_deployment(invocation_id, deployment_id, invocation_metadata); + InvokerEffectKind::PinnedDeployment(pinned_deployment) => { + effects.store_pinned_deployment( + invocation_id, + pinned_deployment, + invocation_metadata, + ); } InvokerEffectKind::JournalEntry { entry_index, entry } => { self.handle_journal_entry( diff --git a/crates/worker/src/partition/state_machine/command_interpreter/tests.rs b/crates/worker/src/partition/state_machine/command_interpreter/tests.rs index 6222e092f5..a162a507f4 100644 --- a/crates/worker/src/partition/state_machine/command_interpreter/tests.rs +++ b/crates/worker/src/partition/state_machine/command_interpreter/tests.rs @@ -18,7 +18,6 @@ use prost::Message; use restate_invoker_api::EffectKind; use restate_service_protocol::awakeable_id::AwakeableIdentifier; use restate_service_protocol::codec::ProtobufRawEntryCodec; -use restate_service_protocol::pb::protocol::SleepEntryMessage; use restate_storage_api::idempotency_table::IdempotencyMetadata; use restate_storage_api::inbox_table::SequenceNumberInboxEntry; use restate_storage_api::invocation_status_table::{JournalMetadata, StatusTimestamps}; @@ -31,6 +30,7 @@ use restate_types::identifiers::{InvocationUuid, WithPartitionKey}; use restate_types::invocation::InvocationTarget; use restate_types::journal::EntryResult; use restate_types::journal::{CompleteAwakeableEntry, Entry}; +use restate_types::service_protocol; use std::collections::HashMap; use test_log::test; @@ -736,7 +736,7 @@ fn create_termination_journal( EnrichedEntryHeader::Sleep { is_completed: false, }, - SleepEntryMessage { + service_protocol::SleepEntryMessage { wake_up_time: 1337, result: None, ..Default::default() diff --git a/crates/worker/src/partition/state_machine/effect_interpreter.rs b/crates/worker/src/partition/state_machine/effect_interpreter.rs index 72323deaf9..1913b3f7ed 100644 --- a/crates/worker/src/partition/state_machine/effect_interpreter.rs +++ b/crates/worker/src/partition/state_machine/effect_interpreter.rs @@ -337,12 +337,12 @@ impl EffectInterpreter { state_storage.delete_timer(&timer_key).await?; collector.push(Action::DeleteTimer { timer_key }); } - Effect::StoreDeploymentId { + Effect::StorePinnedDeployment { invocation_id, - deployment_id, + pinned_deployment, mut metadata, } => { - metadata.set_deployment_id(deployment_id); + metadata.set_pinned_deployment(pinned_deployment); // We recreate the InvocationStatus in Invoked state as the invoker can notify the // chosen deployment_id only when the invocation is in-flight. diff --git a/crates/worker/src/partition/state_machine/effects.rs b/crates/worker/src/partition/state_machine/effects.rs index d9f6e6b9e7..41ad086438 100644 --- a/crates/worker/src/partition/state_machine/effects.rs +++ b/crates/worker/src/partition/state_machine/effects.rs @@ -18,10 +18,9 @@ use restate_storage_api::invocation_status_table::{ use restate_storage_api::invocation_status_table::{InvocationStatus, JournalMetadata}; use restate_storage_api::outbox_table::OutboxMessage; use restate_storage_api::timer_table::{Timer, TimerKey}; +use restate_types::deployment::PinnedDeployment; use restate_types::errors::InvocationErrorCode; -use restate_types::identifiers::{ - DeploymentId, EntryIndex, IdempotencyId, InvocationId, ServiceId, -}; +use restate_types::identifiers::{EntryIndex, IdempotencyId, InvocationId, ServiceId}; use restate_types::ingress::IngressResponse; use restate_types::invocation::{ InvocationResponse, InvocationTarget, ResponseResult, ServiceInvocation, @@ -108,9 +107,9 @@ pub(crate) enum Effect { DeleteTimer(TimerKey), // Journal operations - StoreDeploymentId { + StorePinnedDeployment { invocation_id: InvocationId, - deployment_id: DeploymentId, + pinned_deployment: PinnedDeployment, metadata: InFlightInvocationMetadata, }, AppendResponseSink { @@ -477,10 +476,13 @@ impl Effect { "Effect: Delete timer" ) } - Effect::StoreDeploymentId { deployment_id, .. } => debug_if_leader!( + Effect::StorePinnedDeployment { + pinned_deployment, .. + } => debug_if_leader!( is_leader, - restate.deployment.id = %deployment_id, - "Effect: Store deployment id to storage" + restate.deployment.id = %pinned_deployment.deployment_id, + restate.deployment.service_protocol_version = %pinned_deployment.service_protocol_version.as_repr(), + "Effect: Store chosen deployment to storage" ), Effect::AppendJournalEntry { journal_entry, @@ -854,15 +856,15 @@ impl Effects { self.effects.push(Effect::DeleteTimer(timer_key)); } - pub(crate) fn store_chosen_deployment( + pub(crate) fn store_pinned_deployment( &mut self, invocation_id: InvocationId, - deployment_id: DeploymentId, + pinned_deployment: PinnedDeployment, metadata: InFlightInvocationMetadata, ) { - self.effects.push(Effect::StoreDeploymentId { + self.effects.push(Effect::StorePinnedDeployment { invocation_id, - deployment_id, + pinned_deployment, metadata, }) } diff --git a/crates/worker/src/partition/storage/invoker.rs b/crates/worker/src/partition/storage/invoker.rs index ccc2b28d40..d8d1cd073a 100644 --- a/crates/worker/src/partition/storage/invoker.rs +++ b/crates/worker/src/partition/storage/invoker.rs @@ -55,7 +55,7 @@ where let journal_metadata = JournalMetadata::new( invoked_status.journal_metadata.length, invoked_status.journal_metadata.span_context, - invoked_status.deployment_id, + invoked_status.pinned_deployment, ); let journal_stream = self .0 diff --git a/tools/service-protocol-wireshark-dissector/Cargo.toml b/tools/service-protocol-wireshark-dissector/Cargo.toml index 129a0151cc..19fad2b98a 100644 --- a/tools/service-protocol-wireshark-dissector/Cargo.toml +++ b/tools/service-protocol-wireshark-dissector/Cargo.toml @@ -22,6 +22,7 @@ luajit = ["mlua/luajit"] [dependencies] # Dependencies needed to decode packets restate-service-protocol = { workspace = true, features = ["codec", "message"] } +restate-types = { workspace = true } bytes = { workspace = true } thiserror = { workspace = true } diff --git a/tools/service-protocol-wireshark-dissector/src/lib.rs b/tools/service-protocol-wireshark-dissector/src/lib.rs index 88e2eb0858..535c42986d 100644 --- a/tools/service-protocol-wireshark-dissector/src/lib.rs +++ b/tools/service-protocol-wireshark-dissector/src/lib.rs @@ -14,6 +14,7 @@ use mlua::{Table, Value}; use restate_service_protocol::codec::ProtobufRawEntryCodec; use restate_service_protocol::message::{Decoder, MessageType, ProtocolMessage}; +use restate_types::service_protocol::ServiceProtocolVersion; #[derive(Debug, thiserror::Error)] #[error("unexpected lua value received")] @@ -31,7 +32,7 @@ fn decode_packages<'lua>(lua: &'lua Lua, buf_lua: Value<'lua>) -> LuaResult(lua: &'lua Lua, buf_lua: Value<'lua>) -> LuaResult protocol_version); - } if let Some(completed) = header.completed() { set_table_values!(message_table, "completed" => completed); }