diff --git a/nativelink-error/src/lib.rs b/nativelink-error/src/lib.rs index 967bf963a..0c4f69d2d 100644 --- a/nativelink-error/src/lib.rs +++ b/nativelink-error/src/lib.rs @@ -441,3 +441,17 @@ impl From for std::io::ErrorKind { } } } + +// Allows for mapping this type into a generic serialization error. +impl serde::ser::Error for Error { + fn custom(msg: T) -> Self { + Self::new(Code::InvalidArgument, msg.to_string()) + } +} + +// Allows for mapping this type into a generic deserialization error. +impl serde::de::Error for Error { + fn custom(msg: T) -> Self { + Self::new(Code::InvalidArgument, msg.to_string()) + } +} diff --git a/nativelink-util/src/action_messages.rs b/nativelink-util/src/action_messages.rs index 2a1ae9cbf..031ddc849 100644 --- a/nativelink-util/src/action_messages.rs +++ b/nativelink-util/src/action_messages.rs @@ -29,6 +29,7 @@ use nativelink_proto::google::rpc::Status; use prost::bytes::Bytes; use prost::Message; use prost_types::Any; +use serde::ser::Error as SerdeError; use serde::{Deserialize, Serialize}; use uuid::Uuid; @@ -746,7 +747,7 @@ impl Default for ActionResult { // TODO(allada) Remove the need for clippy argument by making the ActionResult and ProtoActionResult // a Box. /// The execution status/stage. This should match `ExecutionStage::Value` in `remote_execution.proto`. -#[derive(PartialEq, Debug, Clone)] +#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)] #[allow(clippy::large_enum_variant)] pub enum ActionStage { /// Stage is unknown. @@ -762,9 +763,22 @@ pub enum ActionStage { /// Worker completed the work with result. Completed(ActionResult), /// Result was found from cache, don't decode the proto just to re-encode it. + #[serde(serialize_with = "serialize_proto_result", skip_deserializing)] + // The serialization step decodes this to an ActionResult which is serializable. + // Since it will always be serialized as an ActionResult, we do not need to support + // deserialization on this type at all. + // In theory, serializing this should never happen so performance shouldn't be affected. CompletedFromCache(ProtoActionResult), } +fn serialize_proto_result(v: &ProtoActionResult, serializer: S) -> Result +where + S: serde::Serializer, +{ + let s = ActionResult::try_from(v.clone()).map_err(S::Error::custom)?; + s.serialize(serializer) +} + impl ActionStage { pub const fn has_action_result(&self) -> bool { match self { @@ -1075,7 +1089,7 @@ where /// Current state of the action. /// This must be 100% compatible with `Operation` in `google/longrunning/operations.proto`. -#[derive(PartialEq, Debug, Clone)] +#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)] pub struct ActionState { pub stage: ActionStage, pub id: OperationId,