Skip to content

Commit

Permalink
improve bigfloat to bigint conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
Janmajayamall committed Mar 6, 2024
1 parent 2a1cb7e commit 56898a8
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 27 deletions.
94 changes: 76 additions & 18 deletions src/core_crypto/num/big_float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@ use std::{
sync::OnceLock,
};

use astro_float::{BigFloat as AstroBFloat, Consts};
use astro_float::{BigFloat as AstroBFloat, Consts, Sign};
use num_bigint::{BigInt, BigUint};
use num_complex::Complex;
use num_traits::{Float, Num, NumOps, NumRef, One, Zero};
use num_traits::{Float, FromPrimitive, Num, NumOps, NumRef, One, Zero};

use super::{BFloat, CastToZp, ComplexNumber};

#[derive(Clone, Debug)]
pub struct BigFloat(AstroBFloat);

const PRECISION: usize = 256usize;
const MATISSA_LEN: usize = PRECISION / astro_float::WORD_BIT_SIZE;
const ROUNDING_MODE: astro_float::RoundingMode = astro_float::RoundingMode::None;

fn astro_one() -> &'static AstroBFloat {
Expand Down Expand Up @@ -235,8 +236,8 @@ impl From<u32> for BigFloat {
}
}

impl From<BigInt> for BigFloat {
fn from(value: BigInt) -> Self {
impl From<&BigInt> for BigFloat {
fn from(value: &BigInt) -> Self {
let mut consts = Consts::new().unwrap();
let (sign, digits) = value.to_radix_be(10);
let sign = if sign == num_bigint::Sign::Minus {
Expand All @@ -256,26 +257,51 @@ impl From<BigInt> for BigFloat {
}
}

impl From<&BigFloat> for BigInt {
fn from(value: &BigFloat) -> Self {
let (raw, _, s, _, _) = value.0.as_raw_parts().unwrap();
let exponent = value.0.exponent().unwrap();

let mut bits = exponent;
let mut index = MATISSA_LEN - 1;
let mut res = BigInt::zero();
while bits > 0 {
res += (BigInt::from_u64(raw[index as usize]).unwrap())
<< (astro_float::WORD_BIT_SIZE * index);
bits -= astro_float::WORD_BIT_SIZE as i32;
index -= 1;
}
res >>= (PRECISION as i32) - exponent;
if s.is_negative() {
res.neg()
} else {
res
}
}
}

impl CastToZp<BigUint> for BigFloat {
fn cast(&self, q: &BigUint) -> BigUint {
let mut consts = Consts::new().unwrap();
let (sign, radix_repr, mut ex) = self
.0
.round(0, ROUNDING_MODE)
.convert_to_radix(astro_float::Radix::Dec, ROUNDING_MODE, &mut consts)
.unwrap();

if ex < 0 {
ex = 0;
let (raw, _, s, _, _) = self.0.as_raw_parts().unwrap();
let exponent = self.0.exponent().unwrap();

let mut bits = exponent;
let mut index = MATISSA_LEN - 1;
let mut res = BigUint::zero();
while bits > 0 {
res += (BigUint::from_u64(raw[index as usize]).unwrap())
<< (astro_float::WORD_BIT_SIZE * index);
bits -= astro_float::WORD_BIT_SIZE as i32;
index -= 1;
}
res >>= (PRECISION as i32) - exponent;

let v = BigUint::from_radix_be(&radix_repr[..(ex as usize)], 10).unwrap() % q;

if sign.is_negative() {
return q - v;
res %= q;
if s.is_negative() {
return q - res;
}

v
res
}
}

Expand All @@ -288,4 +314,36 @@ impl std::fmt::Display for BigFloat {
#[cfg(test)]
mod tests {
use super::*;

const K: usize = 128;

#[test]
fn convert_bigfloat_to_biguint() {
// This suffices to test CastToZp<BigUint> for BigFloat since both
// `From<BigFloat> for BigInt` and `CastToZp<BigUint> for BigFloat` follow the
// process
for _ in 0..K {
let float = BigFloat(AstroBFloat::random_normal(PRECISION, 0, 100));
let bigint: BigInt = (&float).into();

// expected
let expected_bigint = {
let mut consts = Consts::new().unwrap();
let (sign, radix_repr, mut ex) = float
.0
// .round(0, ROUNDING_MODE)
.convert_to_radix(astro_float::Radix::Dec, ROUNDING_MODE, &mut consts)
.unwrap();
if sign.is_negative() {
BigInt::from_radix_be(num_bigint::Sign::Minus, &radix_repr[..(ex as usize)], 10)
.unwrap()
} else {
BigInt::from_radix_be(num_bigint::Sign::Plus, &radix_repr[..(ex as usize)], 10)
.unwrap()
}
};

assert_eq!(bigint, expected_bigint, "{:?}", float);
}
}
}
29 changes: 26 additions & 3 deletions src/schemes/ckks/default_impl/entities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,18 @@ use crate::{
keys::{Decryptor, Encryptor, LevelDecoder, LevelEncoder, SecretKey},
parameters::{CkksEncDecParameters, Parameters},
schemes::{
ckks::ops::{secret_key_decryption, secret_key_encryption, simd_decode, simd_encode},
ckks::ops::{
secret_key_decryption, secret_key_encryption, simd_decode, simd_encode,
ScaledCkksCiphertext,
},
ops::generate_ternery_secret_with_hamming_weight,
WithGlobal,
},
utils::{convert::TryConvertFrom, moduli_chain_to_biguint, psi_powers},
};

use super::CkksCiphertext;

type DefaultBigFloat = crate::core_crypto::num::big_float::BigFloat;
type DefaultComplex = Complex<DefaultBigFloat>;

Expand Down Expand Up @@ -102,9 +107,11 @@ where
fn encrypt(&self, message: &[DefaultComplex]) -> CkksCiphertextGenericStorage<M> {
CkksClientParametersU64::with_global(|params| {
DefaultU64SeededRandomGenerator::with_local_mut(|rng| {
let delta = params.delta();

// encode
let mut m_poly = M::zeros(params.q_moduli_chain_len, params.ring_size());
simd_encode(&mut m_poly, message, params, 0, params.delta());
simd_encode(&mut m_poly, message, params, 0, delta);

// `encrypt` function wants m_poly in Evaluation representation.
foward_lazy(&mut m_poly, params.q_nttops_at_level(0));
Expand All @@ -116,6 +123,7 @@ where
is_lazy: false,
seed: <ChaCha8Rng as SeedableRng>::Seed::default(),
representation: Representation::Evaluation,
scale: delta.clone(),
};
secret_key_encryption(&mut c_out, &m_poly, self, params, rng, 0);

Expand Down Expand Up @@ -145,7 +153,7 @@ where

// decode
let mut m_out = vec![DefaultComplex::zero(); params.ring_size() >> 1];
simd_decode(&m_poly, params, c.level(), params.delta(), &mut m_out);
simd_decode(&m_poly, params, c.level(), c.scale(), &mut m_out);
m_out
})
}
Expand Down Expand Up @@ -278,6 +286,7 @@ pub struct CkksCiphertextGenericStorage<M> {
is_lazy: bool,
seed: <ChaCha8Rng as SeedableRng>::Seed,
representation: Representation,
scale: DefaultBigFloat,
}

impl<M: MatrixMut<MatElement = u64>> Ciphertext for CkksCiphertextGenericStorage<M>
Expand Down Expand Up @@ -337,6 +346,20 @@ where
}
}

impl<M: MatrixMut<MatElement = u64>> ScaledCkksCiphertext for CkksCiphertextGenericStorage<M>
where
<M as Matrix>::R: RowMut,
{
type F = DefaultBigFloat;
fn scale(&self) -> &Self::F {
&self.scale
}

fn scale_mut(&mut self) -> &mut Self::F {
&mut self.scale
}
}

pub fn build_parameters(q_moduli_chain_sizes: &[usize], delta: DefaultBigFloat, ring_size: usize) {
let q_moduli_chain = generate_primes_vec(q_moduli_chain_sizes, ring_size, &[]);

Expand Down
10 changes: 7 additions & 3 deletions src/schemes/ckks/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ use crate::{
utils::{bit_reverse_map, convert::TryConvertFrom},
};

pub trait ScaledCkksCiphertext: RlweCiphertext {
type F: BFloat;
fn scale(&self) -> &Self::F;
fn scale_mut(&mut self) -> &mut Self::F;
}

pub fn special_inv_fft<F: BFloat + From<u32>, C: ComplexNumber<F> + Clone>(
v: &mut [C],
psi_powers: &[C],
Expand Down Expand Up @@ -169,12 +175,10 @@ pub fn simd_encode<
let mut m = m.to_vec();
special_inv_fft(&mut m, &psi_powers, &rot_group);

// println!("{}", &m[0]);
// scale by delta
izip!(m.iter_mut()).for_each(|v| {
*v = &*v * delta;
});
// println!("{}", &m[0]);

let q_moduli_chain = params.q_moduli_chain_at_level(level);
let big_q = params.bigq_at_level(level);
Expand Down Expand Up @@ -470,7 +474,7 @@ mod tests {
use super::*;
use itertools::Itertools;
use num_bigint::BigUint;
use num_complex::{Complex, Complex64, ComplexDistribution};
use num_complex::{Complex, ComplexDistribution};
use num_traits::{zero, Zero};
use rand::{distributions::Uniform, thread_rng, Rng};

Expand Down
6 changes: 3 additions & 3 deletions src/utils/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ where

let (_, ring_size) = value.dimension();

let mut out_coeffs = vec![];
let mut out_coeffs: Vec<BigFloat> = vec![];
for ri in 0..ring_size {
let mut x = BigUint::zero();
value.get_col_iter(ri).enumerate().for_each(|(i, xi)| {
Expand All @@ -209,9 +209,9 @@ where

// convert x from unsigned representation to signed representation
if x >= &big_q >> 1 {
out_coeffs.push(((&big_q - x).to_bigint().unwrap().neg()).into());
out_coeffs.push((&((&big_q - x).to_bigint().unwrap().neg())).into());
} else {
out_coeffs.push(x.to_bigint().unwrap().to_f64().unwrap().into());
out_coeffs.push((&(x.to_bigint().unwrap())).into());
}
}

Expand Down

0 comments on commit 56898a8

Please sign in to comment.