From 98d1f6d47bc4065602bca5cc39f8d4a9c5b5499f Mon Sep 17 00:00:00 2001 From: Enrico Bottazzi <85900164+enricobottazzi@users.noreply.github.com> Date: Tue, 20 Feb 2024 12:17:36 +0100 Subject: [PATCH] feat: add `random_fill` function body for `RandomGaussianDist` trait --- src/core_crypto/guassian_sampler.rs | 71 ----------------------------- src/core_crypto/mod.rs | 1 - src/core_crypto/random.rs | 49 +++++++++++--------- 3 files changed, 28 insertions(+), 93 deletions(-) delete mode 100644 src/core_crypto/guassian_sampler.rs diff --git a/src/core_crypto/guassian_sampler.rs b/src/core_crypto/guassian_sampler.rs deleted file mode 100644 index 7928d7d..0000000 --- a/src/core_crypto/guassian_sampler.rs +++ /dev/null @@ -1,71 +0,0 @@ -use rand::SeedableRng; -use rand_chacha::ChaCha20Rng; -use rand_distr::{Distribution, Normal}; - -#[derive(Clone, Debug)] -struct TruncatedDiscreteGaussian { - sigma: f64, - bound: i64, -} - -/// `TruncatedDiscreteGaussianSampler` represents an instance to sample values -/// from a trucated discrete Gaussian distribution with: -/// * standard deviation `sigma` -/// * bounds `[-bound, bound]`, where bound is equal to `(6*sigma).round()` -/// * mean 0 -#[derive(Clone, Debug)] -struct TruncatedDiscreteGaussianSampler { - sigma: f64, - bound: i64, - normal: Normal, - rng: ChaCha20Rng, -} - -impl TruncatedDiscreteGaussianSampler { - /// `new` returns a new `TruncatedDiscreteGaussianSampler` with the given - /// parameters starting from the standard deviation `sigma`. - pub fn new(sigma: f64) -> Self { - assert!(sigma > 0.0, "sigma must be positive"); - let bound = (6.0 * sigma).round() as i64; - let rng = ChaCha20Rng::from_entropy(); - let normal = Normal::new(0.0, sigma).unwrap(); - Self { - sigma, - bound, - normal, - rng, - } - } - - /// `sample` returns a sample from the truncated discrete Gaussian - /// distribution. The technique used is patterned after [lattigo](https://github.com/tuneinsight/lattigo/blob/c031b14be1fb3697945709d7afbed264fa845442/ring/sampler_gaussian.go#L71). - /// In particular, `sampled_val` is sampled from a normal distribution with - /// mean 0 and standard deviation `sigma`. If `sampled_val` is within - /// the bounds, it is rounded and returned. - pub fn sample(&mut self) -> i64 { - let sampled_val = self.normal.sample(&mut self.rng); - if sampled_val.abs() < self.bound as f64 { - sampled_val.round() as i64 - } else { - self.sample() - } - } -} - -#[cfg(test)] -mod tests { - use super::TruncatedDiscreteGaussianSampler; - use rand::Rng; - #[test] - fn gaussian_sampler_truncation_limit() { - // Assert that the samples are integers within the bounds of the truncated - // Gaussian distribution. - let sigma = rand::thread_rng().gen_range(1.0..=100.0); - let mut sampler = TruncatedDiscreteGaussianSampler::new(sigma); - - for _ in 0..5000000 { - let sample = sampler.sample(); - assert!((-sampler.bound..=sampler.bound).contains(&sample)); - } - } -} diff --git a/src/core_crypto/mod.rs b/src/core_crypto/mod.rs index 2bab608..d9e1e46 100644 --- a/src/core_crypto/mod.rs +++ b/src/core_crypto/mod.rs @@ -1,4 +1,3 @@ -pub mod guassian_sampler; pub mod matrix; pub mod modulus; pub mod ntt; diff --git a/src/core_crypto/random.rs b/src/core_crypto/random.rs index 4c8c51f..af7789d 100644 --- a/src/core_crypto/random.rs +++ b/src/core_crypto/random.rs @@ -1,12 +1,9 @@ -use std::{borrow::Borrow, cell::RefCell}; - -use itertools::{izip, Itertools}; -use num_traits::Zero; -use rand::{ - distributions::{uniform::SampleUniform, Uniform}, - thread_rng, CryptoRng, Rng, RngCore, SeedableRng, -}; +use std::cell::RefCell; + +use itertools::izip; +use rand::{distributions::Uniform, thread_rng, CryptoRng, Rng, RngCore, SeedableRng}; use rand_chacha::ChaCha8Rng; +use rand_distr::{Distribution, Normal}; use super::matrix::{Matrix, MatrixMut, RowMut}; @@ -15,6 +12,8 @@ where M: ?Sized, { type Parameters: ?Sized; + const MEAN: f64 = 0.0; + const STD_DEV: f64 = 3.2; fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut M); } @@ -157,25 +156,33 @@ where parameters.len() ); - // TODO (Jay) - // izip!(container.iter_rows_mut(), parameters.iter()).for_each(|(r, - // qi)| { izip!( - // r.as_mut().iter_mut(), - // (&mut self.rng).sample_iter(Uniform::new(0, *qi)) - // ) - // .for_each(|(r_el, random_el)| *r_el = random_el); - // }); + let normal = Normal::new( + >::MEAN, + >::STD_DEV, + ) + .unwrap(); + + izip!(container.iter_rows_mut(), parameters.iter()).for_each(|(r, qi)| { + izip!(r.as_mut().iter_mut(), normal.sample_iter(&mut self.rng)) + .for_each(|(r_el, random_el)| *r_el = (random_el as f64).round() as u64); + }); } } impl RandomGaussianDist<[u64]> for DefaultU64SeededRandomGenerator { type Parameters = u64; fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u64]) { - // izip!( - // container.as_mut().iter_mut(), - // (&mut self.rng).sample_iter(Uniform::new(0, *parameters)) - // ) - // .for_each(|(r_el, random_el)| *r_el = random_el); + let normal = Normal::new( + >::MEAN, + >::STD_DEV, + ) + .unwrap(); + + izip!( + container.as_mut().iter_mut(), + normal.sample_iter(&mut self.rng) + ) + .for_each(|(r_el, random_el)| *r_el = (random_el as f64).round() as u64); } }