diff --git a/Cargo.lock b/Cargo.lock index 4d4e80f573..52c9fdb428 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1555,6 +1555,7 @@ dependencies = [ "hex-literal", "hkdf", "k256", + "lazy_static", "merlin", "once_cell", "p256", diff --git a/fastcrypto/Cargo.toml b/fastcrypto/Cargo.toml index b5a9748756..702200868c 100644 --- a/fastcrypto/Cargo.toml +++ b/fastcrypto/Cargo.toml @@ -58,6 +58,7 @@ ark-secp256r1 = "0.4.0" ark-ec = "0.4.1" ark-ff = "0.4.1" ark-serialize = "0.4.1" +lazy_static = "1.4.0" fastcrypto-derive = { path = "../fastcrypto-derive", version = "0.1.3" } diff --git a/fastcrypto/benches/groups.rs b/fastcrypto/benches/groups.rs index 0504f0eedd..dc244d654d 100644 --- a/fastcrypto/benches/groups.rs +++ b/fastcrypto/benches/groups.rs @@ -4,10 +4,14 @@ extern crate criterion; mod group_benches { + use criterion::measurement::Measurement; use criterion::{measurement, BenchmarkGroup, Criterion}; use fastcrypto::groups::bls12381::{G1Element, G2Element, GTElement}; + use fastcrypto::groups::multiplier::windowed::WindowedScalarMultiplier; + use fastcrypto::groups::multiplier::ScalarMultiplier; use fastcrypto::groups::ristretto255::RistrettoPoint; - use fastcrypto::groups::{GroupElement, HashToGroupElement, Pairing, Scalar}; + use fastcrypto::groups::secp256r1::ProjectivePoint; + use fastcrypto::groups::{secp256r1, GroupElement, HashToGroupElement, Pairing, Scalar}; use rand::thread_rng; fn add_single( @@ -36,12 +40,99 @@ mod group_benches { c.bench_function(&(name.to_string()), move |b| b.iter(|| x * y)); } + fn scale_single_precomputed, M: Measurement>( + name: &str, + c: &mut BenchmarkGroup, + ) { + let x = G::generator() * G::ScalarType::rand(&mut thread_rng()); + let y = G::ScalarType::rand(&mut thread_rng()); + + let multiplier = Mul::new(x); + c.bench_function(&(name.to_string()), move |b| b.iter(|| multiplier.mul(&y))); + } + fn scale(c: &mut Criterion) { let mut group: BenchmarkGroup<_> = c.benchmark_group("Scalar To Point Multiplication"); scale_single::("BLS12381-G1", &mut group); scale_single::("BLS12381-G2", &mut group); scale_single::("BLS12381-GT", &mut group); scale_single::("Ristretto255", &mut group); + scale_single::("Secp256r1", &mut group); + + scale_single_precomputed::< + ProjectivePoint, + WindowedScalarMultiplier, + _, + >("Secp256r1 Fixed window (16)", &mut group); + scale_single_precomputed::< + ProjectivePoint, + WindowedScalarMultiplier, + _, + >("Secp256r1 Fixed window (32)", &mut group); + scale_single_precomputed::< + ProjectivePoint, + WindowedScalarMultiplier, + _, + >("Secp256r1 Fixed window (64)", &mut group); + scale_single_precomputed::< + ProjectivePoint, + WindowedScalarMultiplier, + _, + >("Secp256r1 Fixed window (128)", &mut group); + scale_single_precomputed::< + ProjectivePoint, + WindowedScalarMultiplier, + _, + >("Secp256r1 Fixed window (256)", &mut group); + } + + fn double_scale_single, M: Measurement>( + name: &str, + c: &mut BenchmarkGroup, + ) { + let g1 = G::generator() * G::ScalarType::rand(&mut thread_rng()); + let s1 = G::ScalarType::rand(&mut thread_rng()); + let g2 = G::generator() * G::ScalarType::rand(&mut thread_rng()); + let s2 = G::ScalarType::rand(&mut thread_rng()); + + let multiplier = Mul::new(g1); + c.bench_function(&(name.to_string()), move |b| { + b.iter(|| multiplier.two_scalar_mul(&s1, &g2, &s2)) + }); + } + + fn double_scale(c: &mut Criterion) { + let mut group: BenchmarkGroup<_> = c.benchmark_group("Double Scalar Multiplication"); + + double_scale_single::< + ProjectivePoint, + WindowedScalarMultiplier, + _, + >("Secp256r1 Straus (16)", &mut group); + double_scale_single::< + ProjectivePoint, + WindowedScalarMultiplier, + _, + >("Secp256r1 Straus (32)", &mut group); + double_scale_single::< + ProjectivePoint, + WindowedScalarMultiplier, + _, + >("Secp256r1 Straus (64)", &mut group); + double_scale_single::< + ProjectivePoint, + WindowedScalarMultiplier, + _, + >("Secp256r1 Straus (128)", &mut group); + double_scale_single::< + ProjectivePoint, + WindowedScalarMultiplier, + _, + >("Secp256r1 Straus (256)", &mut group); + double_scale_single::, _>( + "Secp256r1", + &mut group, + ); } fn hash_to_group_single( @@ -76,6 +167,20 @@ mod group_benches { pairing_single::("BLS12381-G1", &mut group); } + /// Implementation of a `Multiplier` where scalar multiplication is done without any pre-computation by + /// simply calling the GroupElement implementation. Only used for benchmarking. + struct DefaultMultiplier(G); + + impl ScalarMultiplier for DefaultMultiplier { + fn new(base_element: G) -> Self { + Self(base_element) + } + + fn mul(&self, scalar: &G::ScalarType) -> G { + self.0 * scalar + } + } + criterion_group! { name = group_benches; config = Criterion::default().sample_size(100); @@ -84,6 +189,7 @@ mod group_benches { scale, hash_to_group, pairing, + double_scale, } } diff --git a/fastcrypto/src/groups/mod.rs b/fastcrypto/src/groups/mod.rs index b10df81bfb..82f578765e 100644 --- a/fastcrypto/src/groups/mod.rs +++ b/fastcrypto/src/groups/mod.rs @@ -10,6 +10,10 @@ use std::ops::{AddAssign, SubAssign}; pub mod bls12381; pub mod ristretto255; +pub mod secp256r1; + +pub mod multiplier; + /// Trait impl'd by elements of an additive cyclic group. pub trait GroupElement: Copy @@ -36,6 +40,11 @@ pub trait GroupElement: /// Return an instance of the generator for this group. fn generator() -> Self; + + /// Compute 2 * Self. May be overwritten by implementations that have a fast doubling operation. + fn double(&self) -> Self { + *self + self + } } /// Trait impl'd by scalars to be used with [GroupElement]. diff --git a/fastcrypto/src/groups/multiplier/bgmw.rs b/fastcrypto/src/groups/multiplier/bgmw.rs new file mode 100644 index 0000000000..2229401744 --- /dev/null +++ b/fastcrypto/src/groups/multiplier/bgmw.rs @@ -0,0 +1,176 @@ +// Copyright (c) 2022, Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use crate::groups::multiplier::integer_utils::{compute_base_2w_expansion, div_ceil}; +use crate::groups::multiplier::ScalarMultiplier; +use crate::groups::GroupElement; +use crate::serde_helpers::ToFromByteArray; + +/// Performs scalar multiplication using a windowed method with a larger pre-computation table than +/// the one used in the `windowed` multiplier. We must have HEIGHT >= ceil(SCALAR_SIZE * 8 / ceil(log2(WIDTH)) +/// where WIDTH is the window width, and the pre-computation tables will be of size WIDTH x HEIGHT. +/// Once pre-computation has been done, a scalar multiplication requires HEIGHT additions. Both `mul` +/// and `double_mul` are constant time assuming the group operations for `G` are constant time. +/// +/// The algorithm used is the BGMW algorithm with base `2^WIDTH` and the basic digit set set to `0, ..., 2^WIDTH-1`. +/// +/// This method is faster than the WindowedScalarMultiplier for a single multiplication, but it requires +/// a larger number of precomputed points. +pub struct BGMWScalarMultiplier< + G: GroupElement, + S: GroupElement + ToFromByteArray, + const WIDTH: usize, + const HEIGHT: usize, + const SCALAR_SIZE: usize, +> { + /// Precomputed multiples of the base element, B, up to WIDTH x HEIGHT - 1. + cache: [[G; WIDTH]; HEIGHT], +} + +impl< + G: GroupElement, + S: GroupElement + ToFromByteArray, + const WIDTH: usize, + const HEIGHT: usize, + const SCALAR_SIZE: usize, + > BGMWScalarMultiplier +{ + /// The number of bits in the window. This is equal to the floor of the log2 of the `WIDTH`. + const WINDOW_WIDTH: usize = (usize::BITS - WIDTH.leading_zeros() - 1) as usize; + + /// Get 2^{column * WINDOW_WIDTH} * row * base_point. + fn get_precomputed_multiple(&self, row: usize, column: usize) -> G { + self.cache[row][column] + } +} + +impl< + G: GroupElement, + S: GroupElement + ToFromByteArray, + const WIDTH: usize, + const HEIGHT: usize, + const SCALAR_SIZE: usize, + > ScalarMultiplier for BGMWScalarMultiplier +{ + fn new(base_element: G) -> Self { + // Verify parameters + let lower_limit = div_ceil(SCALAR_SIZE * 8, Self::WINDOW_WIDTH); + if HEIGHT < lower_limit { + panic!("Invalid parameters. HEIGHT needs to be at least {} with the given WIDTH and SCALAR_SIZE.", lower_limit); + } + + // Store cache[i][j] = 2^{i w} * j * base_element + let mut cache = [[G::zero(); WIDTH]; HEIGHT]; + + // Compute cache[0][j] = j * base_element. + for j in 1..WIDTH { + cache[0][j] = cache[0][j - 1] + base_element; + } + + // Compute cache[i][j] = 2^w * cache[i-1][j] for i > 0. + for i in 1..HEIGHT { + for j in 0..WIDTH { + cache[i][j] = cache[i - 1][j]; + for _ in 0..Self::WINDOW_WIDTH { + cache[i][j] = cache[i][j].double(); + } + } + } + Self { cache } + } + + fn mul(&self, scalar: &S) -> G { + // Scalar as bytes in little-endian representation. + let scalar_bytes = scalar.to_byte_array(); + + let base_2w_expansion = + compute_base_2w_expansion::(&scalar_bytes, Self::WINDOW_WIDTH); + + let mut result = self.get_precomputed_multiple(0, base_2w_expansion[0]); + for (i, digit) in base_2w_expansion.iter().enumerate().skip(1) { + result += self.get_precomputed_multiple(i, *digit); + } + result + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::groups::ristretto255::{RistrettoPoint, RistrettoScalar}; + use crate::groups::secp256r1::{ProjectivePoint, Scalar}; + use ark_ff::{BigInteger, PrimeField}; + use ark_secp256r1::Fr; + + #[test] + fn test_scalar_multiplication_ristretto() { + let multiplier = BGMWScalarMultiplier::::new( + RistrettoPoint::generator(), + ); + + let scalars = [ + RistrettoScalar::from(0), + RistrettoScalar::from(1), + RistrettoScalar::from(2), + RistrettoScalar::from(1234), + RistrettoScalar::from(123456), + RistrettoScalar::from(123456789), + RistrettoScalar::from(0xffffffffffffffff), + RistrettoScalar::group_order(), + RistrettoScalar::group_order() - RistrettoScalar::from(1), + RistrettoScalar::group_order() + RistrettoScalar::from(1), + ]; + + for scalar in scalars { + let expected = RistrettoPoint::generator() * scalar; + let actual = multiplier.mul(&scalar); + assert_eq!(expected, actual); + } + } + + #[test] + fn test_scalar_multiplication_secp256r1() { + let mut modulus_minus_one = Fr::MODULUS_MINUS_ONE_DIV_TWO; + modulus_minus_one.mul2(); + let scalars = [ + Scalar::from(0), + Scalar::from(1), + Scalar::from(2), + Scalar::from(1234), + Scalar::from(123456), + Scalar::from(123456789), + Scalar::from(0xffffffffffffffff), + Scalar(Fr::from(modulus_minus_one)), + ]; + + for scalar in scalars { + let expected = ProjectivePoint::generator() * scalar; + + let multiplier = BGMWScalarMultiplier::::new( + ProjectivePoint::generator(), + ); + let actual = multiplier.mul(&scalar); + assert_eq!(expected, actual); + + let multiplier = BGMWScalarMultiplier::::new( + ProjectivePoint::generator(), + ); + let actual = multiplier.mul(&scalar); + assert_eq!(expected, actual); + + let multiplier = BGMWScalarMultiplier::::new( + ProjectivePoint::generator(), + ); + let actual = multiplier.mul(&scalar); + assert_eq!(expected, actual); + } + + // Assert a panic due to setting the HEIGHT too small + assert!(std::panic::catch_unwind(|| { + BGMWScalarMultiplier::::new( + ProjectivePoint::generator(), + ) + }) + .is_err()); + } +} diff --git a/fastcrypto/src/groups/multiplier/integer_utils.rs b/fastcrypto/src/groups/multiplier/integer_utils.rs new file mode 100644 index 0000000000..34efd35317 --- /dev/null +++ b/fastcrypto/src/groups/multiplier/integer_utils.rs @@ -0,0 +1,156 @@ +// Copyright (c) 2022, Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +/// Given a binary representation of a number in little-endian format, return the digits of its base +/// `2^bits_per_digit` expansion. +pub fn compute_base_2w_expansion( + bytes: &[u8; N], + bits_per_digit: usize, +) -> Vec { + assert!(0 < bits_per_digit && bits_per_digit <= usize::BITS as usize); + + // The base 2^window_size expansions digits in little-endian representation. + let mut digits = Vec::new(); + + // Compute the number of digits needed to represent the numbed in base 2^w. This is equal to + // ceil(8*N / window_size), and we compute like this because div_ceil is unstable as of rustc 1.69.0. + let digits_count = div_ceil(8 * N, bits_per_digit); + + for i in 0..digits_count { + digits.push(get_bits_from_bytes( + bytes, + bits_per_digit * i, + bits_per_digit * (i + 1), + )); + } + digits +} + +/// Get the integer represented by a given range of bits of a byte from start to end (exclusive). +#[inline] +fn get_lendian_from_substring(byte: &u8, start: usize, end: usize) -> u8 { + assert!(start <= end); + byte >> start & ((1 << (end - start)) - 1) as u8 +} + +/// Compute ceil(numerator / denominator). +pub(crate) fn div_ceil(numerator: usize, denominator: usize) -> usize { + (numerator + denominator - 1) / denominator +} + +/// Get the integer represented by a given range of bits of a an integer represented by a little-endian +/// byte array from start to end (exclusive). The `end` argument may be arbitrarily large, but if it +/// is larger than 8*N, the remaining bits of the byte array will be assumed to be zero. +#[inline] +pub fn get_bits_from_bytes(bytes: &[u8; N], start: usize, end: usize) -> usize { + assert!(start <= end && start < 8 * N); + + let mut result: usize = 0; + let mut bits_added = 0; + + let mut current_bit = start % 8; + let mut current_byte = start / 8; + + while bits_added < end - start && current_byte < N { + let remaining_bits = end - start - bits_added; + let (bits_to_read, next_byte, next_bit) = if remaining_bits < 8 - current_bit { + // There are enough bits left in the current byte + (remaining_bits, current_byte, current_bit + remaining_bits) + } else { + // There are not enough bits in the current byte. Take the remaining bits and increment the byte index + (8 - current_bit, current_byte + 1, 0) + }; + + // Add the bits to the result + result += (get_lendian_from_substring( + &bytes[current_byte], + current_bit, + current_bit + bits_to_read, + ) as usize) + << bits_added; + + // Increment the counters + bits_added += bits_to_read; + current_bit = next_bit; + current_byte = next_byte; + } + result +} + +/// Return true iff the bit at the given index is set. +#[inline] +pub fn test_bit(bytes: &[u8; N], index: usize) -> bool { + assert!(index < 8 * N); + let byte = index >> 3; + let shifted = bytes[byte] >> (index & 7); + shifted & 1 != 0 +} + +/// Compute the floor of the base-2 logarithm of x. +pub const fn log2(x: usize) -> usize { + (usize::BITS - x.leading_zeros() - 1) as usize +} + +#[cfg(test)] +mod tests { + use super::*; + use std::assert_eq; + + #[test] + fn test_lendian_from_substring() { + let byte = 0b00000001; + assert_eq!(0, get_lendian_from_substring(&byte, 0, 0)); + assert_eq!(1, get_lendian_from_substring(&byte, 0, 1)); + assert_eq!(1, get_lendian_from_substring(&byte, 0, 3)); + assert_eq!(1, get_lendian_from_substring(&byte, 0, 8)); + assert_eq!(0, get_lendian_from_substring(&byte, 1, 8)); + + let byte = 0b00000011; + assert_eq!(1, get_lendian_from_substring(&byte, 0, 1)); + assert_eq!(3, get_lendian_from_substring(&byte, 0, 2)); + assert_eq!(3, get_lendian_from_substring(&byte, 0, 3)); + assert_eq!(1, get_lendian_from_substring(&byte, 1, 8)); + assert_eq!(0, get_lendian_from_substring(&byte, 2, 8)); + + let byte = 0b10000001; + assert_eq!(1, get_lendian_from_substring(&byte, 0, 1)); + assert_eq!(1, get_lendian_from_substring(&byte, 0, 7)); + assert_eq!(129, get_lendian_from_substring(&byte, 0, 8)); + assert_eq!(64, get_lendian_from_substring(&byte, 1, 8)); + assert_eq!(16, get_lendian_from_substring(&byte, 3, 8)); + } + + #[test] + fn test_base_2w_expansion() { + let value: u128 = 123812341234567; + let bytes = value.to_le_bytes(); + + // Is w = 8, the base 2^w expansion should be equal to the le bytes. + let expansion = compute_base_2w_expansion::<16>(&bytes, 8); + assert_eq!( + bytes.to_vec(), + expansion.iter().map(|x| *x as u8).collect::>() + ); + + // Verify that the expansion is correct for w = 1, ..., 64 + for window_size in 1..=64 { + let expansion = compute_base_2w_expansion::<16>(&bytes, window_size); + let mut sum = 0u128; + for (i, value) in expansion.iter().enumerate() { + sum += (1 << (window_size * i)) * *value as u128; + } + assert_eq!(value, sum); + } + } + + #[test] + fn test_bits_form_bytes() { + let bytes = [0b00000001, 0b00000011, 0b10000001]; + assert_eq!(0, get_bits_from_bytes(&bytes, 0, 0)); + assert_eq!(1, get_bits_from_bytes(&bytes, 0, 1)); + assert_eq!(3, get_bits_from_bytes(&bytes, 8, 10)); + assert_eq!(1, get_bits_from_bytes(&bytes, 16, 17)); + assert_eq!(0, get_bits_from_bytes(&bytes, 17, 23)); + assert_eq!(1, get_bits_from_bytes(&bytes, 23, 100)); + } +} diff --git a/fastcrypto/src/groups/multiplier/mod.rs b/fastcrypto/src/groups/multiplier/mod.rs new file mode 100644 index 0000000000..7a2295c824 --- /dev/null +++ b/fastcrypto/src/groups/multiplier/mod.rs @@ -0,0 +1,32 @@ +// Copyright (c) 2022, Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +//! This module contains implementations of optimised scalar multiplication algorithms where the +//! group element is fixed and certain multiples of this may be pre-computed. + +use crate::groups::GroupElement; + +#[cfg(feature = "experimental")] +pub mod bgmw; +mod integer_utils; +pub mod windowed; + +/// Trait for scalar multiplication for a fixed group element, e.g. by using precomputed values. +pub trait ScalarMultiplier { + /// Create a new scalar multiplier with the given base element. + fn new(base_element: G) -> Self; + + /// Compute `self.base_element * scalar`. + fn mul(&self, scalar: &G::ScalarType) -> G; + + /// Compute `self.base_element * base_scalar + other_element * other_scalar`. + fn two_scalar_mul( + &self, + base_scalar: &G::ScalarType, + other_element: &G, + other_scalar: &G::ScalarType, + ) -> G { + // The default implementation. May be overwritten by implementations that allow optimised double multiplication. + self.mul(base_scalar) + *other_element * other_scalar + } +} diff --git a/fastcrypto/src/groups/multiplier/windowed.rs b/fastcrypto/src/groups/multiplier/windowed.rs new file mode 100644 index 0000000000..20e1df4b68 --- /dev/null +++ b/fastcrypto/src/groups/multiplier/windowed.rs @@ -0,0 +1,334 @@ +// Copyright (c) 2022, Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::HashMap; +use std::iter::successors; + +use crate::groups::multiplier::integer_utils::{get_bits_from_bytes, test_bit}; +use crate::groups::multiplier::{integer_utils, ScalarMultiplier}; +use crate::groups::GroupElement; +use crate::serde_helpers::ToFromByteArray; + +/// This scalar multiplier uses pre-computation with the windowed method. This multiplier is particularly +/// fast for double multiplications, where a sliding window method is used, but this implies that the +/// `double_mul`, is NOT constant time. However, the single multiplication method `mul` is constant +/// time if the group operations for `G` are constant time. +/// +/// The `CACHE_SIZE` should be a power of two. The `SCALAR_SIZE` is the number of bytes in the byte +/// representation of the scalar type `S`, and we assume that the `S::to_byte_array` method returns +/// the scalar in little-endian format. +/// +/// The `SLIDING_WINDOW_WIDTH` is the number of bits in the sliding window of the elements not already +/// with precomputed multiples. This should be approximately log2(sqrt(SCALAR_SIZE_IN_BITS)) + 1 for +/// optimal performance. +pub struct WindowedScalarMultiplier< + G: GroupElement, + S: GroupElement + ToFromByteArray, + const CACHE_SIZE: usize, + const SCALAR_SIZE: usize, + const SLIDING_WINDOW_WIDTH: usize, +> { + /// Precomputed multiples of the base element from 0 up to CACHE_SIZE - 1 = 2^WINDOW_WIDTH - 1. + cache: [G; CACHE_SIZE], +} + +impl< + G: GroupElement, + S: GroupElement + ToFromByteArray, + const CACHE_SIZE: usize, + const SCALAR_SIZE: usize, + const SLIDING_WINDOW_WIDTH: usize, + > WindowedScalarMultiplier +{ + /// The number of bits in the window. This is equal to the floor of the log2 of the cache size. + const WINDOW_WIDTH: usize = integer_utils::log2(CACHE_SIZE); +} + +impl< + G: GroupElement, + S: GroupElement + ToFromByteArray, + const CACHE_SIZE: usize, + const SCALAR_SIZE: usize, + const SLIDING_WINDOW_WIDTH: usize, + > ScalarMultiplier + for WindowedScalarMultiplier +{ + fn new(base_element: G) -> Self { + let mut cache = [G::zero(); CACHE_SIZE]; + cache[1] = base_element; + for i in 2..CACHE_SIZE { + cache[i] = cache[i - 1] + base_element; + } + Self { cache } + } + + fn mul(&self, scalar: &S) -> G { + // Scalar as bytes in little-endian representation. + let scalar_bytes = scalar.to_byte_array(); + + let base_2w_expansion = integer_utils::compute_base_2w_expansion::( + &scalar_bytes, + Self::WINDOW_WIDTH, + ); + + // Computer multiplication using the fixed-window method to ensure that it's constant time. + let mut result: G = self.cache[base_2w_expansion[base_2w_expansion.len() - 1]]; + for digit in base_2w_expansion.iter().rev().skip(1) { + for _ in 1..=Self::WINDOW_WIDTH { + result = result.double(); + } + result += self.cache[*digit]; + } + result + } + + fn two_scalar_mul( + &self, + base_scalar: &G::ScalarType, + other_element: &G, + other_scalar: &G::ScalarType, + ) -> G { + // Compute the sum of the two multiples using Straus' algorithm combined with a sliding window algorithm. + multi_scalar_mul( + &[*base_scalar, *other_scalar], + &[self.cache[1], *other_element], + &HashMap::from([(0, self.cache[CACHE_SIZE / 2..CACHE_SIZE].to_vec())]), + SLIDING_WINDOW_WIDTH, + ) + } +} + +/// This method computes the linear combination of the given scalars and group elements using the +/// sliding window method. Some group elements may have tables of precomputed elements which can +/// be given in the `precomputed` hash map. For the elements which does not have a precomputed table +/// a table of size 2default_window_width - 1 is computed. +/// +/// The precomputed tables for an element g should contain the multiples 2w-1 g +/// , ..., (2w - 1) g for some integer w > 1 which is the window width for the +/// given element. +/// +/// The `default_window_width` is the window width for the elements that does not have a precomputation +/// table and may be set to any value >= 1. As rule-of-thumb, this should be set to approximately +/// the bit length of the square root of the scalar size for optimal performance. +pub fn multi_scalar_mul< + G: GroupElement, + S: GroupElement + ToFromByteArray, + const SCALAR_SIZE: usize, + const N: usize, +>( + scalars: &[G::ScalarType; N], + elements: &[G; N], + precomputed_multiples: &HashMap>, + default_window_width: usize, +) -> G { + let mut window_sizes = [0usize; N]; + + // Compute missing precomputation tables. + let mut missing_precomputations = HashMap::new(); + for (i, element) in elements.iter().enumerate() { + if !precomputed_multiples.contains_key(&i) { + missing_precomputations.insert(i, compute_multiples(element, default_window_width)); + } + } + + // Create vector with all precomputation tables. + let mut all_precomputed_multiples = vec![]; + for i in 0..N { + match precomputed_multiples.get(&i).take() { + Some(precomputed_multiples) => { + all_precomputed_multiples.push(precomputed_multiples); + window_sizes[i] = integer_utils::log2(all_precomputed_multiples[i].len()) + 1; + } + None => { + all_precomputed_multiples.push(&missing_precomputations[&i]); + window_sizes[i] = default_window_width; + } + } + } + + // Compute little-endian byte representations of scalars. + let scalar_bytes = scalars + .iter() + .map(|s| s.to_byte_array()) + .collect::>(); + + // We iterate from the top bit and down for all scalars until we reach a set bit. This marks the + // beginning of a window, and we continue the iteration. When the iterations exists the window, + // we add the corresponding precomputed value and keeps iterating until the next one bit is found + // which marks the beginning of the next window. + let mut is_in_window = [false; N]; + let mut index_in_window = [0usize; N]; // Counter for the current window + let mut precomputed_multiple_index = [0usize; N]; + + // We may skip doubling until result is non-zero. + let mut is_zero = true; + let mut result = G::zero(); + + // Iterate through all bits of the scalars from the top. + for bit in (0..SCALAR_SIZE * 8).rev() { + if !is_zero { + result = result.double(); + } + for i in 0..N { + if is_in_window[i] { + // A window has been set for this scalar. Keep iterating until the window is finished. + index_in_window[i] += 1; + if index_in_window[i] == window_sizes[i] { + // This window is finished. Add the right precomputed value and indicate that we are ready for a new window. + result = if is_zero { + is_zero = false; + all_precomputed_multiples[i][precomputed_multiple_index[i]] + } else { + result + all_precomputed_multiples[i][precomputed_multiple_index[i]] + }; + is_in_window[i] = false; + } + } else if test_bit(&scalar_bytes[i], bit) { + // The iteration has reached a set bit for the i'th scalar. + if bit >= window_sizes[i] - 1 { + // There is enough room for a window. Set indicator and reset window index. + is_in_window[i] = true; + index_in_window[i] = 1; + precomputed_multiple_index[i] = get_bits_from_bytes( + &scalar_bytes[i], + bit + 1 - window_sizes[i], + bit, // The last bit is always one, so we ignore it and only precompute the upper half of the first 2^window_sizes multiples. + ); + } else { + // There is not enough room left for a window. Continue with regular double-and-add. + result = if is_zero { + is_zero = false; + elements[i] + } else { + result + elements[i] + }; + } + } + } + } + result +} + +/// Compute multiples 2w-1 base_element, (2w-1 + 1) base_element, ..., (2w - 1) base_element. +fn compute_multiples(base_element: &G, window_size: usize) -> Vec { + assert!(window_size > 0, "Window size must be strictly positive."); + let mut smallest_multiple = base_element.double(); + for _ in 2..window_size { + smallest_multiple = smallest_multiple.double(); + } + successors(Some(smallest_multiple), |g| Some(*g + base_element)) + .take(1 << (window_size - 1)) + .collect::>() +} + +#[cfg(test)] +mod tests { + use ark_ff::{BigInteger, PrimeField}; + use ark_secp256r1::Fr; + use rand::thread_rng; + + use crate::groups::ristretto255::{RistrettoPoint, RistrettoScalar}; + use crate::groups::secp256r1::{ProjectivePoint, Scalar}; + use crate::groups::Scalar as ScalarTrait; + + use super::*; + + #[test] + fn test_scalar_multiplication_ristretto() { + let multiplier = + WindowedScalarMultiplier::::new( + RistrettoPoint::generator(), + ); + + let scalars = [ + RistrettoScalar::from(0), + RistrettoScalar::from(1), + RistrettoScalar::from(2), + RistrettoScalar::from(1234), + RistrettoScalar::from(123456), + RistrettoScalar::from(123456789), + RistrettoScalar::from(0xffffffffffffffff), + RistrettoScalar::group_order(), + RistrettoScalar::group_order() - RistrettoScalar::from(1), + RistrettoScalar::group_order() + RistrettoScalar::from(1), + ]; + + for scalar in scalars { + let expected = RistrettoPoint::generator() * scalar; + let actual = multiplier.mul(&scalar); + assert_eq!(expected, actual); + } + } + + #[test] + fn test_scalar_multiplication_secp256r1() { + let mut modulus_minus_one = Fr::MODULUS_MINUS_ONE_DIV_TWO; + modulus_minus_one.mul2(); + let scalars = [ + Scalar::from(0), + Scalar::from(1), + Scalar::from(2), + Scalar::from(1234), + Scalar::from(123456), + Scalar::from(123456789), + Scalar::from(0xffffffffffffffff), + Scalar(Fr::from(modulus_minus_one)), + ]; + + for scalar in scalars { + let expected = ProjectivePoint::generator() * scalar; + + let multiplier = WindowedScalarMultiplier::::new( + ProjectivePoint::generator(), + ); + let actual = multiplier.mul(&scalar); + assert_eq!(expected, actual); + + let multiplier = WindowedScalarMultiplier::::new( + ProjectivePoint::generator(), + ); + let actual = multiplier.mul(&scalar); + assert_eq!(expected, actual); + + let multiplier = WindowedScalarMultiplier::::new( + ProjectivePoint::generator(), + ); + let actual = multiplier.mul(&scalar); + assert_eq!(expected, actual); + + let multiplier = WindowedScalarMultiplier::::new( + ProjectivePoint::generator(), + ); + let actual = multiplier.mul(&scalar); + assert_eq!(expected, actual); + + let multiplier = WindowedScalarMultiplier::::new( + ProjectivePoint::generator(), + ); + let actual = multiplier.mul(&scalar); + assert_eq!(expected, actual); + + let multiplier = WindowedScalarMultiplier::::new( + ProjectivePoint::generator(), + ); + let actual = multiplier.mul(&scalar); + assert_eq!(expected, actual); + } + } + + #[test] + fn test_double_mul_ristretto() { + let multiplier = + WindowedScalarMultiplier::::new( + RistrettoPoint::generator(), + ); + + let other_point = RistrettoPoint::generator() * RistrettoScalar::from(3); + + let a = RistrettoScalar::rand(&mut thread_rng()); + let b = RistrettoScalar::rand(&mut thread_rng()); + let expected = RistrettoPoint::generator() * a + other_point * b; + let actual = multiplier.two_scalar_mul(&a, &other_point, &b); + assert_eq!(expected, actual); + } +} diff --git a/fastcrypto/src/groups/secp256r1.rs b/fastcrypto/src/groups/secp256r1.rs new file mode 100644 index 0000000000..d9bb463ddc --- /dev/null +++ b/fastcrypto/src/groups/secp256r1.rs @@ -0,0 +1,123 @@ +// Copyright (c) 2022, Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +//! Implementation of the Secp256r1 (aka P-256) curve. This is a 256-bit Weirstrass curve of prime order. +//! See "SEC 2: Recommended Elliptic Curve Domain Parameters" for details." + +use crate::error::{FastCryptoError, FastCryptoResult}; +use crate::groups::{GroupElement, Scalar as ScalarTrait}; +use crate::serde_helpers::ToFromByteArray; +use crate::traits::AllowedRng; +use ark_ec::Group; +use ark_ff::{Field, One, PrimeField, Zero}; +use ark_secp256r1::{Fr, Projective}; +use ark_serialize::CanonicalSerialize; +use derive_more::{Add, From, Neg, Sub}; +use fastcrypto_derive::GroupOpsExtend; +use std::ops::{Div, Mul}; + +pub const SCALAR_SIZE_IN_BYTES: usize = 32; + +/// A point on the Secp256r1 curve in projective coordinates. +#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, From, Add, Sub, Neg, GroupOpsExtend)] +pub struct ProjectivePoint(pub(crate) Projective); + +impl GroupElement for ProjectivePoint { + type ScalarType = Scalar; + + fn zero() -> Self { + Self(Projective::zero()) + } + + fn generator() -> Self { + Self(Projective::generator()) + } + + fn double(&self) -> Self { + ProjectivePoint::from(self.0.double()) + } +} + +impl Mul for ProjectivePoint { + type Output = ProjectivePoint; + + fn mul(self, rhs: Scalar) -> ProjectivePoint { + ProjectivePoint::from(self.0 * rhs.0) + } +} + +#[allow(clippy::suspicious_arithmetic_impl)] +impl Div for ProjectivePoint { + type Output = Result; + + fn div(self, rhs: Scalar) -> Result { + Ok(self * rhs.inverse()?) + } +} + +/// A field element in the prime field of the same order as the curve. +#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, From, Add, Sub, Neg, GroupOpsExtend)] +pub struct Scalar(pub(crate) Fr); + +impl GroupElement for Scalar { + type ScalarType = Scalar; + + fn zero() -> Self { + Scalar(Fr::zero()) + } + + fn generator() -> Self { + Scalar(Fr::one()) + } +} + +impl Mul for Scalar { + type Output = Scalar; + + fn mul(self, rhs: Scalar) -> Self::Output { + Scalar(self.0 * rhs.0) + } +} + +#[allow(clippy::suspicious_arithmetic_impl)] +impl Div for Scalar { + type Output = Result; + + fn div(self, rhs: Scalar) -> Result { + Ok(self * rhs.inverse()?) + } +} + +impl From for Scalar { + fn from(value: u64) -> Self { + Scalar(Fr::from(value)) + } +} + +impl ScalarTrait for Scalar { + fn rand(rng: &mut R) -> Self { + let mut bytes = [0u8; SCALAR_SIZE_IN_BYTES]; + rng.fill_bytes(&mut bytes); + Scalar(Fr::from_be_bytes_mod_order(&bytes)) + } + + fn inverse(&self) -> FastCryptoResult { + Ok(Scalar( + self.0.inverse().ok_or(FastCryptoError::InvalidInput)?, + )) + } +} + +impl ToFromByteArray for Scalar { + fn from_byte_array(bytes: &[u8; SCALAR_SIZE_IN_BYTES]) -> Result { + Ok(Scalar(Fr::from_le_bytes_mod_order(bytes))) + } + + fn to_byte_array(&self) -> [u8; SCALAR_SIZE_IN_BYTES] { + let mut bytes = [0u8; SCALAR_SIZE_IN_BYTES]; + self.0 + .serialize_uncompressed(&mut bytes[..]) + .expect("Byte array not large enough"); + bytes + } +} diff --git a/fastcrypto/src/lib.rs b/fastcrypto/src/lib.rs index 019a77dea4..7d1394effe 100644 --- a/fastcrypto/src/lib.rs +++ b/fastcrypto/src/lib.rs @@ -81,6 +81,10 @@ pub mod test_helpers; #[path = "tests/utils_tests.rs"] pub mod utils_tests; +#[cfg(test)] +#[path = "tests/secp256r1_group_tests.rs"] +pub mod secp256r1_group_tests; + pub mod traits; #[cfg(any(test, feature = "experimental"))] diff --git a/fastcrypto/src/secp256r1/mod.rs b/fastcrypto/src/secp256r1/mod.rs index 4ad47af45b..8ca81c0f68 100644 --- a/fastcrypto/src/secp256r1/mod.rs +++ b/fastcrypto/src/secp256r1/mod.rs @@ -20,14 +20,17 @@ pub mod recoverable; pub mod conversion; +use crate::groups::GroupElement; use crate::serde_helpers::BytesRepresentation; use crate::{ generate_bytes_representation, impl_base64_display_fmt, serialize_deserialize_with_to_from_bytes, }; -use ark_ec::{AffineRepr, CurveGroup, Group}; +use ark_ec::{AffineRepr, CurveGroup}; use ark_ff::Field; +use ark_secp256r1::Projective; use elliptic_curve::{Curve, FieldBytesEncoding, PrimeField}; +use lazy_static::lazy_static; use once_cell::sync::OnceCell; use p256::ecdsa::{ Signature as ExternalSignature, Signature, SigningKey as ExternalSecretKey, @@ -42,6 +45,10 @@ use zeroize::Zeroize; use fastcrypto_derive::{SilentDebug, SilentDisplay}; +use crate::groups::multiplier::windowed::WindowedScalarMultiplier; +use crate::groups::multiplier::ScalarMultiplier; +use crate::groups::secp256r1; +use crate::groups::secp256r1::{ProjectivePoint, SCALAR_SIZE_IN_BYTES}; use crate::hash::{HashFunction, Sha256}; use crate::secp256r1::conversion::{ affine_pt_p256_to_arkworks, arkworks_fq_to_fr, fr_arkworks_to_p256, fr_p256_to_arkworks, @@ -70,6 +77,12 @@ pub const SECP256R1_SIGNATURE_LENTH: usize = 64; /// The key pair bytes length is the same as the private key length. This enforces deserialization to always derive the public key from the private key. pub const SECP256R1_KEYPAIR_LENGTH: usize = SECP256R1_PRIVATE_KEY_LENGTH; +/// The number of precomputed points used for scalar multiplication. +pub const PRECOMPUTED_POINTS: usize = 256; + +/// The size of the sliding window used for scalar multiplication. +pub const SLIDING_WINDOW_WIDTH: usize = 5; + /// Default hash function used for signing and verifying messages unless another hash function is /// specified using the `with_hash` functions. pub type DefaultHash = Sha256; @@ -134,6 +147,22 @@ impl VerifyingKey for Secp256r1PublicKey { } } +lazy_static! { + static ref MULTIPLIER: WindowedScalarMultiplier< + ProjectivePoint, + crate::groups::secp256r1::Scalar, + PRECOMPUTED_POINTS, + SCALAR_SIZE_IN_BYTES, + SLIDING_WINDOW_WIDTH, + > = WindowedScalarMultiplier::< + ProjectivePoint, + crate::groups::secp256r1::Scalar, + PRECOMPUTED_POINTS, + SCALAR_SIZE_IN_BYTES, + SLIDING_WINDOW_WIDTH, + >::new(secp256r1::ProjectivePoint::generator()); +} + serialize_deserialize_with_to_from_bytes!(Secp256r1PublicKey, SECP256R1_PUBLIC_KEY_LENGTH); generate_bytes_representation!( Secp256r1PublicKey, @@ -173,7 +202,16 @@ impl Secp256r1PublicKey { // Verify signature let u1 = z * s_inv; let u2 = r * s_inv; - let p = ark_secp256r1::Projective::generator() * u1 + q * u2; + + // Do optimised double multiplication + let p = MULTIPLIER + .two_scalar_mul( + &secp256r1::Scalar(u1), + &ProjectivePoint(Projective::from(q)), + &secp256r1::Scalar(u2), + ) + .0; + let x = get_affine_x_coordinate(&p); // Note that x is none if and only if p is zero, in which case the signature is invalid. See @@ -379,7 +417,7 @@ impl Secp256r1KeyPair { let k_inv = k.inverse().expect("k should not be zero"); // Compute R = kG - let big_r = (ark_secp256r1::Affine::generator() * k).into_affine(); + let big_r = MULTIPLIER.mul(&secp256r1::Scalar(k)).0.into_affine(); // Lift x-coordinate of R and reduce it into an element of the scalar field let r = arkworks_fq_to_fr(big_r.x().expect("R should not be zero")); diff --git a/fastcrypto/src/secp256r1/recoverable.rs b/fastcrypto/src/secp256r1/recoverable.rs index d501ece9d2..11651ede43 100644 --- a/fastcrypto/src/secp256r1/recoverable.rs +++ b/fastcrypto/src/secp256r1/recoverable.rs @@ -18,13 +18,16 @@ //! assert_eq!(kp.public(), &signature.recover(message).unwrap()); //! ``` +use crate::groups::multiplier::ScalarMultiplier; +use crate::groups::secp256r1; +use crate::groups::secp256r1::ProjectivePoint; use crate::hash::HashFunction; use crate::secp256r1::conversion::{ affine_pt_arkworks_to_p256, affine_pt_p256_to_arkworks, fq_arkworks_to_p256, fr_p256_to_arkworks, reduce_bytes, }; use crate::secp256r1::{ - DefaultHash, Secp256r1KeyPair, Secp256r1PublicKey, Secp256r1Signature, + DefaultHash, Secp256r1KeyPair, Secp256r1PublicKey, Secp256r1Signature, MULTIPLIER, SECP256R1_SIGNATURE_LENTH, }; use crate::traits::{RecoverableSignature, RecoverableSigner, VerifyRecoverable}; @@ -34,7 +37,7 @@ use crate::{ traits::{EncodeDecodeBase64, ToFromBytes}, }; use crate::{impl_base64_display_fmt, serialize_deserialize_with_to_from_bytes}; -use ark_ec::{AffineRepr, CurveGroup, Group}; +use ark_ec::{AffineRepr, CurveGroup}; use ark_ff::Field; use ark_secp256r1::Projective; use ecdsa::elliptic_curve::scalar::IsHigh; @@ -225,11 +228,17 @@ impl RecoverableSignature for Secp256r1RecoverableSignature { // Compute public key let u1 = -(r_inv * z); let u2 = r_inv * s; - let pk = - affine_pt_arkworks_to_p256(&(Projective::generator() * u1 + big_r * u2).into_affine()); + + let pk = MULTIPLIER + .two_scalar_mul( + &secp256r1::Scalar(u1), + &ProjectivePoint(Projective::from(big_r)), + &secp256r1::Scalar(u2), + ) + .0; Ok(Secp256r1PublicKey { - pubkey: VerifyingKey::from_affine(pk) + pubkey: VerifyingKey::from_affine(affine_pt_arkworks_to_p256(&pk.into_affine())) .map_err(|_| FastCryptoError::GeneralOpaqueError)?, bytes: OnceCell::new(), }) diff --git a/fastcrypto/src/tests/secp256r1_group_tests.rs b/fastcrypto/src/tests/secp256r1_group_tests.rs new file mode 100644 index 0000000000..1ccceeaa06 --- /dev/null +++ b/fastcrypto/src/tests/secp256r1_group_tests.rs @@ -0,0 +1,27 @@ +// Copyright (c) 2022, Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use crate::groups::GroupElement; + +use crate::groups::secp256r1::{ProjectivePoint, Scalar}; +use crate::groups::{secp256r1, Scalar as ScalarTrait}; +use crate::serde_helpers::ToFromByteArray; +use rand::thread_rng; + +#[test] +fn test_to_from_byte_array() { + let scalar = secp256r1::Scalar::rand(&mut thread_rng()); + let bytes = scalar.to_byte_array(); + let reconstructed = Scalar::from_byte_array(&bytes).unwrap(); + assert_eq!(scalar, reconstructed); +} + +#[test] +fn test_arithmetic() { + let p = ProjectivePoint::generator(); + let two_p = p + p; + let s = Scalar::from(2); + assert_eq!(two_p, p.double()); + assert_eq!(two_p, p * s); + assert_eq!(p, two_p * (Scalar::generator() / s).unwrap()); +}