diff --git a/src/core_crypto/num/mod.rs b/src/core_crypto/num/mod.rs index 9aeda20..926c48a 100644 --- a/src/core_crypto/num/mod.rs +++ b/src/core_crypto/num/mod.rs @@ -1,7 +1,9 @@ -use num_complex::Complex; +use num_bigint::{BigInt, BigUint}; +use num_complex::{Complex, ComplexFloat}; use num_traits::{ - AsPrimitive, CheckedShl, CheckedShr, MulAddAssign, Num, NumAssign, NumOps, One, Pow, PrimInt, - ToPrimitive, WrappingAdd, WrappingMul, WrappingShl, WrappingShr, WrappingSub, Zero, + AsPrimitive, CheckedShl, CheckedShr, Float, MulAddAssign, Num, NumAssign, NumCast, NumOps, + NumRef, One, Pow, PrimInt, RefNum, ToPrimitive, WrappingAdd, WrappingMul, WrappingShl, + WrappingShr, WrappingSub, Zero, }; use std::{ fmt::{Debug, Display}, @@ -49,35 +51,69 @@ impl NumericConstants for u128 { impl UnsignedInteger for u64 {} impl UnsignedInteger for u128 {} -pub trait Float: num_traits::Float + From { - const PI: Self; -} - -impl Float for f64 { - const PI: Self = std::f64::consts::PI * 2.0; -} - -pub trait ComplexNumber: Num + Div +// TODO(Jay): How to add trait bound for impl Add<&T, Output = T> for {}. Doing +// this remove unecessary copy bounds in ckks/ops.rs +pub trait ComplexNumber: + Num + NumRef + for<'r> Div<&'r T, Output = Self> + for<'r> Mul<&'r T, Output = Self> where - T: Float, + T: BFloat, { + fn new(re: T, im: T) -> Self; fn nth_root(n: u32) -> Self; fn powu(&self, exp: u32) -> Self; fn re(&self) -> T; fn img(&self) -> T; } -impl ComplexNumber for Complex { +pub trait BUint: Num + NumRef + PartialOrd {} + +pub trait BInt: Num + NumRef + PartialOrd {} + +pub trait BFloat: Num + NumRef + Float + NumCast { + const TWICE_PI: Self; +} + +// TODO(Jay): CastToZp is unecessary since it is only used on simd_encode +// where it can easily be replaced with TryConvertFrom<[f64]> to M +pub trait CastToZp { + fn cast(&self, q: &To) -> To; +} + +pub trait ModInverse { + fn mod_inverse(&self, q: Self) -> Self; +} + +impl ComplexNumber for Complex { + fn new(re: T, im: T) -> Self { + Complex::new(re, im) + } fn nth_root(n: u32) -> Self { - Complex::::from_polar(T::one(), T::PI / >::from(n)) + Complex::::from_polar(T::one(), T::TWICE_PI / T::from(n).unwrap()) } fn powu(&self, exp: u32) -> Self { self.powu(exp) } fn re(&self) -> T { - self.re + self.re.clone() } fn img(&self) -> T { - self.im + self.im.clone() + } +} + +impl BFloat for f64 { + const TWICE_PI: Self = std::f64::consts::PI * 2.0; +} + +impl BUint for BigUint {} + +impl CastToZp for f64 { + fn cast(&self, q: &BigUint) -> BigUint { + let v = self.round(); + if v < 0.0 { + q - (v.abs().to_u64().unwrap() % q) + } else { + v.to_u64().unwrap() % q + } } } diff --git a/src/parameters.rs b/src/parameters.rs index 4661a99..1ef684b 100644 --- a/src/parameters.rs +++ b/src/parameters.rs @@ -9,7 +9,7 @@ use crate::{ MontgomeryBackend, MontgomeryScalar, }, ntt::{Ntt, NttConfig}, - num::UnsignedInteger, + num::{BFloat, ComplexNumber, UnsignedInteger}, }, utils::{mod_inverse, mod_inverse_big_unit}, }; @@ -140,3 +140,17 @@ pub trait BfvEncodingDecodingParameters: Parameters { fn t_ntt_op(&self) -> &Self::NttOp; fn modt_op(&self) -> &Self::ModOp; } + +// CKKS encoding decoding parameters +pub trait CkksEncodingDecodingParameters: Parameters { + type F: BFloat; + type Complex: ComplexNumber; + type BU; + + fn delta(&self) -> Self::F; + fn psi_powers(&self) -> &[Self::Complex]; + fn rot_group(&self) -> &[usize]; + fn ring_size(&self) -> usize; + fn bigq_at_level(&self, level: usize) -> &Self::BU; + fn q_moduli_chain_at_level(&self, level: usize) -> &[Self::Scalar]; +} diff --git a/src/schemes/ckks/ops.rs b/src/schemes/ckks/ops.rs index b0563e0..e9197f7 100644 --- a/src/schemes/ckks/ops.rs +++ b/src/schemes/ckks/ops.rs @@ -1,24 +1,20 @@ +use itertools::{izip, Itertools}; use std::{ - clone, fmt::Debug, - ops::{Div, Rem, Sub}, - process::Output, + ops::{Rem, Sub}, }; -use itertools::{izip, Itertools}; -use num_traits::{Num, One}; - use crate::{ core_crypto::{ matrix::{Matrix, MatrixMut, RowMut}, - num::{ComplexNumber, Float, UnsignedInteger}, - ring, + num::{BFloat, BInt, BUint, CastToZp, ComplexNumber, UnsignedInteger}, }, - utils::bit_reverse_map, + parameters::CkksEncodingDecodingParameters, + utils::{bit_reverse_map, convert::TryConvertFrom}, }; // specialIFFT -fn special_inv_fft + Clone + Copy>( +pub fn special_inv_fft + Clone + Copy>( v: &mut [C], psi_powers: &[C], rot_group: &[usize], @@ -29,8 +25,8 @@ fn special_inv_fft + Clone + Copy>( v.len() ); debug_assert!( - v.len() * 4 + 1 == psi_powers.len(), - "psi_powers must have powers of psi for 0 <= j <= 4l, but its length is {}", + v.len() * 4 == psi_powers.len(), + "psi_powers must have powers of psi for 0 <= j < 4l, but its length is {}", psi_powers.len() ); @@ -48,7 +44,7 @@ fn special_inv_fft + Clone + Copy>( let idx = (lenq - (rot_group[j] % lenq)) * gap; // X + Y - let u = v[i + j] + v[i + j + lenh]; + let u = v[i + j] + &v[i + j + lenh]; // (X - Y) \cdot \psi_{idx} let k = (v[i + j] - v[i + j + lenh]) * psi_powers[idx]; @@ -62,10 +58,10 @@ fn special_inv_fft + Clone + Copy>( bit_reverse_map(v); v.iter_mut() - .for_each(|a| *a = *a / >::from(v_len as u32)); + .for_each(|a| *a = *a / &F::from(v_len).unwrap()); } -fn special_fft + Copy + Clone>( +pub fn special_fft + Copy>( v: &mut [C], psi_powers: &[C], rot_group: &[usize], @@ -76,8 +72,8 @@ fn special_fft + Copy + Clone>( v.len() ); debug_assert!( - v.len() * 4 + 1 == psi_powers.len(), - "psi_powers must have powers of psi for 0 <= j <= 4l, but its length is {}", + v.len() * 4 == psi_powers.len(), + "psi_powers must have powers of psi for 0 <= j < 4l, but its length is {}", psi_powers.len() ); @@ -96,6 +92,7 @@ fn special_fft + Copy + Clone>( for j in 0..lenh { let idx = (rot_group[j] % lenq) * gap; + // TODO(Jay): Remove bound of Copy let u = v[i + j]; let k = v[i + j + lenh] * psi_powers[idx]; @@ -109,116 +106,120 @@ fn special_fft + Copy + Clone>( } } -/// Scale input float a by delta and randomly rounds delta_a. Maps delta_a from -/// signed interval to unsigned interval [0, big_q) where big_q is product of -/// smaller primes qi. Maps delta_a in unsigned interval to big_q's moduli chain -/// and returns -fn map_float_to_zq_moduli_chain( - a: F, - delta: F, - big_q: &Uint, - q_moduli_chain: &[Uint], -) -> Vec -where - >::Error: Debug, - >::Error: Debug, - >::Error: Debug, - Uint: TryFrom + TryFrom + Num + PartialOrd, - for<'a> &'a Uint: Rem<&'a Uint, Output = Uint> + Sub<&'a Uint, Output = Uint>, - Scalar: TryFrom, -{ - let delta_a = a * delta; - let delta_a = delta_a.round(); // TODO random round - - // convert signed to unsigned representation - let delta_a_modq = if delta_a < F::zero() { - big_q - &Uint::try_from(-delta).unwrap() - } else { - Uint::try_from(delta).unwrap() - }; - - q_moduli_chain - .iter() - .map(|qi| >::try_from(&delta_a_modq % qi).unwrap()) - .collect_vec() -} - pub fn simd_encode< - Scalar: UnsignedInteger, - Uint, - F: Float, + Scalar: UnsignedInteger + TryFrom, + Uint: BUint, + F: BFloat + CastToZp, C: ComplexNumber + Clone + Copy, MMut: MatrixMut, + P: CkksEncodingDecodingParameters, >( p: &mut MMut, m: &[C], + params: &P, + level: usize, + delta: F, ) where - >::Error: Debug, - >::Error: Debug, >::Error: Debug, - Uint: TryFrom + TryFrom + Num + PartialOrd, - Scalar: TryFrom, - for<'a> &'a Uint: Rem<&'a Uint, Output = Uint> + Sub<&'a Uint, Output = Uint>, + for<'a> &'a Uint: Rem + Sub<&'a Uint, Output = Uint>, ::R: RowMut, { - let slots = 0usize; - let ring_size = 0; + let ring_size = params.ring_size(); + let slots = ring_size >> 1; + debug_assert!( m.len() == slots, "Expected m vector length {slots} but is {}", m.len() ); + + let psi_powers = params.psi_powers(); + let rot_group = params.rot_group(); debug_assert!( - ring_size == 2 * slots, - "Expected ring_size = {} (2*slots), but is {}", - 2 * slots, - ring_size + psi_powers.len() == ring_size * 2, + "Expected psi^i for 0 <= i < {}(M) but psi_powers has length {}", + ring_size * 2, + psi_powers.len() + ); + debug_assert!( + rot_group.len() == slots, + "Expected (5^j mod M) for 0 <= j < {}(l) but rot_group has length {}", + slots, + rot_group.len() ); - - let psi_powers = vec![]; - let rot_group = vec![]; let mut m = m.to_vec(); special_inv_fft(&mut m, &psi_powers, &rot_group); - let delta = F::zero(); + // scale by delta + izip!(m.iter_mut()).for_each(|v| { + *v = *v * δ + }); - let q_moduli_chain: Vec = vec![]; - let q_moduli_chain_big: Vec = q_moduli_chain - .iter() - .map(|v| (*v).try_into().unwrap()) - .collect_vec(); - let big_q = Uint::one(); + let q_moduli_chain = params.q_moduli_chain_at_level(level); + let big_q = params.bigq_at_level(level); for ri in 0..ring_size >> 1 { - izip!( - p.get_col_iter_mut(ri), - map_float_to_zq_moduli_chain::( - m[ri].re(), - delta, - &big_q, - &q_moduli_chain_big - ) - .iter() - ) - .for_each(|(to, from)| *to = *from); + let delta_m_ri = CastToZp::cast(&m[ri].re(), big_q); + izip!(p.get_col_iter_mut(ri), q_moduli_chain.iter()).for_each(|(x_qi, qi)| { + *x_qi = (&delta_m_ri % *qi).try_into().unwrap(); + }); } for ri in ring_size >> 1..ring_size { - izip!( - p.get_col_iter_mut(ri), - map_float_to_zq_moduli_chain::( - m[ri].img(), - delta, - &big_q, - &q_moduli_chain_big - ) - .iter() - ) - .for_each(|(to, from)| *to = *from); + let delta_m_ri = CastToZp::cast(&m[ri - (ring_size >> 1)].img(), big_q); + izip!(p.get_col_iter_mut(ri), q_moduli_chain.iter()).for_each(|(x_qi, qi)| { + *x_qi = (&delta_m_ri % *qi).try_into().unwrap(); + }); } } +pub fn simd_decode< + Scalar: UnsignedInteger, + F: BFloat, + // TODO(Jay): Remove Copy bound + C: ComplexNumber + Copy, + M: Matrix, + P: CkksEncodingDecodingParameters, +>( + p: &M, + params: &P, + level: usize, + delta: F, + m_out: &mut [C], +) where + Vec: TryConvertFrom, +{ + let ring_size = params.ring_size(); + let slots = ring_size >> 1; + debug_assert!( + m_out.len() == ring_size >> 1, + "Expected m_out to have {} slots but has {}", + slots, + m_out.len() + ); + + let q_moduli_chain = params.q_moduli_chain_at_level(level); + // TODO(Jay): `try_convert_from` first computes recomposition factors. Hence is + // quie expensive. + let mut p_reals = Vec::::try_convert_from(p, &q_moduli_chain); + p_reals + .iter_mut() + .map(|v| { + // scale by 1/delta + *v = *v / delta; + }) + .collect_vec(); + + for k in 0..slots { + m_out[k] = C::new(p_reals[k], p_reals[slots + k]); + } + + let psi_powers = params.psi_powers(); + let rot_group = params.rot_group(); + special_fft(m_out, &psi_powers, &rot_group); +} + // specialFFT // encoding // decoeding @@ -227,15 +228,21 @@ pub fn simd_encode< #[cfg(test)] mod tests { - use crate::utils::psi_powers; + use crate::{ + core_crypto::{prime::generate_primes_vec, ring}, + parameters::Parameters, + utils::{moduli_chain_to_biguint, psi_powers}, + }; use super::*; use itertools::Itertools; - use num_complex::{Complex64, ComplexDistribution}; + use num_bigint::BigUint; + use num_complex::{Complex, Complex64, ComplexDistribution}; + use num_traits::{zero, Zero}; use rand::{distributions::Uniform, thread_rng, Rng}; #[test] - fn special_inv_fft_round_trip() { + fn special_fft_round_trip() { // generate random complex values // create rot_group let m = 32; @@ -271,5 +278,101 @@ mod tests { // can't hardcode a value. Therefore, we pretty print here // difference and check ourselves that the error is negligible // TODO(Jay): Pretty print the difference + // dbg!(values_clone); + // dbg!(values); + } + + struct CkksParameters { + delta: f64, + psi_powers: Vec>, + rot_group: Vec, + ring_size: usize, + q_moduli_chain: Vec, + bigq: BigUint, + } + impl CkksParameters { + fn new(ring_size: usize, q_moduli_chain: Vec, delta: f64) -> Self { + let m = ring_size << 1; + let n = ring_size; + let l = n >> 1; + + let mut a = 1usize; + let mut rot_group = vec![]; + for _ in 0..l { + rot_group.push(a); + a = (a * 5) % m; + } + + let psi_powers = psi_powers(m as u32); + let bigq = moduli_chain_to_biguint(&q_moduli_chain); + + CkksParameters { + delta, + psi_powers, + rot_group, + ring_size, + q_moduli_chain, + bigq, + } + } + } + + impl Parameters for CkksParameters { + type Scalar = u64; + } + + impl CkksEncodingDecodingParameters for CkksParameters { + type BU = BigUint; + type Complex = Complex; + type F = f64; + + fn bigq_at_level(&self, level: usize) -> &Self::BU { + &self.bigq + } + fn delta(&self) -> Self::F { + self.delta + } + fn psi_powers(&self) -> &[Self::Complex] { + &self.psi_powers + } + fn q_moduli_chain_at_level(&self, level: usize) -> &[Self::Scalar] { + &self.q_moduli_chain + } + fn ring_size(&self) -> usize { + self.ring_size + } + fn rot_group(&self) -> &[usize] { + &self.rot_group + } + } + + #[test] + fn encoding_decoding_works() { + let ring_size = 1 << 4; + let q_moduli_chain = generate_primes_vec(&[50, 50, 50], ring_size, &[]); + let params = CkksParameters::new(ring_size, q_moduli_chain, 2.0_f64.powi(40i32)); + + // vec of length l with random complex values + let reals = Uniform::new(0.0, 100.0); + let imags = Uniform::new(0.0, 100.0); + let complex_distr = ComplexDistribution::new(reals, imags); + let values = thread_rng() + .sample_iter(complex_distr) + .take(params.ring_size() >> 1) + .collect_vec(); + + let level = 0; + let mut p = > as Matrix>::zeros( + params.q_moduli_chain_at_level(0).len(), + params.ring_size(), + ); + + simd_encode(&mut p, &values, ¶ms, level, params.delta()); + + let mut m_out = vec![Complex::::zero(); ring_size >> 1]; + simd_decode(&p, ¶ms, level, params.delta(), &mut m_out); + + // dbg!(values); + // dbg!(m_out); } } diff --git a/src/utils/convert.rs b/src/utils/convert.rs index df5e21d..8df1395 100644 --- a/src/utils/convert.rs +++ b/src/utils/convert.rs @@ -1,9 +1,11 @@ +use std::{iter::Map, ops::Neg, path::Iter}; + use aligned_vec::{AVec, CACHELINE_ALIGN}; -use itertools::Itertools; +use itertools::{izip, Itertools}; use num_bigint::{BigInt, BigUint, ToBigInt}; use num_traits::{ToPrimitive, Zero}; -use crate::core_crypto::matrix::Matrix; +use crate::core_crypto::matrix::{Matrix, MatrixMut, RowMut}; use super::{mod_inverse, moduli_chain_to_biguint}; @@ -171,6 +173,49 @@ where } } +impl TryConvertFrom for Vec +where + M: Matrix, +{ + type Parameters = [u64]; + + fn try_convert_from(value: &M, parameters: &Self::Parameters) -> Self { + let big_q = moduli_chain_to_biguint(parameters); + + // q/q_i + let mut q_over_qi_vec = vec![]; + // [[q/q_i]^{-1}]_q_i + let mut q_over_qi_inv_modqi_vec = vec![]; + parameters.iter().for_each(|qi| { + let q_over_qi = &big_q / qi; + let q_over_qi_inv_modqi = + BigUint::from(mod_inverse((&q_over_qi % qi).to_u64().unwrap(), *qi)); + q_over_qi_vec.push(q_over_qi); + q_over_qi_inv_modqi_vec.push(q_over_qi_inv_modqi); + }); + + let (_, ring_size) = value.dimension(); + + let mut out_coeffs = vec![]; + for ri in 0..ring_size { + let mut x = BigUint::zero(); + value.get_col_iter(ri).enumerate().for_each(|(i, xi)| { + x += xi * &q_over_qi_vec[i] * &q_over_qi_inv_modqi_vec[i]; + }); + x = x % &big_q; + + // convert x from unsigned representation to signed representation + if x >= &big_q >> 1 { + out_coeffs.push(((&big_q - x).to_bigint().unwrap().neg()).to_f64().unwrap()); + } else { + out_coeffs.push(x.to_bigint().unwrap().to_f64().unwrap()); + } + } + + out_coeffs + } +} + impl TryConvertFromParts for Vec where M: Matrix, @@ -238,6 +283,45 @@ where } } +// TODO(Jay): I wanted to implement `try_convert_from` from [f64] to Matrix. +// Couldn't figure out the right way to do so. After implementing this we can +// `CastToZp for F` trait bound (CastToZp trait itself) in `simd_encode` +// and instead use TryConvertFrom<[f64]> for Matrix + +// pub trait TryConvertFromMut { +// type Parameters: ?Sized; + +// fn try_convert_from_mut(value: &T, parameters: &Self::Parameters, out: +// &mut Self); } +// impl TryConvertFromMut> for M +// where +// M: MatrixMut, +// ::R: RowMut, +// { +// type Parameters = [u64]; +// fn try_convert_from_mut( +// value: &dyn Iterator, +// parameters: &Self::Parameters, +// out: &mut Self, +// ) { +// // TODO(Jay): I don't think calculating mbig_q here is critical but +// not sure. // Recheck +// let big_q = moduli_chain_to_biguint(¶meters); +// let dim = out.dimension(); + +// debug_assert!( +// dim.0 == parameters.len(), +// "Expected Matrix to have {} rows but has {}", +// parameters.len(), +// dim.0 +// ); + +// let ring_size = dim.1; + +// izip!((0..ring_size).into_iter(), *value).for_each(|(col_index, v)| +// {}); } +// } + #[cfg(test)] mod tests { use aligned_vec::AVec; diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 302e316..676357d 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -5,7 +5,7 @@ use num_traits::{FromPrimitive, One, ToBytes, ToPrimitive, Zero}; use crate::core_crypto::{ modulus::{BarrettBackend, ModulusBackendConfig, NativeModulusBackend}, - num::{ComplexNumber, Float, UnsignedInteger}, + num::{BFloat, ComplexNumber, UnsignedInteger}, }; use std::{ mem, @@ -234,9 +234,9 @@ pub fn bit_reverse_map(a: &mut [T]) { /// Calculates M^th root of unity and returns the subgroup \psi^{0}, \psi^{1}, /// ..., \psi^{M} -pub fn psi_powers>(m: u32) -> Vec { +pub fn psi_powers>(m: u32) -> Vec { let m_root_unity = C::nth_root(m); - (0..m + 1) + (0..m) .into_iter() .map(|i| m_root_unity.powu(i as u32)) .collect_vec()