Skip to content

Commit

Permalink
Use pre-computation to speed-up Secp256r1 verification (#595)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-lj authored May 31, 2023
1 parent 50403b9 commit ed1c2e1
Show file tree
Hide file tree
Showing 13 changed files with 1,025 additions and 9 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions fastcrypto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ ark-secp256r1 = "0.4.0"
ark-ec = "0.4.1"
ark-ff = "0.4.1"
ark-serialize = "0.4.1"
lazy_static = "1.4.0"

fastcrypto-derive = { path = "../fastcrypto-derive", version = "0.1.3" }

Expand Down
108 changes: 107 additions & 1 deletion fastcrypto/benches/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
extern crate criterion;

mod group_benches {
use criterion::measurement::Measurement;
use criterion::{measurement, BenchmarkGroup, Criterion};
use fastcrypto::groups::bls12381::{G1Element, G2Element, GTElement};
use fastcrypto::groups::multiplier::windowed::WindowedScalarMultiplier;
use fastcrypto::groups::multiplier::ScalarMultiplier;
use fastcrypto::groups::ristretto255::RistrettoPoint;
use fastcrypto::groups::{GroupElement, HashToGroupElement, Pairing, Scalar};
use fastcrypto::groups::secp256r1::ProjectivePoint;
use fastcrypto::groups::{secp256r1, GroupElement, HashToGroupElement, Pairing, Scalar};
use rand::thread_rng;

fn add_single<G: GroupElement, M: measurement::Measurement>(
Expand Down Expand Up @@ -36,12 +40,99 @@ mod group_benches {
c.bench_function(&(name.to_string()), move |b| b.iter(|| x * y));
}

fn scale_single_precomputed<G: GroupElement, Mul: ScalarMultiplier<G>, M: Measurement>(
name: &str,
c: &mut BenchmarkGroup<M>,
) {
let x = G::generator() * G::ScalarType::rand(&mut thread_rng());
let y = G::ScalarType::rand(&mut thread_rng());

let multiplier = Mul::new(x);
c.bench_function(&(name.to_string()), move |b| b.iter(|| multiplier.mul(&y)));
}

fn scale(c: &mut Criterion) {
let mut group: BenchmarkGroup<_> = c.benchmark_group("Scalar To Point Multiplication");
scale_single::<G1Element, _>("BLS12381-G1", &mut group);
scale_single::<G2Element, _>("BLS12381-G2", &mut group);
scale_single::<GTElement, _>("BLS12381-GT", &mut group);
scale_single::<RistrettoPoint, _>("Ristretto255", &mut group);
scale_single::<ProjectivePoint, _>("Secp256r1", &mut group);

scale_single_precomputed::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 16, 32, 5>,
_,
>("Secp256r1 Fixed window (16)", &mut group);
scale_single_precomputed::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 32, 32, 5>,
_,
>("Secp256r1 Fixed window (32)", &mut group);
scale_single_precomputed::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 64, 32, 5>,
_,
>("Secp256r1 Fixed window (64)", &mut group);
scale_single_precomputed::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 128, 32, 5>,
_,
>("Secp256r1 Fixed window (128)", &mut group);
scale_single_precomputed::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 256, 32, 5>,
_,
>("Secp256r1 Fixed window (256)", &mut group);
}

fn double_scale_single<G: GroupElement, Mul: ScalarMultiplier<G>, M: Measurement>(
name: &str,
c: &mut BenchmarkGroup<M>,
) {
let g1 = G::generator() * G::ScalarType::rand(&mut thread_rng());
let s1 = G::ScalarType::rand(&mut thread_rng());
let g2 = G::generator() * G::ScalarType::rand(&mut thread_rng());
let s2 = G::ScalarType::rand(&mut thread_rng());

let multiplier = Mul::new(g1);
c.bench_function(&(name.to_string()), move |b| {
b.iter(|| multiplier.two_scalar_mul(&s1, &g2, &s2))
});
}

fn double_scale(c: &mut Criterion) {
let mut group: BenchmarkGroup<_> = c.benchmark_group("Double Scalar Multiplication");

double_scale_single::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 16, 32, 5>,
_,
>("Secp256r1 Straus (16)", &mut group);
double_scale_single::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 32, 32, 5>,
_,
>("Secp256r1 Straus (32)", &mut group);
double_scale_single::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 64, 32, 5>,
_,
>("Secp256r1 Straus (64)", &mut group);
double_scale_single::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 128, 32, 5>,
_,
>("Secp256r1 Straus (128)", &mut group);
double_scale_single::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 256, 32, 5>,
_,
>("Secp256r1 Straus (256)", &mut group);
double_scale_single::<ProjectivePoint, DefaultMultiplier<ProjectivePoint>, _>(
"Secp256r1",
&mut group,
);
}

fn hash_to_group_single<G: GroupElement + HashToGroupElement, M: measurement::Measurement>(
Expand Down Expand Up @@ -76,6 +167,20 @@ mod group_benches {
pairing_single::<G1Element, _>("BLS12381-G1", &mut group);
}

/// Implementation of a `Multiplier` where scalar multiplication is done without any pre-computation by
/// simply calling the GroupElement implementation. Only used for benchmarking.
struct DefaultMultiplier<G: GroupElement>(G);

impl<G: GroupElement> ScalarMultiplier<G> for DefaultMultiplier<G> {
fn new(base_element: G) -> Self {
Self(base_element)
}

fn mul(&self, scalar: &G::ScalarType) -> G {
self.0 * scalar
}
}

criterion_group! {
name = group_benches;
config = Criterion::default().sample_size(100);
Expand All @@ -84,6 +189,7 @@ mod group_benches {
scale,
hash_to_group,
pairing,
double_scale,
}
}

Expand Down
9 changes: 9 additions & 0 deletions fastcrypto/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ use std::ops::{AddAssign, SubAssign};
pub mod bls12381;
pub mod ristretto255;

pub mod secp256r1;

pub mod multiplier;

/// Trait impl'd by elements of an additive cyclic group.
pub trait GroupElement:
Copy
Expand All @@ -36,6 +40,11 @@ pub trait GroupElement:

/// Return an instance of the generator for this group.
fn generator() -> Self;

/// Compute 2 * Self. May be overwritten by implementations that have a fast doubling operation.
fn double(&self) -> Self {
*self + self
}
}

/// Trait impl'd by scalars to be used with [GroupElement].
Expand Down
176 changes: 176 additions & 0 deletions fastcrypto/src/groups/multiplier/bgmw.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Copyright (c) 2022, Mysten Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

use crate::groups::multiplier::integer_utils::{compute_base_2w_expansion, div_ceil};
use crate::groups::multiplier::ScalarMultiplier;
use crate::groups::GroupElement;
use crate::serde_helpers::ToFromByteArray;

/// Performs scalar multiplication using a windowed method with a larger pre-computation table than
/// the one used in the `windowed` multiplier. We must have HEIGHT >= ceil(SCALAR_SIZE * 8 / ceil(log2(WIDTH))
/// where WIDTH is the window width, and the pre-computation tables will be of size WIDTH x HEIGHT.
/// Once pre-computation has been done, a scalar multiplication requires HEIGHT additions. Both `mul`
/// and `double_mul` are constant time assuming the group operations for `G` are constant time.
///
/// The algorithm used is the BGMW algorithm with base `2^WIDTH` and the basic digit set set to `0, ..., 2^WIDTH-1`.
///
/// This method is faster than the WindowedScalarMultiplier for a single multiplication, but it requires
/// a larger number of precomputed points.
pub struct BGMWScalarMultiplier<
G: GroupElement<ScalarType = S>,
S: GroupElement + ToFromByteArray<SCALAR_SIZE>,
const WIDTH: usize,
const HEIGHT: usize,
const SCALAR_SIZE: usize,
> {
/// Precomputed multiples of the base element, B, up to WIDTH x HEIGHT - 1.
cache: [[G; WIDTH]; HEIGHT],
}

impl<
G: GroupElement<ScalarType = S>,
S: GroupElement + ToFromByteArray<SCALAR_SIZE>,
const WIDTH: usize,
const HEIGHT: usize,
const SCALAR_SIZE: usize,
> BGMWScalarMultiplier<G, S, WIDTH, HEIGHT, SCALAR_SIZE>
{
/// The number of bits in the window. This is equal to the floor of the log2 of the `WIDTH`.
const WINDOW_WIDTH: usize = (usize::BITS - WIDTH.leading_zeros() - 1) as usize;

/// Get 2^{column * WINDOW_WIDTH} * row * base_point.
fn get_precomputed_multiple(&self, row: usize, column: usize) -> G {
self.cache[row][column]
}
}

impl<
G: GroupElement<ScalarType = S>,
S: GroupElement + ToFromByteArray<SCALAR_SIZE>,
const WIDTH: usize,
const HEIGHT: usize,
const SCALAR_SIZE: usize,
> ScalarMultiplier<G> for BGMWScalarMultiplier<G, S, WIDTH, HEIGHT, SCALAR_SIZE>
{
fn new(base_element: G) -> Self {
// Verify parameters
let lower_limit = div_ceil(SCALAR_SIZE * 8, Self::WINDOW_WIDTH);
if HEIGHT < lower_limit {
panic!("Invalid parameters. HEIGHT needs to be at least {} with the given WIDTH and SCALAR_SIZE.", lower_limit);
}

// Store cache[i][j] = 2^{i w} * j * base_element
let mut cache = [[G::zero(); WIDTH]; HEIGHT];

// Compute cache[0][j] = j * base_element.
for j in 1..WIDTH {
cache[0][j] = cache[0][j - 1] + base_element;
}

// Compute cache[i][j] = 2^w * cache[i-1][j] for i > 0.
for i in 1..HEIGHT {
for j in 0..WIDTH {
cache[i][j] = cache[i - 1][j];
for _ in 0..Self::WINDOW_WIDTH {
cache[i][j] = cache[i][j].double();
}
}
}
Self { cache }
}

fn mul(&self, scalar: &S) -> G {
// Scalar as bytes in little-endian representation.
let scalar_bytes = scalar.to_byte_array();

let base_2w_expansion =
compute_base_2w_expansion::<SCALAR_SIZE>(&scalar_bytes, Self::WINDOW_WIDTH);

let mut result = self.get_precomputed_multiple(0, base_2w_expansion[0]);
for (i, digit) in base_2w_expansion.iter().enumerate().skip(1) {
result += self.get_precomputed_multiple(i, *digit);
}
result
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::groups::ristretto255::{RistrettoPoint, RistrettoScalar};
use crate::groups::secp256r1::{ProjectivePoint, Scalar};
use ark_ff::{BigInteger, PrimeField};
use ark_secp256r1::Fr;

#[test]
fn test_scalar_multiplication_ristretto() {
let multiplier = BGMWScalarMultiplier::<RistrettoPoint, RistrettoScalar, 16, 64, 32>::new(
RistrettoPoint::generator(),
);

let scalars = [
RistrettoScalar::from(0),
RistrettoScalar::from(1),
RistrettoScalar::from(2),
RistrettoScalar::from(1234),
RistrettoScalar::from(123456),
RistrettoScalar::from(123456789),
RistrettoScalar::from(0xffffffffffffffff),
RistrettoScalar::group_order(),
RistrettoScalar::group_order() - RistrettoScalar::from(1),
RistrettoScalar::group_order() + RistrettoScalar::from(1),
];

for scalar in scalars {
let expected = RistrettoPoint::generator() * scalar;
let actual = multiplier.mul(&scalar);
assert_eq!(expected, actual);
}
}

#[test]
fn test_scalar_multiplication_secp256r1() {
let mut modulus_minus_one = Fr::MODULUS_MINUS_ONE_DIV_TWO;
modulus_minus_one.mul2();
let scalars = [
Scalar::from(0),
Scalar::from(1),
Scalar::from(2),
Scalar::from(1234),
Scalar::from(123456),
Scalar::from(123456789),
Scalar::from(0xffffffffffffffff),
Scalar(Fr::from(modulus_minus_one)),
];

for scalar in scalars {
let expected = ProjectivePoint::generator() * scalar;

let multiplier = BGMWScalarMultiplier::<ProjectivePoint, Scalar, 16, 64, 32>::new(
ProjectivePoint::generator(),
);
let actual = multiplier.mul(&scalar);
assert_eq!(expected, actual);

let multiplier = BGMWScalarMultiplier::<ProjectivePoint, Scalar, 32, 52, 32>::new(
ProjectivePoint::generator(),
);
let actual = multiplier.mul(&scalar);
assert_eq!(expected, actual);

let multiplier = BGMWScalarMultiplier::<ProjectivePoint, Scalar, 64, 43, 32>::new(
ProjectivePoint::generator(),
);
let actual = multiplier.mul(&scalar);
assert_eq!(expected, actual);
}

// Assert a panic due to setting the HEIGHT too small
assert!(std::panic::catch_unwind(|| {
BGMWScalarMultiplier::<ProjectivePoint, Scalar, 16, 63, 32>::new(
ProjectivePoint::generator(),
)
})
.is_err());
}
}
Loading

0 comments on commit ed1c2e1

Please sign in to comment.