Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add parallelize_in helper function #46

Merged
merged 1 commit into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions halo2-base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ num-traits = "0.2"
rand_chacha = "0.3"
rustc-hash = "1.1"
ff = "0.12"
rayon = "1.6.1"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
log = "0.4"
Expand Down
3 changes: 3 additions & 0 deletions halo2-base/src/gates/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ use std::{
env::{set_var, var},
};

mod parallelize;
pub use parallelize::*;

/// Vector of thread advice column break points
pub type ThreadBreakPoints = Vec<usize>;
/// Vector of vectors tracking the thread break points across different halo2 phases
Expand Down
38 changes: 38 additions & 0 deletions halo2-base/src/gates/builder/parallelize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use itertools::Itertools;
use rayon::prelude::*;

use crate::{utils::ScalarField, Context};

use super::GateThreadBuilder;

/// Utility function to parallelize an operation involving [`Context`]s in phase `phase`.
pub fn parallelize_in<F, T, R, FR>(
phase: usize,
builder: &mut GateThreadBuilder<F>,
input: Vec<T>,
f: FR,
) -> Vec<R>
where
F: ScalarField,
T: Send,
R: Send,
FR: Fn(&mut Context<F>, T) -> R + Send + Sync,
{
let witness_gen_only = builder.witness_gen_only();
// to prevent concurrency issues with context id, we generate all the ids first
let ctx_ids = input.iter().map(|_| builder.get_new_thread_id()).collect_vec();
let (outputs, mut ctxs): (Vec<_>, Vec<_>) = input
.into_par_iter()
.zip(ctx_ids.into_par_iter())
.map(|(input, ctx_id)| {
// create new context
let mut ctx = Context::new(witness_gen_only, ctx_id);
let output = f(&mut ctx, input);
(output, ctx)
})
.unzip();
// we collect the new threads to ensure they are a FIXED order, otherwise later `assign_threads_in` will get confused
builder.threads[phase].append(&mut ctxs);

outputs
}
31 changes: 15 additions & 16 deletions halo2-ecc/src/ecc/fixed_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use super::{ec_add_unequal, ec_select, ec_select_from_bits, EcPoint, EccChip};
use crate::fields::{FieldChip, PrimeField, Selectable};
use group::Curve;
use halo2_base::gates::builder::GateThreadBuilder;
use halo2_base::gates::builder::{parallelize_in, GateThreadBuilder};
use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context};
use itertools::Itertools;
use rayon::prelude::*;
Expand Down Expand Up @@ -107,6 +107,7 @@ where
curr_point.unwrap()
}

/* To reduce total amount of code, just always use msm_par below.
// basically just adding up individual fixed_base::scalar_multiply except that we do all batched normalization of cached points at once to further save inversion time during witness generation
// we also use the random accumulator for some extra efficiency (which also works in scalar multiply case but that is TODO)
pub fn msm<F, FC, C>(
Expand Down Expand Up @@ -212,6 +213,7 @@ where
.collect_vec();
chip.sum::<C>(ctx, scalar_mults)
}
*/

/// # Assumptions
/// * `points.len() = scalars.len()`
Expand Down Expand Up @@ -269,25 +271,23 @@ where
C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine);

let field_chip = chip.field_chip();
let witness_gen_only = builder.witness_gen_only();

let zero = builder.main(phase).load_zero();
let thread_ids = (0..scalars.len()).map(|_| builder.get_new_thread_id()).collect::<Vec<_>>();
let (new_threads, scalar_mults): (Vec<_>, Vec<_>) = cached_points_affine
.par_chunks(cached_points_affine.len() / points.len())
.zip_eq(scalars.into_par_iter())
.zip(thread_ids.into_par_iter())
.map(|((cached_points, scalar), thread_id)| {
let mut thread = Context::new(witness_gen_only, thread_id);
let ctx = &mut thread;

let scalar_mults = parallelize_in(
phase,
builder,
cached_points_affine
.chunks(cached_points_affine.len() / points.len())
.zip_eq(scalars)
.collect(),
|ctx, (cached_points, scalar)| {
let cached_points = cached_points
.iter()
.map(|point| chip.assign_constant_point(ctx, *point))
.collect_vec();
let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev();

debug_assert_eq!(scalar.len(), scalar_len);
assert_eq!(scalar.len(), scalar_len);
let bits = scalar
.into_iter()
.flat_map(|scalar_chunk| {
Expand Down Expand Up @@ -319,9 +319,8 @@ where
field_chip.gate().mul_add(ctx, is_started, is_zero_window, not_zero_window)
};
}
(thread, curr_point.unwrap())
})
.unzip();
builder.threads[phase].extend(new_threads);
curr_point.unwrap()
},
);
chip.sum::<C>(builder.main(phase), scalar_mults)
}
25 changes: 10 additions & 15 deletions halo2-ecc/src/ecc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,7 @@ impl<'chip, F: PrimeField, FC: FieldChip<F>> EccChip<'chip, F, FC> {
self.field_chip.assert_equal(ctx, P.y, Q.y);
}

/// None of elements in `points` can be point at infinity.
pub fn sum<C>(
&self,
ctx: &mut Context<F>,
Expand Down Expand Up @@ -1153,21 +1154,15 @@ impl<'chip, F: PrimeField, FC: FieldChip<F>> EccChip<'chip, F, FC> {
#[cfg(feature = "display")]
println!("computing length {} fixed base msm", points.len());

// heuristic to decide when to use parallelism
if points.len() < 25 {
let ctx = builder.main(phase);
fixed_base::msm(self, ctx, points, scalars, max_scalar_bits_per_cell, clump_factor)
} else {
fixed_base::msm_par(
self,
builder,
points,
scalars,
max_scalar_bits_per_cell,
clump_factor,
phase,
)
}
fixed_base::msm_par(
self,
builder,
points,
scalars,
max_scalar_bits_per_cell,
clump_factor,
phase,
)

// Empirically does not seem like pippenger is any better for fixed base msm right now, because of the cost of `select_by_indicator`
// Cell usage becomes around comparable when `points.len() > 100`, and `clump_factor` should always be 4
Expand Down
69 changes: 27 additions & 42 deletions halo2-ecc/src/ecc/pippenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ use crate::{
fields::{FieldChip, PrimeField, Selectable},
};
use halo2_base::{
gates::{builder::GateThreadBuilder, GateInstructions},
gates::{
builder::{parallelize_in, GateThreadBuilder},
GateInstructions,
},
utils::CurveAffineExt,
AssignedValue, Context,
AssignedValue,
};
use rayon::prelude::*;

// Reference: https://jbootle.github.io/Misc/pippenger.pdf

Expand Down Expand Up @@ -238,7 +240,6 @@ where

// get a main thread
let ctx = builder.main(phase);
let witness_gen_only = ctx.witness_gen_only();
// single-threaded computation:
for scalar in scalars {
for (scalar_chunk, bool_chunk) in
Expand All @@ -250,32 +251,28 @@ where
}
}
}
// see multi-product comments for explanation of below

let c = clump_factor;
let num_rounds = (points.len() + c - 1) / c;
// to avoid adding two points that are equal or negative of each other,
// we use a trick from halo2wrong where we load a "sufficiently generic" `C` point as witness
// note that while we load a random point, an adversary could load a specifically chosen point, so we must carefully handle edge cases with constraints
// we call it "any point" instead of "random point" to emphasize that "any" sufficiently generic point will do
let any_base = load_random_point::<F, FC, C>(chip, ctx);
let mut any_points = Vec::with_capacity(num_rounds);
any_points.push(any_base);
for _ in 1..num_rounds {
any_points.push(ec_double(chip, ctx, any_points.last().unwrap()));
}
// we will use a different thread per round
// to prevent concurrency issues with context id, we generate all the ids first
let thread_ids = (0..num_rounds).map(|_| builder.get_new_thread_id()).collect::<Vec<_>>();
// now begins multi-threading

// now begins multi-threading
// multi_prods is 2d vector of size `num_rounds` by `scalar_bits`
let (new_threads, multi_prods): (Vec<_>, Vec<_>) = points
.par_chunks(c)
.zip(any_points.par_iter())
.zip(thread_ids.into_par_iter())
.enumerate()
.map(|(round, ((points_clump, any_point), thread_id))| {
let multi_prods = parallelize_in(
phase,
builder,
points.chunks(c).into_iter().zip(any_points.iter()).enumerate().collect(),
|ctx, (round, (points_clump, any_point))| {
// compute all possible multi-products of elements in points[round * c .. round * (c+1)]
// create new thread
let mut thread = Context::new(witness_gen_only, thread_id);
let ctx = &mut thread;
// stores { any_point, any_point + points[0], any_point + points[1], any_point + points[0] + points[1] , ... }
let mut bucket = Vec::with_capacity(1 << c);
let any_point = into_strict_point(chip, ctx, any_point.clone());
Expand All @@ -294,7 +291,7 @@ where
bucket.push(new_point);
}
}
let multi_prods = bool_scalars
bool_scalars
.iter()
.map(|bits| {
strict_ec_select_from_bits(
Expand All @@ -304,31 +301,19 @@ where
&bits[round * c..round * c + points_clump.len()],
)
})
.collect::<Vec<_>>();

(thread, multi_prods)
})
.unzip();
// we collect the new threads to ensure they are a FIXED order, otherwise later `assign_threads_in` will get confused
builder.threads[phase].extend(new_threads);
.collect::<Vec<_>>()
},
);

// agg[j] = sum_{i=0..num_rounds} multi_prods[i][j] for j = 0..scalar_bits
let thread_ids = (0..scalar_bits).map(|_| builder.get_new_thread_id()).collect::<Vec<_>>();
let (new_threads, mut agg): (Vec<_>, Vec<_>) = thread_ids
.into_par_iter()
.enumerate()
.map(|(i, thread_id)| {
let mut thread = Context::new(witness_gen_only, thread_id);
let ctx = &mut thread;
let mut acc = multi_prods[0][i].clone();
for multi_prod in multi_prods.iter().skip(1) {
let _acc = ec_add_unequal(chip, ctx, &acc, &multi_prod[i], true);
acc = into_strict_point(chip, ctx, _acc);
}
(thread, acc)
})
.unzip();
builder.threads[phase].extend(new_threads);
let mut agg = parallelize_in(phase, builder, (0..scalar_bits).collect(), |ctx, i| {
let mut acc = multi_prods[0][i].clone();
for multi_prod in multi_prods.iter().skip(1) {
let _acc = ec_add_unequal(chip, ctx, &acc, &multi_prod[i], true);
acc = into_strict_point(chip, ctx, _acc);
}
acc
});

// gets the LAST thread for single threaded work
let ctx = builder.main(phase);
Expand Down