diff --git a/snark-verifier-sdk/Cargo.toml b/snark-verifier-sdk/Cargo.toml index 4d5587fd..518afb3f 100644 --- a/snark-verifier-sdk/Cargo.toml +++ b/snark-verifier-sdk/Cargo.toml @@ -15,7 +15,7 @@ hex = "0.4" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" bincode = "1.3.3" -ark-std = { version = "0.3.0", features = ["print-trace"], optional = true } +ark-std = { version = "0.3.0" } halo2-base = { git = "https://github.com/scroll-tech/halo2-lib.git", branch = "halo2-ecc-snark-verifier-0323", default-features=false, features=["halo2-pse","display"] } snark-verifier = { path = "../snark-verifier", default-features = false } @@ -29,8 +29,8 @@ ethereum-types = { version = "0.14", default-features = false, features = ["std" env_logger = "0.10.0" log = "0.4.17" + [dev-dependencies] -ark-std = { version = "0.3.0", features = ["print-trace"] } ethers-signers = { version = "0.17.0" } paste = "1.0.7" pprof = { version = "0.11", features = ["criterion", "flamegraph"] } @@ -47,8 +47,8 @@ eth-types = { git = "https://github.com/scroll-tech/zkevm-circuits.git", branch mock = { git = "https://github.com/scroll-tech/zkevm-circuits.git", branch = "halo2-ecc-snark-verifier-0323" } [features] -default = ["loader_halo2", "loader_evm", "halo2-pse", "halo2-base/jemallocator"] -display = ["snark-verifier/display", "dep:ark-std"] +default = ["loader_halo2", "loader_evm", "halo2-pse", "halo2-base/jemallocator", "display"] +display = ["snark-verifier/display", "ark-std/print-trace"] loader_evm = ["snark-verifier/loader_evm", "dep:ethereum-types"] loader_halo2 = ["snark-verifier/loader_halo2"] parallel = ["snark-verifier/parallel"] diff --git a/snark-verifier-sdk/configs/two_layer_recursion_first_layer.config b/snark-verifier-sdk/configs/two_layer_recursion_first_layer.config index 65bf3ab8..c306834a 100644 --- a/snark-verifier-sdk/configs/two_layer_recursion_first_layer.config +++ b/snark-verifier-sdk/configs/two_layer_recursion_first_layer.config @@ -1 +1 @@ -{"strategy":"Simple","degree":25,"num_advice":[21],"num_lookup_advice":[2],"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3} +{"strategy":"Simple","degree":26,"num_advice":[21],"num_lookup_advice":[2],"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3} diff --git a/snark-verifier-sdk/src/halo2/aggregation.rs b/snark-verifier-sdk/src/halo2/aggregation.rs index 832ebf7c..69fd04a8 100644 --- a/snark-verifier-sdk/src/halo2/aggregation.rs +++ b/snark-verifier-sdk/src/halo2/aggregation.rs @@ -212,7 +212,7 @@ impl AggregationCircuit { pub fn new( params: &ParamsKZG, snarks: impl IntoIterator, - rng: impl Rng + Send, + rng: &mut (impl Rng + Send), ) -> Self { let svk = params.get_g()[0].into(); let snarks = snarks.into_iter().collect_vec(); diff --git a/snark-verifier-sdk/src/lib.rs b/snark-verifier-sdk/src/lib.rs index 4ff4691d..9fbce6f2 100644 --- a/snark-verifier-sdk/src/lib.rs +++ b/snark-verifier-sdk/src/lib.rs @@ -28,6 +28,9 @@ pub mod evm; pub mod halo2; mod evm_circuits; +pub mod multi_batch; + + #[cfg(test)] mod tests; diff --git a/snark-verifier-sdk/src/multi_batch.rs b/snark-verifier-sdk/src/multi_batch.rs new file mode 100644 index 00000000..c37c812e --- /dev/null +++ b/snark-verifier-sdk/src/multi_batch.rs @@ -0,0 +1,12 @@ +//! APIs to handle multi-batching snarks. +//! +//! In each batch iteration, we are doing two layers of recursions. +//! - use a wide recursive circuit to aggregate the proofs +//! - use a slim recursive circuit to shrink the size of the aggregated proof +//! + +mod evm; +mod halo2; + +pub use evm::*; +pub use halo2::*; diff --git a/snark-verifier-sdk/src/multi_batch/evm.rs b/snark-verifier-sdk/src/multi_batch/evm.rs new file mode 100644 index 00000000..4a65b372 --- /dev/null +++ b/snark-verifier-sdk/src/multi_batch/evm.rs @@ -0,0 +1,113 @@ +use crate::evm::{gen_evm_proof_shplonk, gen_evm_verifier_shplonk}; +use crate::gen_pk; +use crate::halo2::aggregation::AggregationCircuit; +use crate::halo2::gen_snark_shplonk; +use crate::multi_batch::gen_two_layer_recursive_snark; +use crate::{CircuitExt, Snark}; +#[cfg(feature = "display")] +use ark_std::{end_timer, start_timer}; +use halo2_base::halo2_proofs; +use halo2_base::halo2_proofs::halo2curves::bn256::Fr; +use halo2_proofs::{halo2curves::bn256::Bn256, poly::kzg::commitment::ParamsKZG}; +use itertools::Itertools; +use rand::Rng; + +/// Inputs: +/// - kzg parameters +/// - a list of input snarks +/// - rng +/// +/// Outputs: +/// - the evm byte code to verify the proof +/// - the instances +/// - the actual serialized proof +pub fn gen_two_layer_evm_verifier<'params>( + params: &'params ParamsKZG, + input_snarks: Vec, + rng: &mut (impl Rng + Send), +) -> (Vec, Vec>, Vec) { + let timer = start_timer!(|| "begin two layer recursions"); + // =============================================== + // first layer + // =============================================== + // use a wide config to aggregate snarks + std::env::set_var("VERIFY_CONFIG", "./configs/two_layer_recursion_first_layer.config"); + + let layer_1_snark = { + let layer_1_circuit = AggregationCircuit::new(¶ms, input_snarks, rng); + let layer_1_pk = gen_pk(¶ms, &layer_1_circuit, None); + gen_snark_shplonk(¶ms, &layer_1_pk, layer_1_circuit.clone(), rng, None::) + }; + + log::trace!("Finished layer 1 snark generation"); + + // =============================================== + // second layer + // =============================================== + // use a skim config to aggregate snarks + + std::env::set_var("VERIFY_CONFIG", "./configs/two_layer_recursion_second_layer.config"); + + let layer_2_circuit = AggregationCircuit::new(¶ms, [layer_1_snark], rng); + let layer_2_pk = gen_pk(¶ms, &layer_2_circuit, None); + + let snark = gen_evm_proof_shplonk( + ¶ms, + &layer_2_pk, + layer_2_circuit.clone(), + layer_2_circuit.instances(), + rng, + ); + // =============================================== + // bytecode + // =============================================== + let num_instance = layer_2_circuit.instances().iter().map(|x| x.len()).collect_vec(); + + let bytecode = gen_evm_verifier_shplonk::( + params, + layer_2_pk.get_vk(), + num_instance, + None, + ); + + log::trace!("Finished layer 2 snark generation"); + end_timer!(timer); + (bytecode, layer_2_circuit.instances(), snark) +} + +/// Generate the EVM bytecode and the proofs for the 4 layer recursion circuit +/// +/// Input +/// - kzg parameters +/// - a list of input snarks +/// - rng +/// +/// Output +/// - the evm byte code to verify the proof +/// - the instances +/// - the actual serialized proof +pub fn gen_evm_four_layer_recursive_snark<'params>( + params: &'params ParamsKZG, + input_snark_vecs: Vec>, + rng: &mut (impl Rng + Send), +) -> (Vec, Vec>, Vec) { + let timer = start_timer!(|| "begin two layer recursions"); + + let mut snarks = vec![]; + let len = input_snark_vecs[0].len(); + let inner_timer = start_timer!(|| "inner layers"); + for input_snarks in input_snark_vecs.iter() { + assert_eq!(len, input_snarks.len()); + + let snark = gen_two_layer_recursive_snark(params, input_snarks.clone(), rng); + snarks.push(snark); + } + end_timer!(inner_timer); + + let outer_timer = start_timer!(|| "outer layers"); + let (bytecode, instances, snark) = gen_two_layer_evm_verifier(params, snarks, rng); + end_timer!(outer_timer); + + end_timer!(timer); + (bytecode, instances, snark) +} diff --git a/snark-verifier-sdk/src/multi_batch/halo2.rs b/snark-verifier-sdk/src/multi_batch/halo2.rs new file mode 100644 index 00000000..1f539d3e --- /dev/null +++ b/snark-verifier-sdk/src/multi_batch/halo2.rs @@ -0,0 +1,82 @@ +use crate::gen_pk; +use crate::halo2::aggregation::AggregationCircuit; +use crate::halo2::gen_snark_shplonk; +use crate::Snark; +#[cfg(feature = "display")] +use ark_std::{end_timer, start_timer}; +use halo2_base::halo2_proofs; +use halo2_proofs::{halo2curves::bn256::Bn256, poly::kzg::commitment::ParamsKZG}; +use rand::Rng; + +/// Inputs: +/// - kzg parameters +/// - a list of input snarks +/// - rng +/// +/// Outputs: +// - +// - a SNARK which is a snark **proof** of statement s, where +// - s: I have seen the proofs for all circuits c1,...ck +pub fn gen_two_layer_recursive_snark<'params>( + params: &'params ParamsKZG, + input_snarks: Vec, + rng: &mut (impl Rng + Send), +) -> Snark { + let timer = start_timer!(|| "begin two layer recursions"); + // =============================================== + // first layer + // =============================================== + // use a wide config to aggregate snarks + + std::env::set_var("VERIFY_CONFIG", "./configs/two_layer_recursion_first_layer.config"); + + let layer_1_snark = { + let layer_1_circuit = AggregationCircuit::new(¶ms, input_snarks, rng); + let layer_1_pk = gen_pk(¶ms, &layer_1_circuit, None); + gen_snark_shplonk(¶ms, &layer_1_pk, layer_1_circuit.clone(), rng, None::) + }; + + log::trace!("Finished layer 1 snark generation"); + // =============================================== + // second layer + // =============================================== + // use a skim config to aggregate snarks + + std::env::set_var("VERIFY_CONFIG", "./configs/two_layer_recursion_second_layer.config"); + + let layer_2_circuit = AggregationCircuit::new(¶ms, [layer_1_snark], rng); + let layer_2_pk = gen_pk(¶ms, &layer_2_circuit, None); + + let snark = + gen_snark_shplonk(¶ms, &layer_2_pk, layer_2_circuit.clone(), rng, None::); + + log::trace!("Finished layer 2 snark generation"); + end_timer!(timer); + snark +} + +pub fn gen_four_layer_recursive_snark<'params>( + params: &'params ParamsKZG, + input_snark_vecs: Vec>, + rng: &mut (impl Rng + Send), +) -> Snark { + let timer = start_timer!(|| "begin two layer recursions"); + + let mut snarks = vec![]; + let len = input_snark_vecs[0].len(); + let inner_timer = start_timer!(|| "inner layers"); + for input_snarks in input_snark_vecs.iter() { + assert_eq!(len, input_snarks.len()); + + let snark = gen_two_layer_recursive_snark(params, input_snarks.clone(), rng); + snarks.push(snark); + } + end_timer!(inner_timer); + + let outer_timer = start_timer!(|| "outer layers"); + let snark = gen_two_layer_recursive_snark(params, snarks, rng); + end_timer!(outer_timer); + + end_timer!(timer); + snark +} diff --git a/snark-verifier-sdk/src/tests/mod.rs b/snark-verifier-sdk/src/tests/mod.rs index 3ab6481f..e42c4598 100644 --- a/snark-verifier-sdk/src/tests/mod.rs +++ b/snark-verifier-sdk/src/tests/mod.rs @@ -8,6 +8,7 @@ use test_circuit_1::TestCircuit1; use test_circuit_2::TestCircuit2; mod evm_verifier; +mod multi_batch_aggregation; mod single_layer_aggregation; mod test_circuit_1; mod test_circuit_2; diff --git a/snark-verifier-sdk/src/tests/multi_batch_aggregation.rs b/snark-verifier-sdk/src/tests/multi_batch_aggregation.rs new file mode 100644 index 00000000..439d9844 --- /dev/null +++ b/snark-verifier-sdk/src/tests/multi_batch_aggregation.rs @@ -0,0 +1,98 @@ +use super::TestCircuit1; +use crate::evm::evm_verify; +use crate::multi_batch::{gen_two_layer_recursive_snark, gen_evm_four_layer_recursive_snark}; +use crate::Snark; +use crate::{gen_pk, halo2::gen_snark_shplonk}; +use ark_std::test_rng; +use halo2_base::halo2_proofs; +use halo2_proofs::poly::commitment::Params; +use snark_verifier::loader::halo2::halo2_ecc::halo2_base::utils::fs::gen_srs; + +#[test] +fn test_partial_multi_batch_aggregation() { + let k = 8; + + // let config = MultiBatchConfig::new(k, k_layer_1, k_layer_2); + println!("finished configurations"); + let mut rng = test_rng(); + let params = gen_srs(26 as u32); + let mut param_inner = params.clone(); + param_inner.downsize(k as u32); + println!("finished SRS generation"); + + let circuits: Vec<_> = (0..2).map(|_| TestCircuit1::rand(&mut rng)).collect(); + let pk = gen_pk(¶m_inner, &circuits[0], None); + println!("finished pk and circuits generation"); + + // =============================================== + // convert input circuits to snarks + // =============================================== + let input_snarks: Vec = { + let k = pk.get_vk().get_domain().k(); + println!("inner circuit k = {}", k); + circuits + .iter() + .map(|circuit| { + gen_snark_shplonk::( + ¶m_inner, + &pk, + circuit.clone(), + &mut rng, + None::, + ) + }) + .collect() + }; + println!("Finished input snark generation"); + + let _snark = gen_two_layer_recursive_snark(¶ms, input_snarks, &mut rng); +} + +#[test] +fn test_full_multi_batch_aggregation() { + let k = 8; + + // let config = MultiBatchConfig::new(k, k_layer_1, k_layer_2); + println!("finished configurations"); + let mut rng = test_rng(); + let params = gen_srs(26 as u32); + let mut param_inner = params.clone(); + param_inner.downsize(k as u32); + println!("finished SRS generation"); + + let circuit_vecs: Vec> = + (0..2).map(|_| (0..2).map(|_| TestCircuit1::rand(&mut rng)).collect()).collect(); + let pk = gen_pk(¶m_inner, &circuit_vecs[0][0], None); + println!("finished pk and circuits generation"); + + // =============================================== + // convert input circuits to snarks + // =============================================== + let input_snarks: Vec> = { + let k = pk.get_vk().get_domain().k(); + println!("inner circuit k = {}", k); + circuit_vecs + .iter() + .map(|circuits| { + circuits + .iter() + .map(|circuit| { + gen_snark_shplonk::( + ¶m_inner, + &pk, + circuit.clone(), + &mut rng, + None::, + ) + }) + .collect() + }) + .collect() + }; + println!("Finished input snark generation"); + + let (byte_code, instances, proof) = gen_evm_four_layer_recursive_snark(¶ms, input_snarks, &mut rng); + + evm_verify(byte_code, instances, proof) + +}