From 04408b8c4f5bf96bd81bbc4ce4e96f670ea5ecac Mon Sep 17 00:00:00 2001 From: Nikolai Golub Date: Thu, 24 Aug 2023 20:36:42 +0200 Subject: [PATCH] Fixing `where` bounds for rpc_gen (#726) --- Cargo.lock | 1 + module-system/sov-modules-macros/Cargo.toml | 10 +- .../sov-modules-macros/src/rpc/expose_rpc.rs | 2 +- .../sov-modules-macros/src/rpc/rpc_gen.rs | 17 ++-- .../sov-modules-macros/tests/all_tests.rs | 5 +- .../tests/{ => rpc}/derive_rpc.rs | 2 +- .../tests/rpc/derive_rpc_with_where.rs | 99 +++++++++++++++++++ 7 files changed, 120 insertions(+), 16 deletions(-) rename module-system/sov-modules-macros/tests/{ => rpc}/derive_rpc.rs (98%) create mode 100644 module-system/sov-modules-macros/tests/rpc/derive_rpc_with_where.rs diff --git a/Cargo.lock b/Cargo.lock index 765ddc594..839d787ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6911,6 +6911,7 @@ dependencies = [ "serde_json", "sov-bank", "sov-modules-api", + "sov-modules-macros", "sov-state", "syn 1.0.109", "tempfile", diff --git a/module-system/sov-modules-macros/Cargo.toml b/module-system/sov-modules-macros/Cargo.toml index e94ebc1a4..c90ca0c39 100644 --- a/module-system/sov-modules-macros/Cargo.toml +++ b/module-system/sov-modules-macros/Cargo.toml @@ -20,16 +20,18 @@ name = "tests" path = "tests/all_tests.rs" [dev-dependencies] -serde_json = "1" -tempfile = "3" + +clap = { workspace = true } jsonrpsee = { workspace = true, features = ["macros", "http-client", "server"] } +serde = { workspace = true } +serde_json = { workspace = true } +tempfile = { workspace = true } trybuild = "1.0" sov-modules-api = { path = "../sov-modules-api", version = "0.1" } sov-state = { path = "../sov-state", version = "0.1" } sov-bank = { path = "../module-implementations/sov-bank", version = "0.1", features = ["native"] } -serde = { workspace = true } -clap = { workspace = true } +sov-modules-macros = { path = ".", features = ["native"] } [dependencies] anyhow = { workspace = true } diff --git a/module-system/sov-modules-macros/src/rpc/expose_rpc.rs b/module-system/sov-modules-macros/src/rpc/expose_rpc.rs index 54829eff7..0691eff4d 100644 --- a/module-system/sov-modules-macros/src/rpc/expose_rpc.rs +++ b/module-system/sov-modules-macros/src/rpc/expose_rpc.rs @@ -90,7 +90,7 @@ impl ExposeRpcMacro { } let get_rpc_methods: proc_macro2::TokenStream = quote! { - pub fn get_rpc_methods #impl_generics (storage: <#context_type as ::sov_modules_api::Spec>::Storage) -> jsonrpsee::RpcModule<()> #where_clause{ + pub fn get_rpc_methods #impl_generics (storage: <#context_type as ::sov_modules_api::Spec>::Storage) -> jsonrpsee::RpcModule<()> #where_clause { let mut module = jsonrpsee::RpcModule::new(()); let r = RpcStorage::<#context_type> { storage: storage.clone(), diff --git a/module-system/sov-modules-macros/src/rpc/rpc_gen.rs b/module-system/sov-modules-macros/src/rpc/rpc_gen.rs index fd1e027de..8a0b8d56f 100644 --- a/module-system/sov-modules-macros/src/rpc/rpc_gen.rs +++ b/module-system/sov-modules-macros/src/rpc/rpc_gen.rs @@ -6,7 +6,7 @@ use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::{ parenthesized, Attribute, FnArg, ImplItem, Meta, MetaList, PatType, Path, PathSegment, - Signature, Type, + Signature, }; /// Returns an attribute with the name `rpc_method` replaced with `method`, and the index @@ -64,7 +64,7 @@ fn find_working_set_argument(sig: &Signature) -> Option<(usize, syn::Type)> { struct RpcImplBlock { pub(crate) type_name: Ident, pub(crate) methods: Vec, - pub(crate) working_set_type: Option, + pub(crate) working_set_type: Option, pub(crate) generics: syn::Generics, } @@ -164,14 +164,14 @@ impl RpcImplBlock { let rpc_impl_trait = if let Some(ref working_set_type) = self.working_set_type { quote! { - pub trait #impl_trait_name #generics { + pub trait #impl_trait_name #generics #where_clause { fn get_working_set(&self) -> #working_set_type; #(#impl_trait_methods)* } } } else { quote! { - pub trait #impl_trait_name #generics { + pub trait #impl_trait_name #generics #where_clause { #(#impl_trait_methods)* } } @@ -303,13 +303,16 @@ fn build_rpc_trait( #input }; + let where_clause = &generics.where_clause; + let rpc_output = quote! { #simplified_impl #impl_rpc_trait_impl + #rpc_attribute - pub trait #intermediate_trait_name #generics { + pub trait #intermediate_trait_name #generics #where_clause { #(#intermediate_trait_items)* @@ -335,12 +338,12 @@ pub(crate) fn rpc_gen( build_rpc_trait(attrs, type_name.clone(), input) } -struct TypeList(pub Punctuated); +struct TypeList(pub Punctuated); impl Parse for TypeList { fn parse(input: ParseStream) -> syn::Result { let content; parenthesized!(content in input); - Ok(TypeList(content.parse_terminated(Type::parse)?)) + Ok(TypeList(content.parse_terminated(syn::Type::parse)?)) } } diff --git a/module-system/sov-modules-macros/tests/all_tests.rs b/module-system/sov-modules-macros/tests/all_tests.rs index 09cfafdd3..8d30942cb 100644 --- a/module-system/sov-modules-macros/tests/all_tests.rs +++ b/module-system/sov-modules-macros/tests/all_tests.rs @@ -23,14 +23,13 @@ fn module_dispatch_tests() { t.compile_fail("tests/dispatch/missing_serialization.rs"); } -#[cfg(feature = "native")] #[test] fn rpc_tests() { let t = trybuild::TestCases::new(); - t.pass("tests/derive_rpc.rs"); + t.pass("tests/rpc/derive_rpc.rs"); + t.pass("tests/rpc/derive_rpc_with_where.rs"); } -#[cfg(feature = "native")] #[test] fn cli_wallet_arg_tests() { let t: trybuild::TestCases = trybuild::TestCases::new(); diff --git a/module-system/sov-modules-macros/tests/derive_rpc.rs b/module-system/sov-modules-macros/tests/rpc/derive_rpc.rs similarity index 98% rename from module-system/sov-modules-macros/tests/derive_rpc.rs rename to module-system/sov-modules-macros/tests/rpc/derive_rpc.rs index 760dcdfe8..09c20f729 100644 --- a/module-system/sov-modules-macros/tests/derive_rpc.rs +++ b/module-system/sov-modules-macros/tests/rpc/derive_rpc.rs @@ -107,5 +107,5 @@ fn main() { assert_eq!(result, ()); } - println!("All tests passed!") + println!("All tests passed!"); } diff --git a/module-system/sov-modules-macros/tests/rpc/derive_rpc_with_where.rs b/module-system/sov-modules-macros/tests/rpc/derive_rpc_with_where.rs new file mode 100644 index 000000000..716d8f0e6 --- /dev/null +++ b/module-system/sov-modules-macros/tests/rpc/derive_rpc_with_where.rs @@ -0,0 +1,99 @@ +use std::hash::Hasher; + +use jsonrpsee::core::RpcResult; +use sov_modules_api::default_context::ZkDefaultContext; +use sov_modules_api::macros::rpc_gen; +use sov_modules_api::{Context, ModuleInfo}; +use sov_state::{WorkingSet, ZkStorage}; + +#[derive(ModuleInfo)] +pub struct TestStruct +where + D: std::hash::Hash + + std::clone::Clone + + borsh::BorshSerialize + + borsh::BorshDeserialize + + serde::Serialize + + serde::de::DeserializeOwned + + 'static, +{ + #[address] + pub(crate) address: C::Address, + #[state] + pub(crate) data: ::sov_state::StateValue, +} + +#[rpc_gen(client, server, namespace = "test")] +impl TestStruct +where + D: std::hash::Hash + + std::clone::Clone + + borsh::BorshSerialize + + borsh::BorshDeserialize + + serde::Serialize + + serde::de::DeserializeOwned + + 'static, +{ + #[rpc_method(name = "firstMethod")] + pub fn first_method(&self, _working_set: &mut WorkingSet) -> RpcResult { + Ok(11) + } + + #[rpc_method(name = "secondMethod")] + pub fn second_method( + &self, + result: D, + _working_set: &mut WorkingSet, + ) -> RpcResult<(D, u64)> { + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + let value = result.clone(); + value.hash(&mut hasher); + let hashed_value = hasher.finish(); + + Ok((value, hashed_value)) + } +} + +pub struct TestRuntime { + test_struct: TestStruct, +} + +// This is generated by a macro annotating the state transition runner, +// but we do not have that in scope here so generating the struct manually. +struct RpcStorage { + pub storage: C::Storage, +} + +impl TestStructRpcImpl for RpcStorage { + fn get_working_set( + &self, + ) -> ::sov_state::WorkingSet<::Storage> { + ::sov_state::WorkingSet::new(self.storage.clone()) + } +} + +fn main() { + let storage = ZkStorage::new([1u8; 32]); + let r: RpcStorage = RpcStorage { + storage: storage.clone(), + }; + { + let result = + as TestStructRpcServer>::first_method( + &r, + ) + .unwrap(); + assert_eq!(result, 11); + } + + { + let result = + as TestStructRpcServer>::second_method( + &r, 22, + ) + .unwrap(); + assert_eq!(result, (22, 15733059416522709050)); + } + + println!("All tests passed!"); +}