Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] implement multi-batch aggregation #11

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions snark-verifier-sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ hex = "0.4"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
bincode = "1.3.3"
ark-std = { version = "0.3.0", features = ["print-trace"], optional = true }
ark-std = { version = "0.3.0" }

halo2-base = { git = "https://github.com/scroll-tech/halo2-lib.git", branch = "halo2-ecc-snark-verifier-0323", default-features=false, features=["halo2-pse","display"] }
snark-verifier = { path = "../snark-verifier", default-features = false }
Expand All @@ -29,8 +29,8 @@ ethereum-types = { version = "0.14", default-features = false, features = ["std"

env_logger = "0.10.0"
log = "0.4.17"

[dev-dependencies]
ark-std = { version = "0.3.0", features = ["print-trace"] }
ethers-signers = { version = "0.17.0" }
paste = "1.0.7"
pprof = { version = "0.11", features = ["criterion", "flamegraph"] }
Expand All @@ -47,8 +47,8 @@ eth-types = { git = "https://github.com/scroll-tech/zkevm-circuits.git", branch
mock = { git = "https://github.com/scroll-tech/zkevm-circuits.git", branch = "halo2-ecc-snark-verifier-0323" }

[features]
default = ["loader_halo2", "loader_evm", "halo2-pse", "halo2-base/jemallocator"]
display = ["snark-verifier/display", "dep:ark-std"]
default = ["loader_halo2", "loader_evm", "halo2-pse", "halo2-base/jemallocator", "display"]
display = ["snark-verifier/display", "ark-std/print-trace"]
loader_evm = ["snark-verifier/loader_evm", "dep:ethereum-types"]
loader_halo2 = ["snark-verifier/loader_halo2"]
parallel = ["snark-verifier/parallel"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"strategy":"Simple","degree":25,"num_advice":[21],"num_lookup_advice":[2],"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3}
{"strategy":"Simple","degree":26,"num_advice":[21],"num_lookup_advice":[2],"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3}
2 changes: 1 addition & 1 deletion snark-verifier-sdk/src/halo2/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ impl AggregationCircuit {
pub fn new(
params: &ParamsKZG<Bn256>,
snarks: impl IntoIterator<Item = Snark>,
rng: impl Rng + Send,
rng: &mut (impl Rng + Send),
) -> Self {
let svk = params.get_g()[0].into();
let snarks = snarks.into_iter().collect_vec();
Expand Down
3 changes: 3 additions & 0 deletions snark-verifier-sdk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ pub mod evm;
pub mod halo2;

mod evm_circuits;
pub mod multi_batch;



#[cfg(test)]
mod tests;
Expand Down
12 changes: 12 additions & 0 deletions snark-verifier-sdk/src/multi_batch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//! APIs to handle multi-batching snarks.
//!
//! In each batch iteration, we are doing two layers of recursions.
//! - use a wide recursive circuit to aggregate the proofs
//! - use a slim recursive circuit to shrink the size of the aggregated proof
//!
mod evm;
mod halo2;

pub use evm::*;
pub use halo2::*;
113 changes: 113 additions & 0 deletions snark-verifier-sdk/src/multi_batch/evm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use crate::evm::{gen_evm_proof_shplonk, gen_evm_verifier_shplonk};
use crate::gen_pk;
use crate::halo2::aggregation::AggregationCircuit;
use crate::halo2::gen_snark_shplonk;
use crate::multi_batch::gen_two_layer_recursive_snark;
use crate::{CircuitExt, Snark};
#[cfg(feature = "display")]
use ark_std::{end_timer, start_timer};
use halo2_base::halo2_proofs;
use halo2_base::halo2_proofs::halo2curves::bn256::Fr;
use halo2_proofs::{halo2curves::bn256::Bn256, poly::kzg::commitment::ParamsKZG};
use itertools::Itertools;
use rand::Rng;

/// Inputs:
/// - kzg parameters
/// - a list of input snarks
/// - rng
///
/// Outputs:
/// - the evm byte code to verify the proof
/// - the instances
/// - the actual serialized proof
pub fn gen_two_layer_evm_verifier<'params>(
params: &'params ParamsKZG<Bn256>,
input_snarks: Vec<Snark>,
rng: &mut (impl Rng + Send),
) -> (Vec<u8>, Vec<Vec<Fr>>, Vec<u8>) {
let timer = start_timer!(|| "begin two layer recursions");
// ===============================================
// first layer
// ===============================================
// use a wide config to aggregate snarks
std::env::set_var("VERIFY_CONFIG", "./configs/two_layer_recursion_first_layer.config");

let layer_1_snark = {
let layer_1_circuit = AggregationCircuit::new(&params, input_snarks, rng);
let layer_1_pk = gen_pk(&params, &layer_1_circuit, None);
gen_snark_shplonk(&params, &layer_1_pk, layer_1_circuit.clone(), rng, None::<String>)
};

log::trace!("Finished layer 1 snark generation");

// ===============================================
// second layer
// ===============================================
// use a skim config to aggregate snarks

std::env::set_var("VERIFY_CONFIG", "./configs/two_layer_recursion_second_layer.config");

let layer_2_circuit = AggregationCircuit::new(&params, [layer_1_snark], rng);
let layer_2_pk = gen_pk(&params, &layer_2_circuit, None);

let snark = gen_evm_proof_shplonk(
&params,
&layer_2_pk,
layer_2_circuit.clone(),
layer_2_circuit.instances(),
rng,
);
// ===============================================
// bytecode
// ===============================================
let num_instance = layer_2_circuit.instances().iter().map(|x| x.len()).collect_vec();

let bytecode = gen_evm_verifier_shplonk::<AggregationCircuit>(
params,
layer_2_pk.get_vk(),
num_instance,
None,
);

log::trace!("Finished layer 2 snark generation");
end_timer!(timer);
(bytecode, layer_2_circuit.instances(), snark)
}

/// Generate the EVM bytecode and the proofs for the 4 layer recursion circuit
///
/// Input
/// - kzg parameters
/// - a list of input snarks
/// - rng
///
/// Output
/// - the evm byte code to verify the proof
/// - the instances
/// - the actual serialized proof
pub fn gen_evm_four_layer_recursive_snark<'params>(
params: &'params ParamsKZG<Bn256>,
input_snark_vecs: Vec<Vec<Snark>>,
rng: &mut (impl Rng + Send),
) -> (Vec<u8>, Vec<Vec<Fr>>, Vec<u8>) {
let timer = start_timer!(|| "begin two layer recursions");

let mut snarks = vec![];
let len = input_snark_vecs[0].len();
let inner_timer = start_timer!(|| "inner layers");
for input_snarks in input_snark_vecs.iter() {
assert_eq!(len, input_snarks.len());

let snark = gen_two_layer_recursive_snark(params, input_snarks.clone(), rng);
snarks.push(snark);
}
end_timer!(inner_timer);

let outer_timer = start_timer!(|| "outer layers");
let (bytecode, instances, snark) = gen_two_layer_evm_verifier(params, snarks, rng);
end_timer!(outer_timer);

end_timer!(timer);
(bytecode, instances, snark)
}
82 changes: 82 additions & 0 deletions snark-verifier-sdk/src/multi_batch/halo2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use crate::gen_pk;
use crate::halo2::aggregation::AggregationCircuit;
use crate::halo2::gen_snark_shplonk;
use crate::Snark;
#[cfg(feature = "display")]
use ark_std::{end_timer, start_timer};
use halo2_base::halo2_proofs;
use halo2_proofs::{halo2curves::bn256::Bn256, poly::kzg::commitment::ParamsKZG};
use rand::Rng;

/// Inputs:
/// - kzg parameters
/// - a list of input snarks
/// - rng
///
/// Outputs:
// -
// - a SNARK which is a snark **proof** of statement s, where
// - s: I have seen the proofs for all circuits c1,...ck
pub fn gen_two_layer_recursive_snark<'params>(
params: &'params ParamsKZG<Bn256>,
input_snarks: Vec<Snark>,
rng: &mut (impl Rng + Send),
) -> Snark {
let timer = start_timer!(|| "begin two layer recursions");
// ===============================================
// first layer
// ===============================================
// use a wide config to aggregate snarks

std::env::set_var("VERIFY_CONFIG", "./configs/two_layer_recursion_first_layer.config");

let layer_1_snark = {
let layer_1_circuit = AggregationCircuit::new(&params, input_snarks, rng);
let layer_1_pk = gen_pk(&params, &layer_1_circuit, None);
gen_snark_shplonk(&params, &layer_1_pk, layer_1_circuit.clone(), rng, None::<String>)
};

log::trace!("Finished layer 1 snark generation");
// ===============================================
// second layer
// ===============================================
// use a skim config to aggregate snarks

std::env::set_var("VERIFY_CONFIG", "./configs/two_layer_recursion_second_layer.config");

let layer_2_circuit = AggregationCircuit::new(&params, [layer_1_snark], rng);
let layer_2_pk = gen_pk(&params, &layer_2_circuit, None);

let snark =
gen_snark_shplonk(&params, &layer_2_pk, layer_2_circuit.clone(), rng, None::<String>);

log::trace!("Finished layer 2 snark generation");
end_timer!(timer);
snark
}

pub fn gen_four_layer_recursive_snark<'params>(
params: &'params ParamsKZG<Bn256>,
input_snark_vecs: Vec<Vec<Snark>>,
rng: &mut (impl Rng + Send),
) -> Snark {
let timer = start_timer!(|| "begin two layer recursions");

let mut snarks = vec![];
let len = input_snark_vecs[0].len();
let inner_timer = start_timer!(|| "inner layers");
for input_snarks in input_snark_vecs.iter() {
assert_eq!(len, input_snarks.len());

let snark = gen_two_layer_recursive_snark(params, input_snarks.clone(), rng);
snarks.push(snark);
}
end_timer!(inner_timer);

let outer_timer = start_timer!(|| "outer layers");
let snark = gen_two_layer_recursive_snark(params, snarks, rng);
end_timer!(outer_timer);

end_timer!(timer);
snark
}
1 change: 1 addition & 0 deletions snark-verifier-sdk/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use test_circuit_1::TestCircuit1;
use test_circuit_2::TestCircuit2;

mod evm_verifier;
mod multi_batch_aggregation;
mod single_layer_aggregation;
mod test_circuit_1;
mod test_circuit_2;
Expand Down
98 changes: 98 additions & 0 deletions snark-verifier-sdk/src/tests/multi_batch_aggregation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use super::TestCircuit1;
use crate::evm::evm_verify;
use crate::multi_batch::{gen_two_layer_recursive_snark, gen_evm_four_layer_recursive_snark};
use crate::Snark;
use crate::{gen_pk, halo2::gen_snark_shplonk};
use ark_std::test_rng;
use halo2_base::halo2_proofs;
use halo2_proofs::poly::commitment::Params;
use snark_verifier::loader::halo2::halo2_ecc::halo2_base::utils::fs::gen_srs;

#[test]
fn test_partial_multi_batch_aggregation() {
let k = 8;

// let config = MultiBatchConfig::new(k, k_layer_1, k_layer_2);
println!("finished configurations");
let mut rng = test_rng();
let params = gen_srs(26 as u32);
let mut param_inner = params.clone();
param_inner.downsize(k as u32);
println!("finished SRS generation");

let circuits: Vec<_> = (0..2).map(|_| TestCircuit1::rand(&mut rng)).collect();
let pk = gen_pk(&param_inner, &circuits[0], None);
println!("finished pk and circuits generation");

// ===============================================
// convert input circuits to snarks
// ===============================================
let input_snarks: Vec<Snark> = {
let k = pk.get_vk().get_domain().k();
println!("inner circuit k = {}", k);
circuits
.iter()
.map(|circuit| {
gen_snark_shplonk::<TestCircuit1>(
&param_inner,
&pk,
circuit.clone(),
&mut rng,
None::<String>,
)
})
.collect()
};
println!("Finished input snark generation");

let _snark = gen_two_layer_recursive_snark(&params, input_snarks, &mut rng);
}

#[test]
fn test_full_multi_batch_aggregation() {
let k = 8;

// let config = MultiBatchConfig::new(k, k_layer_1, k_layer_2);
println!("finished configurations");
let mut rng = test_rng();
let params = gen_srs(26 as u32);
let mut param_inner = params.clone();
param_inner.downsize(k as u32);
println!("finished SRS generation");

let circuit_vecs: Vec<Vec<_>> =
(0..2).map(|_| (0..2).map(|_| TestCircuit1::rand(&mut rng)).collect()).collect();
let pk = gen_pk(&param_inner, &circuit_vecs[0][0], None);
println!("finished pk and circuits generation");

// ===============================================
// convert input circuits to snarks
// ===============================================
let input_snarks: Vec<Vec<Snark>> = {
let k = pk.get_vk().get_domain().k();
println!("inner circuit k = {}", k);
circuit_vecs
.iter()
.map(|circuits| {
circuits
.iter()
.map(|circuit| {
gen_snark_shplonk::<TestCircuit1>(
&param_inner,
&pk,
circuit.clone(),
&mut rng,
None::<String>,
)
})
.collect()
})
.collect()
};
println!("Finished input snark generation");

let (byte_code, instances, proof) = gen_evm_four_layer_recursive_snark(&params, input_snarks, &mut rng);

evm_verify(byte_code, instances, proof)

}