From fa79f9737facc4617ac8a2b9961bc8bd9cf584b2 Mon Sep 17 00:00:00 2001 From: Milos Djurica Date: Mon, 30 Dec 2024 14:17:07 +0100 Subject: [PATCH] changed codebase to work with Option instead of OnceCell --- Cargo.toml | 2 +- benches/compressed-snark.rs | 16 ++-- benches/ppsnark.rs | 4 +- benches/recursive-snark.rs | 12 +-- benches/sha256.rs | 6 +- examples/and.rs | 14 ++-- examples/hashchain.rs | 14 ++-- examples/minroot.rs | 14 ++-- src/lib.rs | 150 ++++++++++++++++++++---------------- src/r1cs/mod.rs | 32 +++++--- src/spartan/direct.rs | 10 +-- src/spartan/ppsnark.rs | 40 ++++++---- src/spartan/snark.rs | 44 +++++++---- src/traits/snark.rs | 5 +- 14 files changed, 206 insertions(+), 157 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1a0ccd78..e579327f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,7 +48,7 @@ bincode = { version = "1.3", optional = true, default-features = false } rayon = "1.10" halo2curves = { version = "0.6.0", features = ["bits", "derive_serde"] } # once_cell = { version = "1.18.0", default-features = false } -once_cell = "1.18.0" +# once_cell = "1.18.0" [target.'cfg(target_arch = "wasm32")'.dependencies] diff --git a/benches/compressed-snark.rs b/benches/compressed-snark.rs index 5119bf4e..e44878e8 100644 --- a/benches/compressed-snark.rs +++ b/benches/compressed-snark.rs @@ -67,7 +67,7 @@ fn bench_compressed_snark_internal, S2: RelaxedR1C let c_secondary = TrivialCircuit::default(); // Produce public parameters - let pp = PublicParams::::setup( + let mut pp = PublicParams::::setup( &c_primary, &c_secondary, &*S1::ck_floor(), @@ -76,12 +76,12 @@ fn bench_compressed_snark_internal, S2: RelaxedR1C .unwrap(); // Produce prover and verifier keys for CompressedSNARK - let (pk, vk) = CompressedSNARK::<_, _, _, _, S1, S2>::setup(&pp).unwrap(); + let (pk, mut vk) = CompressedSNARK::<_, _, _, _, S1, S2>::setup(&mut pp).unwrap(); // produce a recursive SNARK let num_steps = 3; let mut recursive_snark: RecursiveSNARK = RecursiveSNARK::new( - &pp, + &mut pp, &c_primary, &c_secondary, &[::Scalar::from(2u64)], @@ -90,12 +90,12 @@ fn bench_compressed_snark_internal, S2: RelaxedR1C .unwrap(); for i in 0..num_steps { - let res = recursive_snark.prove_step(&pp, &c_primary, &c_secondary); + let res = recursive_snark.prove_step(&mut pp, &c_primary, &c_secondary); assert!(res.is_ok()); // verify the recursive snark at each step of recursion let res = recursive_snark.verify( - &pp, + &mut pp, i + 1, &[::Scalar::from(2u64)], &[::Scalar::from(2u64)], @@ -107,14 +107,14 @@ fn bench_compressed_snark_internal, S2: RelaxedR1C group.bench_function("Prove", |b| { b.iter(|| { assert!(CompressedSNARK::<_, _, _, _, S1, S2>::prove( - black_box(&pp), + black_box(&mut pp), black_box(&pk), black_box(&recursive_snark), ) .is_ok()); }) }); - let res = CompressedSNARK::<_, _, _, _, S1, S2>::prove(&pp, &pk, &recursive_snark); + let res = CompressedSNARK::<_, _, _, _, S1, S2>::prove(&mut pp, &pk, &recursive_snark); assert!(res.is_ok()); let compressed_snark = res.unwrap(); @@ -123,7 +123,7 @@ fn bench_compressed_snark_internal, S2: RelaxedR1C b.iter(|| { assert!(black_box(&compressed_snark) .verify( - black_box(&vk), + black_box(&mut vk), black_box(num_steps), black_box(&[::Scalar::from(2u64)]), black_box(&[::Scalar::from(2u64)]), diff --git a/benches/ppsnark.rs b/benches/ppsnark.rs index b9f512db..6ee772d9 100644 --- a/benches/ppsnark.rs +++ b/benches/ppsnark.rs @@ -49,7 +49,7 @@ fn bench_ppsnark(c: &mut Criterion) { let input = vec![::Scalar::from(42)]; // produce keys - let (pk, vk) = + let (pk, mut vk) = DirectSNARK::::Scalar>>::setup(c.clone()).unwrap(); // Bench time to produce a ppSNARK; @@ -75,7 +75,7 @@ fn bench_ppsnark(c: &mut Criterion) { let ppsnark = DirectSNARK::prove(&pk, c.clone(), &input).unwrap(); group.bench_function("Verify", |b| { b.iter(|| { - assert!(ppsnark.verify(black_box(&vk), black_box(&io),).is_ok()); + assert!(ppsnark.verify(black_box(&mut vk), black_box(&io),).is_ok()); }); }); group.finish(); diff --git a/benches/recursive-snark.rs b/benches/recursive-snark.rs index 9ced2703..3f9dfdc3 100644 --- a/benches/recursive-snark.rs +++ b/benches/recursive-snark.rs @@ -70,7 +70,7 @@ fn bench_recursive_snark(c: &mut Criterion) { let c_secondary = TrivialCircuit::default(); // Produce public parameters - let pp = PublicParams::::setup( + let mut pp = PublicParams::::setup( &c_primary, &c_secondary, &*default_ck_hint(), @@ -84,7 +84,7 @@ fn bench_recursive_snark(c: &mut Criterion) { // a lot of zeros in the satisfying assignment let num_warmup_steps = 10; let mut recursive_snark: RecursiveSNARK = RecursiveSNARK::new( - &pp, + &mut pp, &c_primary, &c_secondary, &[::Scalar::from(2u64)], @@ -93,12 +93,12 @@ fn bench_recursive_snark(c: &mut Criterion) { .unwrap(); for i in 0..num_warmup_steps { - let res = recursive_snark.prove_step(&pp, &c_primary, &c_secondary); + let res = recursive_snark.prove_step(&mut pp, &c_primary, &c_secondary); assert!(res.is_ok()); // verify the recursive snark at each step of recursion let res = recursive_snark.verify( - &pp, + &mut pp, i + 1, &[::Scalar::from(2u64)], &[::Scalar::from(2u64)], @@ -111,7 +111,7 @@ fn bench_recursive_snark(c: &mut Criterion) { // produce a recursive SNARK for a step of the recursion assert!(black_box(&mut recursive_snark.clone()) .prove_step( - black_box(&pp), + black_box(&mut pp), black_box(&c_primary), black_box(&c_secondary), ) @@ -124,7 +124,7 @@ fn bench_recursive_snark(c: &mut Criterion) { b.iter(|| { assert!(black_box(&recursive_snark) .verify( - black_box(&pp), + black_box(&mut pp), black_box(num_warmup_steps), black_box(&[::Scalar::from(2u64)]), black_box(&[::Scalar::from(2u64)]), diff --git a/benches/sha256.rs b/benches/sha256.rs index bfa17336..8266fe10 100644 --- a/benches/sha256.rs +++ b/benches/sha256.rs @@ -155,7 +155,7 @@ fn bench_recursive_snark(c: &mut Criterion) { // Produce public parameters let ttc = TrivialCircuit::default(); - let pp = PublicParams::::setup( + let mut pp = PublicParams::::setup( &circuit_primary, &ttc, &*default_ck_hint(), @@ -170,7 +170,7 @@ fn bench_recursive_snark(c: &mut Criterion) { group.bench_function("Prove", |b| { b.iter(|| { let mut recursive_snark = RecursiveSNARK::new( - black_box(&pp), + black_box(&mut pp), black_box(&circuit_primary), black_box(&circuit_secondary), black_box(&z0_primary), @@ -181,7 +181,7 @@ fn bench_recursive_snark(c: &mut Criterion) { // produce a recursive SNARK for a step of the recursion assert!(recursive_snark .prove_step( - black_box(&pp), + black_box(&mut pp), black_box(&circuit_primary), black_box(&circuit_secondary), ) diff --git a/examples/and.rs b/examples/and.rs index dff56191..1a256f8d 100644 --- a/examples/and.rs +++ b/examples/and.rs @@ -220,7 +220,7 @@ fn main() { // produce public parameters let start = Instant::now(); println!("Producing public parameters..."); - let pp = PublicParams::< + let mut pp = PublicParams::< E1, E2, AndCircuit<::GE>, @@ -264,7 +264,7 @@ fn main() { println!("Generating a RecursiveSNARK..."); let mut recursive_snark: RecursiveSNARK = RecursiveSNARK::::new( - &pp, + &mut pp, &circuits[0], &circuit_secondary, &[::Scalar::zero()], @@ -274,7 +274,7 @@ fn main() { let start = Instant::now(); for circuit_primary in circuits.iter() { - let res = recursive_snark.prove_step(&pp, circuit_primary, &circuit_secondary); + let res = recursive_snark.prove_step(&mut pp, circuit_primary, &circuit_secondary); assert!(res.is_ok()); } println!( @@ -286,7 +286,7 @@ fn main() { // verify the recursive SNARK println!("Verifying a RecursiveSNARK..."); let res = recursive_snark.verify( - &pp, + &mut pp, num_steps, &[::Scalar::ZERO], &[::Scalar::ZERO], @@ -296,11 +296,11 @@ fn main() { // produce a compressed SNARK println!("Generating a CompressedSNARK using Spartan with HyperKZG..."); - let (pk, vk) = CompressedSNARK::<_, _, _, _, S1, S2>::setup(&pp).unwrap(); + let (pk, mut vk) = CompressedSNARK::<_, _, _, _, S1, S2>::setup(&mut pp).unwrap(); let start = Instant::now(); - let res = CompressedSNARK::<_, _, _, _, S1, S2>::prove(&pp, &pk, &recursive_snark); + let res = CompressedSNARK::<_, _, _, _, S1, S2>::prove(&mut pp, &pk, &recursive_snark); println!( "CompressedSNARK::prove: {:?}, took {:?}", res.is_ok(), @@ -324,7 +324,7 @@ fn main() { println!("Verifying a CompressedSNARK..."); let start = Instant::now(); let res = compressed_snark.verify( - &vk, + &mut vk, num_steps, &[::Scalar::ZERO], &[::Scalar::ZERO], diff --git a/examples/hashchain.rs b/examples/hashchain.rs index 0398890f..afcae2d4 100644 --- a/examples/hashchain.rs +++ b/examples/hashchain.rs @@ -113,7 +113,7 @@ fn main() { // produce public parameters let start = Instant::now(); println!("Producing public parameters..."); - let pp = PublicParams::< + let mut pp = PublicParams::< E1, E2, HashChainCircuit<::GE>, @@ -159,7 +159,7 @@ fn main() { ); let mut recursive_snark: RecursiveSNARK = RecursiveSNARK::::new( - &pp, + &mut pp, &circuits[0], &circuit_secondary, &[::Scalar::zero()], @@ -169,7 +169,7 @@ fn main() { for (i, circuit_primary) in circuits.iter().enumerate() { let start = Instant::now(); - let res = recursive_snark.prove_step(&pp, circuit_primary, &circuit_secondary); + let res = recursive_snark.prove_step(&mut pp, circuit_primary, &circuit_secondary); assert!(res.is_ok()); println!("RecursiveSNARK::prove {} : took {:?} ", i, start.elapsed()); @@ -178,7 +178,7 @@ fn main() { // verify the recursive SNARK println!("Verifying a RecursiveSNARK..."); let res = recursive_snark.verify( - &pp, + &mut pp, num_steps, &[::Scalar::ZERO], &[::Scalar::ZERO], @@ -188,11 +188,11 @@ fn main() { // produce a compressed SNARK println!("Generating a CompressedSNARK using Spartan with HyperKZG..."); - let (pk, vk) = CompressedSNARK::<_, _, _, _, S1, S2>::setup(&pp).unwrap(); + let (pk, mut vk) = CompressedSNARK::<_, _, _, _, S1, S2>::setup(&mut pp).unwrap(); let start = Instant::now(); - let res = CompressedSNARK::<_, _, _, _, S1, S2>::prove(&pp, &pk, &recursive_snark); + let res = CompressedSNARK::<_, _, _, _, S1, S2>::prove(&mut pp, &pk, &recursive_snark); println!( "CompressedSNARK::prove: {:?}, took {:?}", res.is_ok(), @@ -216,7 +216,7 @@ fn main() { println!("Verifying a CompressedSNARK..."); let start = Instant::now(); let res = compressed_snark.verify( - &vk, + &mut vk, num_steps, &[::Scalar::ZERO], &[::Scalar::ZERO], diff --git a/examples/minroot.rs b/examples/minroot.rs index 25ee7da8..989d6b1f 100644 --- a/examples/minroot.rs +++ b/examples/minroot.rs @@ -165,7 +165,7 @@ fn main() { // produce public parameters let start = Instant::now(); println!("Producing public parameters..."); - let pp = PublicParams::< + let mut pp = PublicParams::< E1, E2, MinRootCircuit<::GE>, @@ -224,7 +224,7 @@ fn main() { println!("Generating a RecursiveSNARK..."); let mut recursive_snark: RecursiveSNARK = RecursiveSNARK::::new( - &pp, + &mut pp, &minroot_circuits[0], &circuit_secondary, &z0_primary, @@ -234,7 +234,7 @@ fn main() { for (i, circuit_primary) in minroot_circuits.iter().enumerate() { let start = Instant::now(); - let res = recursive_snark.prove_step(&pp, circuit_primary, &circuit_secondary); + let res = recursive_snark.prove_step(&mut pp, circuit_primary, &circuit_secondary); assert!(res.is_ok()); println!( "RecursiveSNARK::prove_step {}: {:?}, took {:?} ", @@ -247,7 +247,7 @@ fn main() { // verify the recursive SNARK println!("Verifying a RecursiveSNARK..."); let start = Instant::now(); - let res = recursive_snark.verify(&pp, num_steps, &z0_primary, &z0_secondary); + let res = recursive_snark.verify(&mut pp, num_steps, &z0_primary, &z0_secondary); println!( "RecursiveSNARK::verify: {:?}, took {:?}", res.is_ok(), @@ -257,11 +257,11 @@ fn main() { // produce a compressed SNARK println!("Generating a CompressedSNARK using Spartan with HyperKZG..."); - let (pk, vk) = CompressedSNARK::<_, _, _, _, S1, S2>::setup(&pp).unwrap(); + let (pk, mut vk) = CompressedSNARK::<_, _, _, _, S1, S2>::setup(&mut pp).unwrap(); let start = Instant::now(); - let res = CompressedSNARK::<_, _, _, _, S1, S2>::prove(&pp, &pk, &recursive_snark); + let res = CompressedSNARK::<_, _, _, _, S1, S2>::prove(&mut pp, &pk, &recursive_snark); println!( "CompressedSNARK::prove: {:?}, took {:?}", res.is_ok(), @@ -284,7 +284,7 @@ fn main() { // verify the compressed SNARK println!("Verifying a CompressedSNARK..."); let start = Instant::now(); - let res = compressed_snark.verify(&vk, num_steps, &z0_primary, &z0_secondary); + let res = compressed_snark.verify(&mut vk, num_steps, &z0_primary, &z0_secondary); println!( "CompressedSNARK::verify: {:?}, took {:?}", res.is_ok(), diff --git a/src/lib.rs b/src/lib.rs index f5a397a6..51d3dc94 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -70,7 +70,7 @@ use frontend::{ }; use gadgets::utils::scalar_as_base; use nifs::{NIFSRelaxed, NIFS}; -use once_cell::sync::OnceCell; +// use once_cell::sync::OnceCell; use prelude::*; use r1cs::{ CommitmentKeyHint, R1CSInstance, R1CSShape, R1CSWitness, RelaxedR1CSInstance, RelaxedR1CSWitness, @@ -83,7 +83,7 @@ use traits::{ }; /// A type that holds public parameters of Nova -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] #[serde(bound = "")] pub struct PublicParams where @@ -104,8 +104,8 @@ where r1cs_shape_secondary: R1CSShape, augmented_circuit_params_primary: NovaAugmentedCircuitParams, augmented_circuit_params_secondary: NovaAugmentedCircuitParams, - #[serde(skip, default = "OnceCell::new")] - digest: OnceCell, + #[serde(skip)] + digest: Option, _p: PhantomData<(C1, C2)>, } @@ -215,7 +215,7 @@ where return Err(NovaError::InvalidStepCircuitIO); } - let pp = PublicParams { + let mut pp = PublicParams { F_arity_primary, F_arity_secondary, ro_consts_primary, @@ -228,23 +228,31 @@ where r1cs_shape_secondary, augmented_circuit_params_primary, augmented_circuit_params_secondary, - digest: OnceCell::new(), + digest: None, _p: Default::default(), }; // call pp.digest() so the digest is computed here rather than in RecursiveSNARK methods - let _ = pp.digest(); + pp.digest(); Ok(pp) } /// Retrieve the digest of the public parameters. - pub fn digest(&self) -> E1::Scalar { - self - .digest - .get_or_try_init(|| DigestComputer::new(self).digest()) - .cloned() - .expect("Failure in retrieving digest") + pub fn digest(&mut self) -> E1::Scalar { + if self.digest.is_none() { + let computed_digest = DigestComputer::new(self) + .digest() + .expect("Failure in retrieving digest"); + self.digest = Some(computed_digest); + } + self.digest.unwrap() + + // self + // .digest + // .get_or_try_init(|| DigestComputer::new(self).digest()) + // .cloned() + // .expect("Failure in retrieving digest") } /// Returns the number of constraints in the primary and secondary circuits @@ -299,7 +307,7 @@ where { /// Create new instance of recursive SNARK pub fn new( - pp: &PublicParams, + pp: &mut PublicParams, c_primary: &C1, c_secondary: &C2, z0_primary: &[E1::Scalar], @@ -410,7 +418,7 @@ where /// by executing a step of the incremental computation pub fn prove_step( &mut self, - pp: &PublicParams, + pp: &mut PublicParams, c_primary: &C1, c_secondary: &C2, ) -> Result<(), NovaError> { @@ -420,10 +428,11 @@ where return Ok(()); } + let pp_clone = pp.clone(); // fold the secondary circuit's instance let (nifs_secondary, (r_U_secondary, r_W_secondary)) = NIFS::prove( - &pp.ck_secondary, - &pp.ro_consts_secondary, + &pp_clone.ck_secondary, + &pp_clone.ro_consts_secondary, &scalar_as_base::(pp.digest()), &pp.r1cs_shape_secondary, &self.r_U_secondary, @@ -458,10 +467,11 @@ where let (l_u_primary, l_w_primary) = cs_primary.r1cs_instance_and_witness(&pp.r1cs_shape_primary, &pp.ck_primary)?; + let pp_clone = pp.clone(); // fold the primary circuit's instance let (nifs_primary, (r_U_primary, r_W_primary)) = NIFS::prove( - &pp.ck_primary, - &pp.ro_consts_primary, + &pp_clone.ck_primary, + &pp_clone.ro_consts_primary, &pp.digest(), &pp.r1cs_shape_primary, &self.r_U_primary, @@ -527,7 +537,7 @@ where /// Verify the correctness of the `RecursiveSNARK` pub fn verify( &self, - pp: &PublicParams, + pp: &mut PublicParams, num_steps: usize, z0_primary: &[E1::Scalar], z0_secondary: &[E2::Scalar], @@ -734,7 +744,7 @@ where { /// Creates prover and verifier keys for `CompressedSNARK` pub fn setup( - pp: &PublicParams, + pp: &mut PublicParams, ) -> Result< ( ProverKey, @@ -769,16 +779,17 @@ where /// Create a new `CompressedSNARK` (provides zero-knowledge) pub fn prove( - pp: &PublicParams, + pp: &mut PublicParams, pk: &ProverKey, recursive_snark: &RecursiveSNARK, ) -> Result { // prove three foldings + let pp_clone = pp.clone(); // fold secondary U/W with secondary u/w to get Uf/Wf let (nifs_Uf_secondary, (r_Uf_secondary, r_Wf_secondary)) = NIFS::prove( - &pp.ck_secondary, - &pp.ro_consts_secondary, + &pp_clone.ck_secondary, + &pp_clone.ro_consts_secondary, &scalar_as_base::(pp.digest()), &pp.r1cs_shape_secondary, &recursive_snark.r_U_secondary, @@ -792,9 +803,10 @@ where .r1cs_shape_secondary .sample_random_instance_witness(&pp.ck_secondary)?; + let pp_clone = pp.clone(); let (nifs_Un_secondary, (r_Un_secondary, r_Wn_secondary)) = NIFSRelaxed::prove( - &pp.ck_secondary, - &pp.ro_consts_secondary, + &pp_clone.ck_secondary, + &pp_clone.ro_consts_secondary, &scalar_as_base::(pp.digest()), &pp.r1cs_shape_secondary, &r_Uf_secondary, @@ -808,9 +820,10 @@ where .r1cs_shape_primary .sample_random_instance_witness(&pp.ck_primary)?; + let pp_clone = pp.clone(); let (nifs_Un_primary, (r_Un_primary, r_Wn_primary)) = NIFSRelaxed::prove( - &pp.ck_primary, - &pp.ro_consts_primary, + &pp_clone.ck_primary, + &pp_clone.ro_consts_primary, &pp.digest(), &pp.r1cs_shape_primary, &recursive_snark.r_U_primary, @@ -890,7 +903,7 @@ where /// Verify the correctness of the `CompressedSNARK` (provides zero-knowledge) pub fn verify( &self, - vk: &VerifierKey, + vk: &mut VerifierKey, num_steps: usize, z0_primary: &[E1::Scalar], z0_secondary: &[E2::Scalar], @@ -996,12 +1009,12 @@ where || { self .snark_primary - .verify(&vk.vk_primary, &derandom_r_Un_primary) + .verify(&mut vk.vk_primary, &derandom_r_Un_primary) }, || { self .snark_secondary - .verify(&vk.vk_secondary, &derandom_r_Un_secondary) + .verify(&mut vk.vk_secondary, &derandom_r_Un_secondary) }, ); @@ -1100,7 +1113,8 @@ mod tests { // this tests public parameters with a size specifically intended for a spark-compressed SNARK let ck_hint1 = &*SPrime::>::ck_floor(); let ck_hint2 = &*SPrime::>::ck_floor(); - let pp = PublicParams::::setup(circuit1, circuit2, ck_hint1, ck_hint2).unwrap(); + let mut pp = + PublicParams::::setup(circuit1, circuit2, ck_hint1, ck_hint2).unwrap(); let digest_str = pp .digest() @@ -1144,7 +1158,7 @@ mod tests { let test_circuit2 = TrivialCircuit::<::Scalar>::default(); // produce public parameters - let pp = PublicParams::< + let mut pp = PublicParams::< E1, E2, TrivialCircuit<::Scalar>, @@ -1161,7 +1175,7 @@ mod tests { // produce a recursive SNARK let mut recursive_snark = RecursiveSNARK::new( - &pp, + &mut pp, &test_circuit1, &test_circuit2, &[::Scalar::ZERO], @@ -1169,13 +1183,13 @@ mod tests { ) .unwrap(); - let res = recursive_snark.prove_step(&pp, &test_circuit1, &test_circuit2); + let res = recursive_snark.prove_step(&mut pp, &test_circuit1, &test_circuit2); assert!(res.is_ok()); // verify the recursive SNARK let res = recursive_snark.verify( - &pp, + &mut pp, num_steps, &[::Scalar::ZERO], &[::Scalar::ZERO], @@ -1199,7 +1213,7 @@ mod tests { let circuit_secondary = CubicCircuit::default(); // produce public parameters - let pp = PublicParams::< + let mut pp = PublicParams::< E1, E2, TrivialCircuit<::Scalar>, @@ -1221,7 +1235,7 @@ mod tests { TrivialCircuit<::Scalar>, CubicCircuit<::Scalar>, >::new( - &pp, + &mut pp, &circuit_primary, &circuit_secondary, &[::Scalar::ONE], @@ -1230,12 +1244,12 @@ mod tests { .unwrap(); for i in 0..num_steps { - let res = recursive_snark.prove_step(&pp, &circuit_primary, &circuit_secondary); + let res = recursive_snark.prove_step(&mut pp, &circuit_primary, &circuit_secondary); assert!(res.is_ok()); // verify the recursive snark at each step of recursion let res = recursive_snark.verify( - &pp, + &mut pp, i + 1, &[::Scalar::ONE], &[::Scalar::ZERO], @@ -1245,7 +1259,7 @@ mod tests { // verify the recursive SNARK let res = recursive_snark.verify( - &pp, + &mut pp, num_steps, &[::Scalar::ONE], &[::Scalar::ZERO], @@ -1282,7 +1296,7 @@ mod tests { let circuit_secondary = CubicCircuit::default(); // produce public parameters - let pp = PublicParams::< + let mut pp = PublicParams::< E1, E2, TrivialCircuit<::Scalar>, @@ -1304,7 +1318,7 @@ mod tests { TrivialCircuit<::Scalar>, CubicCircuit<::Scalar>, >::new( - &pp, + &mut pp, &circuit_primary, &circuit_secondary, &[::Scalar::ONE], @@ -1313,13 +1327,13 @@ mod tests { .unwrap(); for _i in 0..num_steps { - let res = recursive_snark.prove_step(&pp, &circuit_primary, &circuit_secondary); + let res = recursive_snark.prove_step(&mut pp, &circuit_primary, &circuit_secondary); assert!(res.is_ok()); } // verify the recursive SNARK let res = recursive_snark.verify( - &pp, + &mut pp, num_steps, &[::Scalar::ONE], &[::Scalar::ZERO], @@ -1338,17 +1352,18 @@ mod tests { assert_eq!(zn_secondary, vec![::Scalar::from(2460515u64)]); // produce the prover and verifier keys for compressed snark - let (pk, vk) = CompressedSNARK::<_, _, _, _, S, S>::setup(&pp).unwrap(); + let (pk, mut vk) = + CompressedSNARK::<_, _, _, _, S, S>::setup(&mut pp).unwrap(); // produce a compressed SNARK let res = - CompressedSNARK::<_, _, _, _, S, S>::prove(&pp, &pk, &recursive_snark); + CompressedSNARK::<_, _, _, _, S, S>::prove(&mut pp, &pk, &recursive_snark); assert!(res.is_ok()); let compressed_snark = res.unwrap(); // verify the compressed SNARK let res = compressed_snark.verify( - &vk, + &mut vk, num_steps, &[::Scalar::ONE], &[::Scalar::ZERO], @@ -1382,7 +1397,7 @@ mod tests { let circuit_secondary = CubicCircuit::default(); // produce public parameters, which we'll use with a spark-compressed SNARK - let pp = PublicParams::< + let mut pp = PublicParams::< E1, E2, TrivialCircuit<::Scalar>, @@ -1404,7 +1419,7 @@ mod tests { TrivialCircuit<::Scalar>, CubicCircuit<::Scalar>, >::new( - &pp, + &mut pp, &circuit_primary, &circuit_secondary, &[::Scalar::ONE], @@ -1413,13 +1428,13 @@ mod tests { .unwrap(); for _i in 0..num_steps { - let res = recursive_snark.prove_step(&pp, &circuit_primary, &circuit_secondary); + let res = recursive_snark.prove_step(&mut pp, &circuit_primary, &circuit_secondary); assert!(res.is_ok()); } // verify the recursive SNARK let res = recursive_snark.verify( - &pp, + &mut pp, num_steps, &[::Scalar::ONE], &[::Scalar::ZERO], @@ -1439,12 +1454,12 @@ mod tests { // run the compressed snark with Spark compiler // produce the prover and verifier keys for compressed snark - let (pk, vk) = - CompressedSNARK::<_, _, _, _, SPrime, SPrime>::setup(&pp).unwrap(); + let (pk, mut vk) = + CompressedSNARK::<_, _, _, _, SPrime, SPrime>::setup(&mut pp).unwrap(); // produce a compressed SNARK let res = CompressedSNARK::<_, _, _, _, SPrime, SPrime>::prove( - &pp, + &mut pp, &pk, &recursive_snark, ); @@ -1453,7 +1468,7 @@ mod tests { // verify the compressed SNARK let res = compressed_snark.verify( - &vk, + &mut vk, num_steps, &[::Scalar::ONE], &[::Scalar::ZERO], @@ -1545,7 +1560,7 @@ mod tests { let circuit_secondary = TrivialCircuit::default(); // produce public parameters - let pp = PublicParams::< + let mut pp = PublicParams::< E1, E2, FifthRootCheckingCircuit<::Scalar>, @@ -1576,7 +1591,7 @@ mod tests { FifthRootCheckingCircuit<::Scalar>, TrivialCircuit<::Scalar>, >::new( - &pp, + &mut pp, &roots[0], &circuit_secondary, &z0_primary, @@ -1585,25 +1600,26 @@ mod tests { .unwrap(); for circuit_primary in roots.iter().take(num_steps) { - let res = recursive_snark.prove_step(&pp, circuit_primary, &circuit_secondary); + let res = recursive_snark.prove_step(&mut pp, circuit_primary, &circuit_secondary); assert!(res.is_ok()); } // verify the recursive SNARK - let res = recursive_snark.verify(&pp, num_steps, &z0_primary, &z0_secondary); + let res = recursive_snark.verify(&mut pp, num_steps, &z0_primary, &z0_secondary); assert!(res.is_ok()); // produce the prover and verifier keys for compressed snark - let (pk, vk) = CompressedSNARK::<_, _, _, _, S, S>::setup(&pp).unwrap(); + let (pk, mut vk) = + CompressedSNARK::<_, _, _, _, S, S>::setup(&mut pp).unwrap(); // produce a compressed SNARK let res = - CompressedSNARK::<_, _, _, _, S, S>::prove(&pp, &pk, &recursive_snark); + CompressedSNARK::<_, _, _, _, S, S>::prove(&mut pp, &pk, &recursive_snark); assert!(res.is_ok()); let compressed_snark = res.unwrap(); // verify the compressed SNARK - let res = compressed_snark.verify(&vk, num_steps, &z0_primary, &z0_secondary); + let res = compressed_snark.verify(&mut vk, num_steps, &z0_primary, &z0_secondary); assert!(res.is_ok()); } @@ -1623,7 +1639,7 @@ mod tests { let test_circuit2 = CubicCircuit::<::Scalar>::default(); // produce public parameters - let pp = PublicParams::< + let mut pp = PublicParams::< E1, E2, TrivialCircuit<::Scalar>, @@ -1645,7 +1661,7 @@ mod tests { TrivialCircuit<::Scalar>, CubicCircuit<::Scalar>, >::new( - &pp, + &mut pp, &test_circuit1, &test_circuit2, &[::Scalar::ONE], @@ -1654,13 +1670,13 @@ mod tests { .unwrap(); // produce a recursive SNARK - let res = recursive_snark.prove_step(&pp, &test_circuit1, &test_circuit2); + let res = recursive_snark.prove_step(&mut pp, &test_circuit1, &test_circuit2); assert!(res.is_ok()); // verify the recursive SNARK let res = recursive_snark.verify( - &pp, + &mut pp, num_steps, &[::Scalar::ONE], &[::Scalar::ZERO], diff --git a/src/r1cs/mod.rs b/src/r1cs/mod.rs index 263ef4fa..e8e2af78 100644 --- a/src/r1cs/mod.rs +++ b/src/r1cs/mod.rs @@ -15,7 +15,7 @@ use crate::{ }; use core::{cmp::max, marker::PhantomData}; use ff::Field; -use once_cell::sync::OnceCell; +// use once_cell::sync::OnceCell; use rand_core::OsRng; use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -39,8 +39,8 @@ pub struct R1CSShape { pub(crate) A: SparseMatrix, pub(crate) B: SparseMatrix, pub(crate) C: SparseMatrix, - #[serde(skip, default = "OnceCell::new")] - pub(crate) digest: OnceCell, + #[serde(skip)] + pub(crate) digest: Option, } impl SimpleDigestible for R1CSShape {} @@ -138,17 +138,25 @@ impl R1CSShape { A, B, C, - digest: OnceCell::new(), + digest: None, }) } /// returned the digest of the `R1CSShape` - pub fn digest(&self) -> E::Scalar { - self - .digest - .get_or_try_init(|| DigestComputer::new(self).digest()) - .cloned() - .expect("Failure retrieving digest") + pub fn digest(&mut self) -> E::Scalar { + if self.digest.is_none() { + let computed_digest = DigestComputer::new(self) + .digest() + .expect("Failure in retrieving digest"); + self.digest = Some(computed_digest); + } + self.digest.unwrap() + + // self + // .digest + // .get_or_try_init(|| DigestComputer::new(self).digest()) + // .cloned() + // .expect("Failure retrieving digest") } // Checks regularity conditions on the R1CSShape, required in Spartan-class SNARKs @@ -344,7 +352,7 @@ impl R1CSShape { A: self.A.clone(), B: self.B.clone(), C: self.C.clone(), - digest: OnceCell::new(), + digest: None, }; } @@ -380,7 +388,7 @@ impl R1CSShape { A: A_padded, B: B_padded, C: C_padded, - digest: OnceCell::new(), + digest: None, } } diff --git a/src/spartan/direct.rs b/src/spartan/direct.rs index 93f58f35..9198ea7b 100644 --- a/src/spartan/direct.rs +++ b/src/spartan/direct.rs @@ -85,7 +85,7 @@ where impl> VerifierKey { /// Returns the digest of the verifier's key - pub fn digest(&self) -> E::Scalar { + pub fn digest(&mut self) -> E::Scalar { self.vk.digest() } } @@ -169,7 +169,7 @@ impl, C: StepCircuit> DirectSN } /// Verifies a proof of satisfiability - pub fn verify(&self, vk: &VerifierKey, io: &[E::Scalar]) -> Result<(), NovaError> { + pub fn verify(&self, vk: &mut VerifierKey, io: &[E::Scalar]) -> Result<(), NovaError> { // derandomize/unblind commitments let comm_W = E::CE::derandomize(&vk.dk, &self.comm_W, &self.blind_r_W); @@ -177,7 +177,7 @@ impl, C: StepCircuit> DirectSN let u_relaxed = RelaxedR1CSInstance::from_r1cs_instance_unchecked(&comm_W, io); // verify the snark using the constructed instance - self.snark.verify(&vk.vk, &u_relaxed)?; + self.snark.verify(&mut vk.vk, &u_relaxed)?; Ok(()) } @@ -272,7 +272,7 @@ mod tests { let circuit = CubicCircuit::default(); // produce keys - let (pk, vk) = + let (pk, mut vk) = DirectSNARK::::Scalar>>::setup(circuit.clone()).unwrap(); let num_steps = 3; @@ -296,7 +296,7 @@ mod tests { .into_iter() .chain(z_i_plus_one.clone()) .collect::>(); - let res = snark.verify(&vk, &io); + let res = snark.verify(&mut vk, &io); assert!(res.is_ok()); // set input to the next step diff --git a/src/spartan/ppsnark.rs b/src/spartan/ppsnark.rs index 7198faac..caf5390f 100644 --- a/src/spartan/ppsnark.rs +++ b/src/spartan/ppsnark.rs @@ -34,7 +34,7 @@ use crate::{ use core::cmp::max; use ff::Field; use itertools::Itertools as _; -use once_cell::sync::OnceCell; +// use once_cell::sync::OnceCell; use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -856,8 +856,8 @@ pub struct VerifierKey> { num_vars: usize, vk_ee: EE::VerifierKey, S_comm: R1CSShapeSparkCommitment, - #[serde(skip, default = "OnceCell::new")] - digest: OnceCell, + #[serde(skip)] + digest: Option, } impl> SimpleDigestible for VerifierKey {} @@ -1044,15 +1044,23 @@ impl> VerifierKey { } impl> DigestHelperTrait for VerifierKey { /// Returns the digest of the verifier's key - fn digest(&self) -> E::Scalar { - self - .digest - .get_or_try_init(|| { - let dc = DigestComputer::new(self); - dc.digest() - }) - .cloned() - .expect("Failure to retrieve digest!") + fn digest(&mut self) -> E::Scalar { + if self.digest.is_none() { + let computed_digest = DigestComputer::new(self) + .digest() + .expect("Failure in retrieving digest"); + self.digest = Some(computed_digest); + } + self.digest.unwrap() + + // self + // .digest + // .get_or_try_init(|| { + // let dc = DigestComputer::new(self); + // dc.digest() + // }) + // .cloned() + // .expect("Failure to retrieve digest!") } } @@ -1083,7 +1091,7 @@ impl> RelaxedR1CSSNARKTrait for Relax let S_repr = R1CSShapeSparkRepr::new(&S); let S_comm = S_repr.commit(ck); - let vk = VerifierKey::new(S.num_cons, S.num_vars, S_comm.clone(), vk_ee); + let mut vk = VerifierKey::new(S.num_cons, S.num_vars, S_comm.clone(), vk_ee); let pk = ProverKey { pk_ee, @@ -1419,7 +1427,11 @@ impl> RelaxedR1CSSNARKTrait for Relax } /// verifies a proof of satisfiability of a `RelaxedR1CS` instance - fn verify(&self, vk: &Self::VerifierKey, U: &RelaxedR1CSInstance) -> Result<(), NovaError> { + fn verify( + &self, + vk: &mut Self::VerifierKey, + U: &RelaxedR1CSInstance, + ) -> Result<(), NovaError> { let mut transcript = E::TE::new(b"RelaxedR1CSSNARK"); // append the verifier key (including commitment to R1CS matrices) and the RelaxedR1CSInstance to the transcript diff --git a/src/spartan/snark.rs b/src/spartan/snark.rs index 3aa7f96c..0ab158e8 100644 --- a/src/spartan/snark.rs +++ b/src/spartan/snark.rs @@ -27,7 +27,7 @@ use crate::{ use ff::Field; use itertools::Itertools as _; -use once_cell::sync::OnceCell; +// use once_cell::sync::OnceCell; use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -46,8 +46,8 @@ pub struct ProverKey> { pub struct VerifierKey> { vk_ee: EE::VerifierKey, S: R1CSShape, - #[serde(skip, default = "OnceCell::new")] - digest: OnceCell, + #[serde(skip)] + digest: Option, } impl> SimpleDigestible for VerifierKey {} @@ -57,22 +57,30 @@ impl> VerifierKey { VerifierKey { vk_ee, S: shape, - digest: OnceCell::new(), + digest: None, } } } impl> DigestHelperTrait for VerifierKey { /// Returns the digest of the verifier's key. - fn digest(&self) -> E::Scalar { - self - .digest - .get_or_try_init(|| { - let dc = DigestComputer::::new(self); - dc.digest() - }) - .cloned() - .expect("Failure to retrieve digest!") + fn digest(&mut self) -> E::Scalar { + if self.digest.is_none() { + let computed_digest = DigestComputer::new(self) + .digest() + .expect("Failure in retrieving digest"); + self.digest = Some(computed_digest); + } + self.digest.unwrap() + + // self + // .digest + // .get_or_try_init(|| { + // let dc = DigestComputer::::new(self); + // dc.digest() + // }) + // .cloned() + // .expect("Failure to retrieve digest!") } } @@ -104,7 +112,7 @@ impl> RelaxedR1CSSNARKTrait for Relax let S = S.pad(); - let vk: VerifierKey = VerifierKey::new(S, vk_ee); + let mut vk: VerifierKey = VerifierKey::new(S, vk_ee); let pk = ProverKey { pk_ee, @@ -272,11 +280,15 @@ impl> RelaxedR1CSSNARKTrait for Relax } /// verifies a proof of satisfiability of a `RelaxedR1CS` instance - fn verify(&self, vk: &Self::VerifierKey, U: &RelaxedR1CSInstance) -> Result<(), NovaError> { + fn verify( + &self, + vk: &mut Self::VerifierKey, + U: &RelaxedR1CSInstance, + ) -> Result<(), NovaError> { let mut transcript = E::TE::new(b"RelaxedR1CSSNARK"); // append the digest of R1CS matrices and the RelaxedR1CSInstance to the transcript - transcript.absorb(b"vk", &vk.digest()); + transcript.absorb(b"vk", &mut vk.digest()); transcript.absorb(b"U", U); let (num_rounds_x, num_rounds_y) = ( diff --git a/src/traits/snark.rs b/src/traits/snark.rs index 799a2843..335588d2 100644 --- a/src/traits/snark.rs +++ b/src/traits/snark.rs @@ -52,11 +52,12 @@ pub trait RelaxedR1CSSNARKTrait: ) -> Result; /// Verifies a SNARK for a relaxed R1CS - fn verify(&self, vk: &Self::VerifierKey, U: &RelaxedR1CSInstance) -> Result<(), NovaError>; + fn verify(&self, vk: &mut Self::VerifierKey, U: &RelaxedR1CSInstance) + -> Result<(), NovaError>; } /// A helper trait that defines the behavior of a verifier key of `zkSNARK` pub trait DigestHelperTrait { /// Returns the digest of the verifier's key - fn digest(&self) -> E::Scalar; + fn digest(&mut self) -> E::Scalar; }