Skip to content

Commit

Permalink
[FEAT] implement multi-batch aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenfeizhang committed Apr 24, 2023
1 parent a3d0a5a commit 654706d
Show file tree
Hide file tree
Showing 24 changed files with 406 additions and 315 deletions.
8 changes: 4 additions & 4 deletions snark-verifier-sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ hex = "0.4"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
bincode = "1.3.3"
ark-std = { version = "0.3.0", features = ["print-trace"], optional = true }
ark-std = { version = "0.3.0" }

halo2-base = { git = "https://github.com/scroll-tech/halo2-lib.git", branch = "halo2-ecc-snark-verifier-0323", default-features=false, features=["halo2-pse","display"] }
snark-verifier = { path = "../snark-verifier", default-features = false }
Expand All @@ -29,8 +29,8 @@ ethereum-types = { version = "0.14", default-features = false, features = ["std"

env_logger = "0.10.0"
log = "0.4.17"

[dev-dependencies]
ark-std = { version = "0.3.0", features = ["print-trace"] }
ethers-signers = { version = "0.17.0" }
paste = "1.0.7"
pprof = { version = "0.11", features = ["criterion", "flamegraph"] }
Expand All @@ -47,8 +47,8 @@ eth-types = { git = "https://github.com/scroll-tech/zkevm-circuits.git", branch
mock = { git = "https://github.com/scroll-tech/zkevm-circuits.git", branch = "halo2-ecc-snark-verifier-0323" }

[features]
default = ["loader_halo2", "loader_evm", "halo2-pse", "halo2-base/jemallocator"]
display = ["snark-verifier/display", "dep:ark-std"]
default = ["loader_halo2", "loader_evm", "halo2-pse", "halo2-base/jemallocator", "display"]
display = ["snark-verifier/display", "ark-std/print-trace"]
loader_evm = ["snark-verifier/loader_evm", "dep:ethereum-types"]
loader_halo2 = ["snark-verifier/loader_halo2"]
parallel = ["snark-verifier/parallel"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"strategy":"Simple","degree":25,"num_advice":[21],"num_lookup_advice":[2],"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3}
{"strategy":"Simple","degree":26,"num_advice":[21],"num_lookup_advice":[2],"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3}
5 changes: 1 addition & 4 deletions snark-verifier-sdk/src/evm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@ where
{
#[cfg(debug_assertions)]
{
use halo2_base::halo2_proofs::{
dev::MockProver,
poly::commitment::Params,
};
use halo2_base::halo2_proofs::{dev::MockProver, poly::commitment::Params};
MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied();
}

Expand Down
2 changes: 1 addition & 1 deletion snark-verifier-sdk/src/evm_circuits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ mod test {
//
{
std::env::set_var("VERIFY_CONFIG", "./configs/two_layer_recursion_second_layer.config");

let agg_circuit = AggregationCircuit::new(&params_layer_2, [layer_1_snark], &mut rng);
let pk_outer = gen_pk(&params_layer_2, &agg_circuit, None);
log::info!("finished layer 2 aggregation circuit generation");
Expand Down
4 changes: 3 additions & 1 deletion snark-verifier-sdk/src/halo2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ where
#[cfg(debug_assertions)]
{
use halo2_proofs::poly::commitment::Params;
halo2_proofs::dev::MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied();
halo2_proofs::dev::MockProver::run(params.k(), &circuit, instances.clone())
.unwrap()
.assert_satisfied();
}

if let Some((instance_path, proof_path)) = path {
Expand Down
2 changes: 1 addition & 1 deletion snark-verifier-sdk/src/halo2/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ impl AggregationCircuit {
pub fn new(
params: &ParamsKZG<Bn256>,
snarks: impl IntoIterator<Item = Snark>,
rng: impl Rng + Send,
rng: &mut (impl Rng + Send),
) -> Self {
let svk = params.get_g()[0].into();
let snarks = snarks.into_iter().collect_vec();
Expand Down
1 change: 1 addition & 0 deletions snark-verifier-sdk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub mod evm;
pub mod halo2;

mod evm_circuits;
pub mod multi_batch;

#[cfg(test)]
mod tests;
Expand Down
12 changes: 12 additions & 0 deletions snark-verifier-sdk/src/multi_batch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//! APIs to handle multi-batching snarks.
//!
//! In each batch iteration, we are doing two layers of recursions.
//! - use a wide recursive circuit to aggregate the proofs
//! - use a slim recursive circuit to shrink the size of the aggregated proof
//!
mod evm;
mod halo2;

pub use evm::*;
pub use halo2::*;
107 changes: 107 additions & 0 deletions snark-verifier-sdk/src/multi_batch/evm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
use crate::evm::{gen_evm_proof_shplonk, gen_evm_verifier_shplonk};
use crate::gen_pk;
use crate::halo2::aggregation::AggregationCircuit;
use crate::halo2::gen_snark_shplonk;
use crate::multi_batch::gen_two_layer_recursive_snark;
use crate::{CircuitExt, Snark};
#[cfg(feature = "display")]
use ark_std::{end_timer, start_timer};
use halo2_base::halo2_proofs;
use halo2_proofs::{halo2curves::bn256::Bn256, poly::kzg::commitment::ParamsKZG};
use itertools::Itertools;
use rand::Rng;

/// Inputs:
/// - kzg parameters
/// - a public key: all circuit should share a same public key
/// - circuit instances: c1,...ck
/// - rng
/// Output
/// -
/// - the evm byte code to verify the proof
/// - the actual serialized proof
pub fn gen_two_layer_evm_verifier<'params>(
params: &'params ParamsKZG<Bn256>,
input_snarks: Vec<Snark>,
rng: &mut (impl Rng + Send),
) -> (Vec<u8>, Vec<u8>) {
let timer = start_timer!(|| "begin two layer recursions");
// ===============================================
// first layer
// ===============================================
// use a wide config to aggregate snarks

std::env::set_var("VERIFY_CONFIG", "./configs/two_layer_recursion_first_layer.config");

let layer_1_snark = {
let layer_1_circuit = AggregationCircuit::new(&params, input_snarks, rng);
let layer_1_pk = gen_pk(&params, &layer_1_circuit, None);
gen_snark_shplonk(&params, &layer_1_pk, layer_1_circuit.clone(), rng, None::<String>)
};

println!("Finished layer 1 snark generation");

// ===============================================
// second layer
// ===============================================
// use a skim config to aggregate snarks

std::env::set_var("VERIFY_CONFIG", "./configs/two_layer_recursion_second_layer.config");

let layer_2_circuit = AggregationCircuit::new(&params, [layer_1_snark], rng);
let layer_2_pk = gen_pk(&params, &layer_2_circuit, None);

let snark = gen_evm_proof_shplonk(
&params,
&layer_2_pk,
layer_2_circuit.clone(),
layer_2_circuit.instances(),
rng,
);
// ===============================================
// bytecode
// ===============================================
let num_instance = layer_2_circuit.instances().iter().map(|x| x.len()).collect_vec();

let bytecode = gen_evm_verifier_shplonk::<AggregationCircuit>(
params,
layer_2_pk.get_vk(),
num_instance,
None,
);

println!("Finished layer 2 snark generation");
end_timer!(timer);
(bytecode, snark)
}

/// Generate the EVM bytecode and the proofs for the 4 layer recursion circuit
/// Output
/// -
/// - the evm byte code to verify the proof
/// - the actual serialized proof
pub fn gen_evm_four_layer_recursive_snark<'params>(
params: &'params ParamsKZG<Bn256>,
input_snark_vecs: Vec<Vec<Snark>>,
rng: &mut (impl Rng + Send),
) -> (Vec<u8>, Vec<u8>) {
let timer = start_timer!(|| "begin two layer recursions");

let mut snarks = vec![];
let len = input_snark_vecs[0].len();
let inner_timer = start_timer!(|| "inner layers");
for input_snarks in input_snark_vecs.iter() {
assert_eq!(len, input_snarks.len());

let snark = gen_two_layer_recursive_snark(params, input_snarks.clone(), rng);
snarks.push(snark);
}
end_timer!(inner_timer);

let outer_timer = start_timer!(|| "outer layers");
let (bytecode, snark) = gen_two_layer_evm_verifier(params, snarks, rng);
end_timer!(outer_timer);

end_timer!(timer);
(bytecode, snark)
}
84 changes: 84 additions & 0 deletions snark-verifier-sdk/src/multi_batch/halo2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use crate::gen_pk;
use crate::halo2::aggregation::AggregationCircuit;
use crate::halo2::gen_snark_shplonk;
use crate::Snark;
#[cfg(feature = "display")]
use ark_std::{end_timer, start_timer};
use halo2_base::halo2_proofs;
use halo2_proofs::{halo2curves::bn256::Bn256, poly::kzg::commitment::ParamsKZG};
use rand::Rng;

// use super::MultiBatchConfig;

// Inputs:
// - kzg parameters
// - a public key: all circuit should share a same public key
// - circuit instances: c1,...ck
// - rng
// Output
// -
// - a SNARK which is a snark **proof** of statement s, where
// - s: I have seen the proofs for all circuits c1,...ck
pub fn gen_two_layer_recursive_snark<'params>(
params: &'params ParamsKZG<Bn256>,
input_snarks: Vec<Snark>,
rng: &mut (impl Rng + Send),
) -> Snark {
let timer = start_timer!(|| "begin two layer recursions");
// ===============================================
// first layer
// ===============================================
// use a wide config to aggregate snarks

std::env::set_var("VERIFY_CONFIG", "./configs/two_layer_recursion_first_layer.config");

let layer_1_snark = {
let layer_1_circuit = AggregationCircuit::new(&params, input_snarks, rng);
let layer_1_pk = gen_pk(&params, &layer_1_circuit, None);
gen_snark_shplonk(&params, &layer_1_pk, layer_1_circuit.clone(), rng, None::<String>)
};

println!("Finished layer 1 snark generation");
// ===============================================
// second layer
// ===============================================
// use a skim config to aggregate snarks

std::env::set_var("VERIFY_CONFIG", "./configs/two_layer_recursion_second_layer.config");

let layer_2_circuit = AggregationCircuit::new(&params, [layer_1_snark], rng);
let layer_2_pk = gen_pk(&params, &layer_2_circuit, None);

let snark =
gen_snark_shplonk(&params, &layer_2_pk, layer_2_circuit.clone(), rng, None::<String>);

println!("Finished layer 2 snark generation");
end_timer!(timer);
snark
}

pub fn gen_four_layer_recursive_snark<'params>(
params: &'params ParamsKZG<Bn256>,
input_snark_vecs: Vec<Vec<Snark>>,
rng: &mut (impl Rng + Send),
) -> Snark {
let timer = start_timer!(|| "begin two layer recursions");

let mut snarks = vec![];
let len = input_snark_vecs[0].len();
let inner_timer = start_timer!(|| "inner layers");
for input_snarks in input_snark_vecs.iter() {
assert_eq!(len, input_snarks.len());

let snark = gen_two_layer_recursive_snark(params, input_snarks.clone(), rng);
snarks.push(snark);
}
end_timer!(inner_timer);

let outer_timer = start_timer!(|| "outer layers");
let snark = gen_two_layer_recursive_snark(params, snarks, rng);
end_timer!(outer_timer);

end_timer!(timer);
snark
}
1 change: 1 addition & 0 deletions snark-verifier-sdk/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use test_circuit_1::TestCircuit1;
use test_circuit_2::TestCircuit2;

mod evm_verifier;
mod multi_batch_aggregation;
mod single_layer_aggregation;
mod test_circuit_1;
mod test_circuit_2;
Expand Down
94 changes: 94 additions & 0 deletions snark-verifier-sdk/src/tests/multi_batch_aggregation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use super::TestCircuit1;
use crate::multi_batch::{gen_four_layer_recursive_snark, gen_two_layer_recursive_snark};
use crate::Snark;
use crate::{gen_pk, halo2::gen_snark_shplonk};
use ark_std::test_rng;
use halo2_base::halo2_proofs;
use halo2_proofs::poly::commitment::Params;
use snark_verifier::loader::halo2::halo2_ecc::halo2_base::utils::fs::gen_srs;

#[test]
fn test_partial_multi_batch_aggregation() {
let k = 8;

// let config = MultiBatchConfig::new(k, k_layer_1, k_layer_2);
println!("finished configurations");
let mut rng = test_rng();
let params = gen_srs(26 as u32);
let mut param_inner = params.clone();
param_inner.downsize(k as u32);
println!("finished SRS generation");

let circuits: Vec<_> = (0..2).map(|_| TestCircuit1::rand(&mut rng)).collect();
let pk = gen_pk(&param_inner, &circuits[0], None);
println!("finished pk and circuits generation");

// ===============================================
// convert input circuits to snarks
// ===============================================
let input_snarks: Vec<Snark> = {
let k = pk.get_vk().get_domain().k();
println!("inner circuit k = {}", k);
circuits
.iter()
.map(|circuit| {
gen_snark_shplonk::<TestCircuit1>(
&param_inner,
&pk,
circuit.clone(),
&mut rng,
None::<String>,
)
})
.collect()
};
println!("Finished input snark generation");

let _snark = gen_two_layer_recursive_snark(&params, input_snarks, &mut rng);
}

#[test]
fn test_full_multi_batch_aggregation() {
let k = 8;

// let config = MultiBatchConfig::new(k, k_layer_1, k_layer_2);
println!("finished configurations");
let mut rng = test_rng();
let params = gen_srs(26 as u32);
let mut param_inner = params.clone();
param_inner.downsize(k as u32);
println!("finished SRS generation");

let circuit_vecs: Vec<Vec<_>> =
(0..2).map(|_| (0..2).map(|_| TestCircuit1::rand(&mut rng)).collect()).collect();
let pk = gen_pk(&param_inner, &circuit_vecs[0][0], None);
println!("finished pk and circuits generation");

// ===============================================
// convert input circuits to snarks
// ===============================================
let input_snarks: Vec<Vec<Snark>> = {
let k = pk.get_vk().get_domain().k();
println!("inner circuit k = {}", k);
circuit_vecs
.iter()
.map(|circuits| {
circuits
.iter()
.map(|circuit| {
gen_snark_shplonk::<TestCircuit1>(
&param_inner,
&pk,
circuit.clone(),
&mut rng,
None::<String>,
)
})
.collect()
})
.collect()
};
println!("Finished input snark generation");

let _snark = gen_four_layer_recursive_snark(&params, input_snarks, &mut rng);
}
Loading

0 comments on commit 654706d

Please sign in to comment.