From c4abe860318535625bce1bd2a7620335aa460f5a Mon Sep 17 00:00:00 2001 From: Ahmed Farghal Date: Thu, 16 May 2024 09:44:21 +0100 Subject: [PATCH] Unify RpcStyle message routing Introduces types that makes it easier to perform Rpc-like interactions with Networking. - `RpcMessage` trait marks messages that carry a `CorrelationId`. A sane default correlation id is provided as `RequestId`. - `RpcRequest` trait defines request messages and their response types. - `RpcRouter` enables sending rpc request and awaiting responses with auto eviction of dropped requests. - `ResponseTracker` is a helper that manages tracking tokens for in-flight requests, this can be used in the future to replace large portions of IngressDispatcher. - Macros to help define RPC messages to reduce code noise in node-protocol --- Cargo.lock | 1 + crates/ingress-dispatcher/src/dispatcher.rs | 17 +- crates/ingress-dispatcher/src/lib.rs | 9 +- crates/network/Cargo.toml | 1 + crates/network/src/lib.rs | 1 + crates/network/src/rpc_router.rs | 423 ++++++++++++++++++++ crates/node-protocol/src/common.rs | 48 +++ crates/node-protocol/src/ingress.rs | 54 ++- crates/node-protocol/src/lib.rs | 127 +++++- crates/node-protocol/src/metadata.rs | 33 +- 10 files changed, 634 insertions(+), 80 deletions(-) create mode 100644 crates/network/src/rpc_router.rs diff --git a/Cargo.lock b/Cargo.lock index 06070ace8b..316904314c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5667,6 +5667,7 @@ dependencies = [ "async-trait", "bincode", "bytes", + "dashmap", "drain", "enum-map", "enumset", diff --git a/crates/ingress-dispatcher/src/dispatcher.rs b/crates/ingress-dispatcher/src/dispatcher.rs index 75a7a2f9b2..c35a9ede76 100644 --- a/crates/ingress-dispatcher/src/dispatcher.rs +++ b/crates/ingress-dispatcher/src/dispatcher.rs @@ -10,15 +10,16 @@ use crate::error::IngressDispatchError; use crate::{ - IngressCorrelationId, IngressDispatcherRequest, IngressDispatcherRequestInner, - IngressDispatcherResponse, IngressRequestMode, IngressResponseSender, + IngressDispatcherRequest, IngressDispatcherRequestInner, IngressDispatcherResponse, + IngressRequestMode, IngressResponseSender, }; use dashmap::DashMap; use restate_bifrost::Bifrost; use restate_core::metadata; use restate_core::network::MessageHandler; use restate_node_protocol::codec::Targeted; -use restate_node_protocol::ingress::IngressMessage; +use restate_node_protocol::ingress::{IngressCorrelationId, IngressMessage}; +use restate_node_protocol::RpcMessage; use restate_storage_api::deduplication_table::DedupInformation; use restate_types::identifiers::{PartitionKey, WithPartitionKey}; use restate_types::message::MessageIndex; @@ -132,15 +133,7 @@ impl MessageHandler for IngressDispatcher { trace!("Processing message '{}' from '{}'", msg.kind(), peer); match msg { IngressMessage::InvocationResponse(invocation_response) => { - let correlation_id = invocation_response - .idempotency_id - .as_ref() - .map(|idempotency_id| { - IngressCorrelationId::IdempotencyId(idempotency_id.clone()) - }) - .unwrap_or_else(|| { - IngressCorrelationId::InvocationId(invocation_response.invocation_id) - }); + let correlation_id = invocation_response.correlation_id(); if let Some((_, sender)) = self.state.waiting_responses.remove(&correlation_id) { let dispatcher_response = IngressDispatcherResponse { // TODO we need to add back the expiration time for idempotent results diff --git a/crates/ingress-dispatcher/src/lib.rs b/crates/ingress-dispatcher/src/lib.rs index 7891cb81bc..5ed8b9c067 100644 --- a/crates/ingress-dispatcher/src/lib.rs +++ b/crates/ingress-dispatcher/src/lib.rs @@ -11,6 +11,7 @@ use bytes::Bytes; use bytestring::ByteString; use restate_core::metadata; +pub use restate_node_protocol::ingress::IngressCorrelationId; use restate_schema_api::subscription::{EventReceiverServiceType, Sink, Subscription}; use restate_types::identifiers::{ partitioner, IdempotencyId, InvocationId, PartitionKey, WithPartitionKey, @@ -32,14 +33,6 @@ pub use dispatcher::{DispatchIngressRequest, IngressDispatcher}; pub type IngressResponseSender = oneshot::Sender; pub type IngressResponseReceiver = oneshot::Receiver; -// TODO we could eventually remove this type and replace it with something simpler once -// https://github.com/restatedev/restate/issues/1329 is in place -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum IngressCorrelationId { - InvocationId(InvocationId), - IdempotencyId(IdempotencyId), -} - #[derive(Debug)] enum IngressDispatcherRequestInner { Invoke(ServiceInvocation), diff --git a/crates/network/Cargo.toml b/crates/network/Cargo.toml index 8ac9e46dcf..837927f21b 100644 --- a/crates/network/Cargo.toml +++ b/crates/network/Cargo.toml @@ -20,6 +20,7 @@ anyhow = { workspace = true } async-trait = { workspace = true } bincode = { workspace = true } bytes = { workspace = true } +dashmap = { workspace = true } drain = { workspace = true } enum-map = { workspace = true } enumset = { workspace = true } diff --git a/crates/network/src/lib.rs b/crates/network/src/lib.rs index e3b3e8fe36..5d40fe1739 100644 --- a/crates/network/src/lib.rs +++ b/crates/network/src/lib.rs @@ -14,6 +14,7 @@ pub mod error; mod handshake; pub(crate) mod metric_definitions; mod networking; +pub mod rpc_router; pub use connection::ConnectionSender; pub use connection_manager::ConnectionManager; diff --git a/crates/network/src/rpc_router.rs b/crates/network/src/rpc_router.rs new file mode 100644 index 0000000000..23e458bab9 --- /dev/null +++ b/crates/network/src/rpc_router.rs @@ -0,0 +1,423 @@ +// Copyright (c) 2024 - Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::sync::{Arc, Weak}; + +use dashmap::DashMap; +use futures::stream::BoxStream; +use futures::StreamExt; +use restate_core::{cancellation_watcher, ShutdownError}; +use restate_node_protocol::codec::{Targeted, WireDecode, WireEncode}; +use restate_types::NodeId; +use tokio::sync::oneshot; + +use restate_core::network::{ + MessageHandler, MessageRouterBuilder, NetworkSendError, NetworkSender, +}; +use restate_node_protocol::{MessageEnvelope, RpcMessage, RpcRequest}; +use tracing::warn; + +use crate::Networking; + +/// A router for sending and receiving RPC messages through Networking +/// +/// It's responsible for keeping track of in-flight requests, correlating responses, and dropping +/// tracking tokens if caller dropped the future. +/// +/// This type is designed to be used by senders of RpcRequest(s). +pub struct RpcRouter +where + T: RpcRequest, +{ + networking: Networking, + response_tracker: ResponseTracker, +} + +#[derive(thiserror::Error, Debug)] +#[error(transparent)] +pub enum RpcError { + #[error("correlation id {0} is already in-flight")] + CorrelationIdExists(String), + SendError(#[from] NetworkSendError), + Shutdown(#[from] ShutdownError), +} + +impl RpcRouter +where + T: RpcRequest + WireEncode + Send + Sync + 'static, + T::Response: WireDecode + Send + Sync + 'static, + ::CorrelationId: Send + Sync + From, +{ + pub fn new(networking: Networking, router_builder: &mut MessageRouterBuilder) -> Self { + let response_tracker = ResponseTracker::::default(); + router_builder.add_message_handler(response_tracker.clone()); + Self { + networking, + response_tracker, + } + } + + pub async fn call(&self, to: NodeId, msg: &T) -> Result, RpcError> + where + ::CorrelationId: Default, + { + let token = self + .response_tracker + .new_token(msg.correlation_id().into()) + .ok_or_else(|| RpcError::CorrelationIdExists(format!("{:?}", msg.correlation_id())))?; + self.networking.send(to, &msg).await?; + token + .recv() + .await + .map_err(|_| RpcError::Shutdown(ShutdownError)) + } + + pub fn num_in_flight(&self) -> usize { + self.response_tracker.num_in_flight() + } +} + +/// A tracker for responses but can be used to track responses for requests that were dispatched +/// via other mechanisms (e.g. ingress flow) +pub struct ResponseTracker +where + T: RpcMessage, +{ + inner: Arc>, +} + +impl Clone for ResponseTracker +where + T: RpcMessage, +{ + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +struct Inner +where + T: RpcMessage, +{ + in_flight: DashMap>, +} + +impl Default for ResponseTracker +where + T: RpcMessage, +{ + fn default() -> Self { + Self { + inner: Arc::new(Inner { + in_flight: Default::default(), + }), + } + } +} + +impl ResponseTracker +where + T: RpcMessage, +{ + pub fn num_in_flight(&self) -> usize { + self.inner.in_flight.len() + } + + /// Returns None if an in-flight request holds the same correlation_id. + pub fn new_token(&self, correlation_id: T::CorrelationId) -> Option> { + let (sender, receiver) = oneshot::channel(); + let existing = self + .inner + .in_flight + .insert(correlation_id.clone(), RpcTokenSender { sender }); + + if existing.is_some() { + // in this extraordinary case, we put the old token back even that it wouldn't really + // guarantee correctness since the response might have arrived by now, but we do it + // anyway as a best hope. + self.inner + .in_flight + .entry(correlation_id.clone()) + .and_modify(|val| *val = existing.unwrap()); + warn!( + "correlation id {:?} was already in-flight when this rpc was issued, this is an indicator that the correlation_id is not unique across RPC calls", + correlation_id + ); + return None; + } + Some(RpcToken { + correlation_id, + router: Arc::downgrade(&self.inner), + receiver: Some(receiver), + }) + } + + /// Returns None if an in-flight request holds the same correlation_id. + pub fn generate_token(&self) -> Option> + where + T::CorrelationId: Default, + { + let correlation_id = T::CorrelationId::default(); + self.new_token(correlation_id) + } + + /// Handle a message through this response tracker. + pub fn handle_message(&self, msg: MessageEnvelope) -> Option> { + // find the token and send, message is dropped on the floor if no valid match exist for the + // correlation id. + if let Some((_, token)) = self.inner.in_flight.remove(&msg.correlation_id()) { + let _ = token.sender.send(msg); + None + } else { + Some(msg) + } + } +} + +pub struct StreamingResponseTracker +where + T: RpcMessage, +{ + flight_tracker: ResponseTracker, + incoming_messages: BoxStream<'static, MessageEnvelope>, +} + +impl StreamingResponseTracker +where + T: RpcMessage, +{ + pub fn new(incoming_messages: BoxStream<'static, MessageEnvelope>) -> Self { + let flight_tracker = ResponseTracker::default(); + Self { + flight_tracker, + incoming_messages, + } + } + + /// Returns None if an in-flight request holds the same correlation_id. + pub fn new_token(&self, correlation_id: T::CorrelationId) -> Option> { + self.flight_tracker.new_token(correlation_id) + } + + /// Returns None if an in-flight request holds the same correlation_id. + pub fn generate_token(&self) -> Option> + where + T::CorrelationId: Default, + { + let correlation_id = T::CorrelationId::default(); + self.new_token(correlation_id) + } + + /// Handles the next message. This will **return** the message if no correlated request is + /// in-flight. Otherwise, it's handled by the corresponding token receiver. + pub async fn handle_next_or_get(&mut self) -> Option> { + tokio::select! { + Some(message) = self.incoming_messages.next() => { + self.flight_tracker.handle_message(message) + }, + _ = cancellation_watcher() => { None }, + else => { None } , + } + } +} + +struct RpcTokenSender { + sender: oneshot::Sender>, +} + +pub struct RpcToken +where + T: RpcMessage, +{ + correlation_id: T::CorrelationId, + router: Weak>, + // This is Option to get around Rust's borrow checker rules when a type implements the Drop + // trait. Without this, we cannot move receiver out. + receiver: Option>>, +} + +impl RpcToken +where + T: RpcMessage, +{ + pub fn correlation_id(&self) -> T::CorrelationId { + self.correlation_id.clone() + } + + /// Awaits the response to come for the associated request. Cancellation safe. + pub async fn recv(mut self) -> Result, ShutdownError> { + let receiver = std::mem::take(&mut self.receiver); + let res = match receiver { + Some(receiver) => { + tokio::select! { + _ = cancellation_watcher() => { + return Err(ShutdownError); + }, + res = receiver => { + res.map_err(|_| ShutdownError) + } + } + } + // Should never happen unless token was created with None which shouldn't be possible + None => Err(ShutdownError), + }; + // If we have received something, we don't need to run drop() since the flight tracker has + // already removed the sender token. + std::mem::forget(self); + res + } +} + +impl Drop for RpcToken +where + T: RpcMessage, +{ + fn drop(&mut self) { + // if the router is gone, we can't do anything. + let Some(router) = self.router.upgrade() else { + return; + }; + let _ = router.in_flight.remove(&self.correlation_id); + } +} + +impl MessageHandler for ResponseTracker +where + T: RpcMessage + WireDecode + Targeted, +{ + type MessageType = T; + + fn on_message( + &self, + msg: restate_node_protocol::MessageEnvelope, + ) -> impl std::future::Future + Send { + self.handle_message(msg); + std::future::ready(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use restate_node_protocol::common::TargetName; + use restate_types::GenerationalNodeId; + + #[derive(Debug, Clone, PartialEq, Eq, Hash)] + struct TestCorrelationId(u64); + + #[derive(Debug, Clone)] + struct TestResponse { + correlation_id: TestCorrelationId, + text: String, + } + + impl RpcMessage for TestResponse { + type CorrelationId = TestCorrelationId; + fn correlation_id(&self) -> Self::CorrelationId { + self.correlation_id.clone() + } + } + + impl Targeted for TestResponse { + const TARGET: TargetName = TargetName::Unknown; + fn kind(&self) -> &'static str { + "TestMessage" + } + } + + impl WireDecode for TestResponse { + fn decode( + _: &mut B, + _: restate_node_protocol::common::ProtocolVersion, + ) -> Result + where + Self: Sized, + { + unimplemented!() + } + } + + #[tokio::test(start_paused = true)] + async fn test_rpc_flight_tracker_drop() { + let tracker = ResponseTracker::::default(); + assert_eq!(tracker.num_in_flight(), 0); + let token = tracker.new_token(TestCorrelationId(1)).unwrap(); + assert_eq!(tracker.num_in_flight(), 1); + drop(token); + assert_eq!(tracker.num_in_flight(), 0); + + let token = tracker.new_token(TestCorrelationId(1)).unwrap(); + assert_eq!(tracker.num_in_flight(), 1); + // receive with timeout, this should drop the token + let start = tokio::time::Instant::now(); + let dur = std::time::Duration::from_millis(500); + let res = tokio::time::timeout(dur, token.recv()).await; + assert!(res.is_err()); + assert!(start.elapsed() >= dur); + assert_eq!(tracker.num_in_flight(), 0); + } + + #[tokio::test(start_paused = true)] + async fn test_rpc_flight_tracker_send_recv() { + let tracker = ResponseTracker::::default(); + assert_eq!(tracker.num_in_flight(), 0); + let token = tracker.new_token(TestCorrelationId(1)).unwrap(); + assert_eq!(tracker.num_in_flight(), 1); + + // dropped on the floor + tracker + .on_message(MessageEnvelope::new( + GenerationalNodeId::new(1, 1), + 22, + TestResponse { + correlation_id: TestCorrelationId(42), + text: "test".to_string(), + }, + )) + .await; + + assert_eq!(tracker.num_in_flight(), 1); + + let maybe_msg = tracker.handle_message(MessageEnvelope::new( + GenerationalNodeId::new(1, 1), + 22, + TestResponse { + correlation_id: TestCorrelationId(42), + text: "test".to_string(), + }, + )); + assert!(maybe_msg.is_some()); + + assert_eq!(tracker.num_in_flight(), 1); + + // matches correlation id + tracker + .on_message(MessageEnvelope::new( + GenerationalNodeId::new(1, 1), + 22, + TestResponse { + correlation_id: TestCorrelationId(1), + text: "a very real message".to_string(), + }, + )) + .await; + + // sender token is dropped + assert_eq!(tracker.num_in_flight(), 0); + + let msg = token.recv().await.unwrap(); + assert_eq!(TestCorrelationId(1), msg.correlation_id()); + let (from, msg) = msg.split(); + assert_eq!(GenerationalNodeId::new(1, 1), from); + assert_eq!("a very real message", msg.text); + } +} diff --git a/crates/node-protocol/src/common.rs b/crates/node-protocol/src/common.rs index da63dc83e9..d983b5ea5d 100644 --- a/crates/node-protocol/src/common.rs +++ b/crates/node-protocol/src/common.rs @@ -8,11 +8,46 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +use std::sync::atomic::AtomicUsize; + include!(concat!(env!("OUT_DIR"), "/dev.restate.common.rs")); pub static MIN_SUPPORTED_PROTOCOL_VERSION: ProtocolVersion = ProtocolVersion::Flexbuffers; pub static CURRENT_PROTOCOL_VERSION: ProtocolVersion = ProtocolVersion::Flexbuffers; +/// Used to identify a request in a RPC-style call going through Networking. +#[derive( + Debug, + derive_more::Display, + PartialEq, + Eq, + Clone, + Copy, + Hash, + PartialOrd, + Ord, + serde::Serialize, + serde::Deserialize, +)] +pub struct RequestId(u64); +impl RequestId { + pub fn new() -> Self { + Default::default() + } +} + +impl Default for RequestId { + fn default() -> Self { + static NEXT_REQUEST_ID: AtomicUsize = AtomicUsize::new(1); + RequestId( + NEXT_REQUEST_ID + .fetch_add(1, std::sync::atomic::Ordering::Relaxed) + .try_into() + .unwrap(), + ) + } +} + pub const FILE_DESCRIPTOR_SET: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/common_descriptor.bin")); @@ -69,3 +104,16 @@ impl From for NodeId { } } } + +// write tests for RequestId +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_request_id() { + let request_id1 = RequestId::new(); + let request_id2 = RequestId::new(); + let request_id3 = RequestId::default(); + assert!(request_id1.0 < request_id2.0 && request_id2.0 < request_id3.0); + } +} diff --git a/crates/node-protocol/src/ingress.rs b/crates/node-protocol/src/ingress.rs index 2a3a430b19..c28824b294 100644 --- a/crates/node-protocol/src/ingress.rs +++ b/crates/node-protocol/src/ingress.rs @@ -8,14 +8,13 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. -use bytes::{Buf, BufMut}; use restate_types::identifiers::{IdempotencyId, InvocationId}; use restate_types::invocation::ResponseResult; use serde::{Deserialize, Serialize}; -use crate::codec::{decode_default, encode_default, Targeted, WireDecode, WireEncode}; -use crate::common::{ProtocolVersion, TargetName}; -use crate::CodecError; +use crate::common::TargetName; +use crate::define_message; +use crate::RpcMessage; #[derive( Debug, @@ -30,32 +29,9 @@ pub enum IngressMessage { InvocationResponse(InvocationResponse), } -impl Targeted for IngressMessage { - const TARGET: TargetName = TargetName::Ingress; - - fn kind(&self) -> &'static str { - self.into() - } -} - -impl WireEncode for IngressMessage { - fn encode( - &self, - buf: &mut B, - protocol_version: ProtocolVersion, - ) -> Result<(), CodecError> { - // serialize message into buf - encode_default(self, buf, protocol_version) - } -} - -impl WireDecode for IngressMessage { - fn decode(buf: &mut B, protocol_version: ProtocolVersion) -> Result - where - Self: Sized, - { - decode_default(buf, protocol_version) - } +define_message! { + @message = IngressMessage, + @target = TargetName::Ingress, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -64,3 +40,21 @@ pub struct InvocationResponse { pub idempotency_id: Option, pub response: ResponseResult, } + +// TODO we could eventually remove this type and replace it with something simpler once +// https://github.com/restatedev/restate/issues/1329 is in place +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum IngressCorrelationId { + InvocationId(InvocationId), + IdempotencyId(IdempotencyId), +} + +impl RpcMessage for InvocationResponse { + type CorrelationId = IngressCorrelationId; + fn correlation_id(&self) -> Self::CorrelationId { + self.idempotency_id + .as_ref() + .map(|idempotency_id| IngressCorrelationId::IdempotencyId(idempotency_id.clone())) + .unwrap_or_else(|| IngressCorrelationId::InvocationId(self.invocation_id)) + } +} diff --git a/crates/node-protocol/src/lib.rs b/crates/node-protocol/src/lib.rs index 4ce5388c67..6f9762e1ce 100644 --- a/crates/node-protocol/src/lib.rs +++ b/crates/node-protocol/src/lib.rs @@ -22,10 +22,11 @@ pub use error::*; use restate_types::GenerationalNodeId; +use self::codec::Targeted; use self::codec::WireDecode; /// A wrapper for a message that includes the sender id -pub struct MessageEnvelope { +pub struct MessageEnvelope { peer: GenerationalNodeId, connection_id: u64, body: M, @@ -39,7 +40,9 @@ impl MessageEnvelope { body, } } +} +impl MessageEnvelope { pub fn connection_id(&self) -> u64 { self.connection_id } @@ -48,3 +51,125 @@ impl MessageEnvelope { (self.peer, self.body) } } + +impl MessageEnvelope { + /// A unique identifier used by RPC-style messages to correlated requests and responses + pub fn correlation_id(&self) -> M::CorrelationId { + self.body.correlation_id() + } +} + +pub trait RpcMessage { + type CorrelationId: Clone + Send + Eq + PartialEq + std::fmt::Debug + std::hash::Hash; + fn correlation_id(&self) -> Self::CorrelationId; +} + +pub trait RpcRequest: RpcMessage + Targeted { + type Response: RpcMessage + Targeted; +} + +// to define a message, we need +// - Message type +// - message target +// +// Example: +// ``` +// define_message! { +// @message = IngressMessage, +// @target = TargetName::Ingress, +// } +// ``` +macro_rules! define_message { + ( + @message = $message:ty, + @target = $target:expr, + ) => { + impl crate::codec::Targeted for $message { + const TARGET: TargetName = $target; + fn kind(&self) -> &'static str { + stringify!($message) + } + } + + impl crate::codec::WireEncode for $message { + fn encode( + &self, + buf: &mut B, + protocol_version: crate::common::ProtocolVersion, + ) -> Result<(), crate::CodecError> { + // serialize message into buf + crate::codec::encode_default(self, buf, protocol_version) + } + } + + impl crate::codec::WireDecode for $message { + fn decode( + buf: &mut B, + protocol_version: crate::common::ProtocolVersion, + ) -> Result + where + Self: Sized, + { + crate::codec::decode_default(buf, protocol_version) + } + } + }; +} + +// to define an RPC, we need +// - Request type +// - request target +// - Response type +// - response Target +// +// Example: +// ``` +// define_rpc! { +// @request = AttachRequest, +// @response = AttachResponse, +// @request_target = TargetName::ClusterController, +// @response_target = TargetName::AttachResponse, +// } +// ``` +#[allow(unused_macros)] +macro_rules! define_rpc { + ( + @request = $request:ty, + @response = $response:ty, + @request_target = $request_target:expr, + @response_target = $response_target:expr, + ) => { + impl crate::RpcRequest for $request { + type Response = $response; + } + + impl crate::RpcMessage for $request { + type CorrelationId = crate::common::RequestId; + + fn correlation_id(&self) -> Self::CorrelationId { + self.request_id + } + } + + impl crate::RpcMessage for $response { + type CorrelationId = crate::common::RequestId; + + fn correlation_id(&self) -> Self::CorrelationId { + self.request_id + } + } + + crate::define_message! { + @message = $request, + @target = $request_target, + } + + crate::define_message! { + @message = $response, + @target = $response_target, + } + }; +} + +#[allow(unused_imports)] +use {define_message, define_rpc}; diff --git a/crates/node-protocol/src/metadata.rs b/crates/node-protocol/src/metadata.rs index 2ed48098a5..30fbb744ca 100644 --- a/crates/node-protocol/src/metadata.rs +++ b/crates/node-protocol/src/metadata.rs @@ -8,7 +8,6 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. -use bytes::{Buf, BufMut}; use enum_map::Enum; pub use restate_schema::{Schema, UpdateableSchema}; use restate_types::logs::metadata::Logs; @@ -17,10 +16,8 @@ use restate_types::partition_table::FixedPartitionTable; use serde::{Deserialize, Serialize}; use strum_macros::EnumIter; -use crate::codec::{decode_default, encode_default, Targeted, WireDecode, WireEncode}; -use crate::common::ProtocolVersion; use crate::common::TargetName; -use crate::CodecError; +use crate::define_message; #[derive( Debug, @@ -36,31 +33,9 @@ pub enum MetadataMessage { MetadataUpdate(MetadataUpdate), } -impl Targeted for MetadataMessage { - const TARGET: TargetName = TargetName::MetadataManager; - - fn kind(&self) -> &'static str { - self.into() - } -} - -impl WireEncode for MetadataMessage { - fn encode( - &self, - buf: &mut B, - protocol_version: ProtocolVersion, - ) -> Result<(), CodecError> { - encode_default(self, buf, protocol_version) - } -} - -impl WireDecode for MetadataMessage { - fn decode(buf: &mut B, protocol_version: ProtocolVersion) -> Result - where - Self: Sized, - { - decode_default(buf, protocol_version) - } +define_message! { + @message = MetadataMessage, + @target = TargetName::MetadataManager, } /// The kind of versioned metadata that can be synchronized across nodes.