diff --git a/sunscreen_tfhe/src/entities/bootstrap_key.rs b/sunscreen_tfhe/src/entities/bootstrap_key.rs index 08c4f09c7..622b3860c 100644 --- a/sunscreen_tfhe/src/entities/bootstrap_key.rs +++ b/sunscreen_tfhe/src/entities/bootstrap_key.rs @@ -14,7 +14,7 @@ use crate::{ dst! { /// Keys used for bootstrapping. The [BootstrapKeyFft] variant of this type /// is used by the bootstrapping functions such as - /// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap). + /// [`programmable_bootstrap_univariate`](crate::ops::bootstrapping::programmable_bootstrap_univariate). BootstrapKey, BootstrapKeyRef, Torus, @@ -116,7 +116,7 @@ impl BootstrapKeyRef { dst! { /// Keys used for bootstrapping. Used by the bootstrapping functions such as - /// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap). + /// [`programmable_bootstrap_univariate`](crate::ops::bootstrapping::programmable_bootstrap_univariate). /// The non-FFT variant of this type is [BootstrapKey]. BootstrapKeyFft, BootstrapKeyFftRef, @@ -140,7 +140,7 @@ impl BootstrapKeyFft> { /// encrypts a single bit of an LWE secret key. In this representation, the /// GGSW ciphertexts are in the frequency domain and can be used directly by /// the bootstrapping functions such as - /// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap). + /// [`programmable_bootstrap_univariate`](crate::ops::bootstrapping::programmable_bootstrap_univariate). pub fn new(lwe_params: &LweDef, glwe_params: &GlweDef, radix: &RadixDecomposition) -> Self { let len = BootstrapKeyFftRef::size((lwe_params.dim, glwe_params.dim, radix.count)); diff --git a/sunscreen_tfhe/src/entities/univariate_lookup_table.rs b/sunscreen_tfhe/src/entities/univariate_lookup_table.rs index 820fa7945..832005b7c 100644 --- a/sunscreen_tfhe/src/entities/univariate_lookup_table.rs +++ b/sunscreen_tfhe/src/entities/univariate_lookup_table.rs @@ -13,7 +13,7 @@ use super::GlweCiphertextRef; dst! { /// Lookup table for a univariate function used during - /// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap) + /// [`programmable_bootstrap_univariate`](crate::ops::bootstrapping::programmable_bootstrap_univariate) /// and [`circuit_bootstrap`](crate::ops::bootstrapping::circuit_bootstrap). UnivariateLookupTable, UnivariateLookupTableRef, @@ -31,7 +31,11 @@ impl OverlaySize for UnivariateLookupTableRef { } impl UnivariateLookupTable { - /// Creates a lookup table that is trivially encrypted. + /// Creates a trivially encrypted lookup table that computes a single function `map`. + /// + /// # Remarks + /// The result of this can be used with + /// [`programmable_bootstrap_univariate`](crate::ops::bootstrapping::programmable_bootstrap_univariate). pub fn trivial_from_fn(map: F, glwe: &GlweDef, plaintext_bits: PlaintextBits) -> Self where F: Fn(u64) -> u64, @@ -40,7 +44,32 @@ impl UnivariateLookupTable { data: vec![Torus::zero(); UnivariateLookupTableRef::::size(glwe.dim)], }; - lut.fill_trivial_from_fn(map, glwe, plaintext_bits); + lut.fill_trivial_from_fns(&[map], glwe, plaintext_bits); + + lut + } + + /// Creates a trivially encrypted lookup table that computes multiple functions + /// given by `maps`. + /// + /// # Remarks + /// The result of this should be used with + /// [`generalized_programmable_bootstrap`](crate::ops::bootstrapping::generalized_programmable_bootstrap). + pub fn trivivial_multifunctional( + maps: &[F], + glwe: &GlweDef, + plaintext_bits: PlaintextBits, + ) -> Self + where + F: Fn(u64) -> u64, + { + assert!(maps.len() > 1); + + let mut lut = UnivariateLookupTable { + data: vec![Torus::zero(); UnivariateLookupTableRef::::size(glwe.dim)], + }; + + lut.fill_trivial_from_fns(maps, glwe, plaintext_bits); lut } @@ -60,15 +89,15 @@ impl UnivariateLookupTableRef { /// Generates a look up table filled with the values from the provided map, /// and trivially encrypts the lookup table. - pub fn fill_trivial_from_fn u64>( + pub fn fill_trivial_from_fns u64>( &mut self, - map: F, + maps: &[F], glwe: &GlweDef, plaintext_bits: PlaintextBits, ) { allocate_scratch_ref!(poly, PolynomialRef>, (glwe.dim.polynomial_degree)); - generate_lut(poly, map, glwe, plaintext_bits); + generate_lut(poly, maps, glwe, plaintext_bits); trivially_encrypt_glwe_ciphertext(self.glwe_mut(), poly, glwe); } diff --git a/sunscreen_tfhe/src/high_level.rs b/sunscreen_tfhe/src/high_level.rs index 420cbe647..6d1143fe5 100644 --- a/sunscreen_tfhe/src/high_level.rs +++ b/sunscreen_tfhe/src/high_level.rs @@ -200,8 +200,8 @@ pub mod keygen { /// However, anyone who possesses `glwe_key` can easily use the returned /// [`BootstrapKey`] to recover `sk`. pub fn generate_bootstrapping_key( - sk: &LweSecretKey, - glwe_key: &GlweSecretKey, + sk: &LweSecretKeyRef, + glwe_key: &GlweSecretKeyRef, lwe: &LweDef, glwe: &GlweDef, radix: &RadixDecomposition, @@ -844,7 +844,7 @@ pub mod evaluation { ) -> LweCiphertext { let mut out = LweCiphertext::new(&glwe.as_lwe_def()); - crate::ops::bootstrapping::programmable_bootstrap( + crate::ops::bootstrapping::programmable_bootstrap_univariate( &mut out, input, lut, bsk, lwe, glwe, radix, ); diff --git a/sunscreen_tfhe/src/ops/bootstrapping/circuit_bootstrapping.rs b/sunscreen_tfhe/src/ops/bootstrapping/circuit_bootstrapping.rs index a4f6e1b69..c63efeee3 100644 --- a/sunscreen_tfhe/src/ops/bootstrapping/circuit_bootstrapping.rs +++ b/sunscreen_tfhe/src/ops/bootstrapping/circuit_bootstrapping.rs @@ -1,13 +1,15 @@ use num::Complex; +use sunscreen_math::Zero; use crate::{ dst::FromMutSlice, entities::{ BootstrapKeyFftRef, CircuitBootstrappingKeyswitchKeysRef, GgswCiphertextRef, - LweCiphertextListRef, LweCiphertextRef, UnivariateLookupTableRef, + GlweCiphertextRef, LweCiphertextListRef, LweCiphertextRef, UnivariateLookupTableRef, }, ops::{ - bootstrapping::programmable_bootstrap, homomorphisms::rotate, + bootstrapping::generalized_programmable_bootstrap, ciphertext::sample_extract, + homomorphisms::rotate, keyswitch::private_functional_keyswitch::private_functional_keyswitch, }, scratch::allocate_scratch_ref, @@ -206,13 +208,11 @@ fn level_0_to_level_2( pbs_radix: &RadixDecomposition, cbs_radix: &RadixDecomposition, ) { + allocate_scratch_ref!(glwe_out, GlweCiphertextRef, (glwe_2.dim)); allocate_scratch_ref!(lut, UnivariateLookupTableRef, (glwe_2.dim)); allocate_scratch_ref!(lwe_rotated, LweCiphertextRef, (lwe_0.dim)); - allocate_scratch_ref!( - lwe_bootstrapped, - LweCiphertextRef, - (glwe_2.as_lwe_def().dim) - ); + allocate_scratch_ref!(extracted, LweCiphertextRef, (glwe_2.as_lwe_def().dim)); + assert!(cbs_radix.count.0 < 8); // Rotate our input by q/4, putting 0 centered on q/4 and 1 centered on // -q/4. @@ -223,42 +223,96 @@ fn level_0_to_level_2( lwe_0, ); + let log_v = if cbs_radix.count.0.is_power_of_two() { + cbs_radix.count.0.ilog2() + } else { + cbs_radix.count.0.ilog2() + 1 + }; + + fill_multifunctional_cbs_decomposition_lut(lut, glwe_2, cbs_radix); + + generalized_programmable_bootstrap( + glwe_out, + lwe_rotated, + lut, + bsk, + 0, + log_v, + lwe_0, + glwe_2, + pbs_radix, + ); + for (i, lwe_2) in lwes_2.ciphertexts_mut(&glwe_2.as_lwe_def()).enumerate() { let cur_level = i + 1; - - // Treat value as a T_{b^l+1} with one extra place for rounding as the last - // step. let plaintext_bits = PlaintextBits((cbs_radix.radix_log.0 * cur_level + 1) as u32); - // Exploiting the fact that our LUT is negacyclic, we can encode -1 in T_{b^l+1} - // everywhere. Any lookup < q/2 will give -1 and any lookup > q/2 will - // give 1. Since we've shifted our input lwe by q/4, a 1 plaintext - // value will map to 1 and a 0 will map to -1. - let minus_one = (S::one() << plaintext_bits.0 as usize) - S::one(); - - lut.fill_with_constant(minus_one, glwe_2, plaintext_bits); - - programmable_bootstrap( - lwe_bootstrapped, - lwe_rotated, - lut, - bsk, - lwe_0, - glwe_2, - pbs_radix, - ); + sample_extract(extracted, glwe_out, i, glwe_2); // Now we rotate our message containing -1 or 1 by 1 (wrt plaintext_bits). // This will overflow -1 to 0 and cause 1 to wrap to 2. rotate( lwe_2, - lwe_bootstrapped, + extracted, Torus::encode(S::one(), plaintext_bits), &glwe_2.as_lwe_def(), ); } } +fn fill_multifunctional_cbs_decomposition_lut( + lut: &mut UnivariateLookupTableRef, + glwe: &GlweDef, + cbs_radix: &RadixDecomposition, +) { + lut.clear(); + + // Pick a largish number of levels nobody would ever exceed. + let mut levels = [Torus::zero(); 16]; + + assert!(cbs_radix.count.0 < levels.len()); + + // Compute our base decomposition factors. + // Exploiting the fact that our LUT is negacyclic, we can encode -1 in T_{b^l+1} + // everywhere. Any lookup < q/2 will give -1 and any lookup > q/2 will + // give 1. Since we've shifted our input lwe by q/4, a 1 plaintext + // value will map to 1 and a 0 will map to -1. + for (i, x) in levels.iter_mut().enumerate() { + let i = i + 1; + if i * cbs_radix.radix_log.0 + 1 < S::BITS as usize { + let plaintext_bits = PlaintextBits((cbs_radix.radix_log.0 * i + 1) as u32); + + let minus_one = (S::one() << plaintext_bits.0 as usize) - S::one(); + *x = Torus::encode(minus_one, plaintext_bits); + } + } + + // Fill the table with alternating factors padded with zeros to a power of 2 + let log_v = if cbs_radix.count.0.is_power_of_two() { + cbs_radix.count.0.ilog2() + } else { + cbs_radix.count.0.ilog2() + 1 + }; + + let v = 0x1usize << log_v; + + for (i, x) in lut + .glwe_mut() + .b_mut(glwe) + .coeffs_mut() + .iter_mut() + .enumerate() + { + let fn_id = i % v; + + *x = if fn_id < cbs_radix.count.0 { + levels[fn_id] + } else { + Torus::zero() + }; + } +} + /// Bootstraps a level 2 GLWE ciphertext to a level 1 GLWE ciphertext. pub fn level_2_to_level1( result: &mut GgswCiphertextRef, @@ -322,6 +376,7 @@ mod tests { let sk = keygen::generate_binary_lwe_sk(&TEST_LWE_DEF_1); let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params); + let bsk = keygen::generate_bootstrapping_key( &sk, &glwe_sk, diff --git a/sunscreen_tfhe/src/ops/bootstrapping/programmable_bootstrapping.rs b/sunscreen_tfhe/src/ops/bootstrapping/programmable_bootstrapping.rs index 413d883a2..f5e2eda27 100644 --- a/sunscreen_tfhe/src/ops/bootstrapping/programmable_bootstrapping.rs +++ b/sunscreen_tfhe/src/ops/bootstrapping/programmable_bootstrapping.rs @@ -10,7 +10,10 @@ use crate::{ }, ops::{ bootstrapping::rotate_glwe_positive_monomial_negacyclic, - ciphertext::{add_lwe_inplace, modulus_switch, sample_extract, scalar_mul_ciphertext_mad}, + ciphertext::{ + add_lwe_inplace, lwe_ciphertext_modulus_switch, sample_extract, + scalar_mul_ciphertext_mad, + }, encryption::encrypt_ggsw_ciphertext_scalar, fft_ops::cmux, }, @@ -26,7 +29,7 @@ use super::rotate_glwe_negative_monomial_negacyclic; /// the secret key being encrypted. /// /// See -/// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap) +/// [`programmable_bootstrap_univariate`](crate::ops::bootstrapping::programmable_bootstrap_univariate) /// for an example of how to use this key. pub fn generate_bootstrap_key( bootstrap_key: &mut BootstrapKeyRef, @@ -115,9 +118,14 @@ fn generate_negacyclic_lut( /// The input `map` is used for generating programmable bootstrapping LUTs. This /// function takes an element in the plaintext space and must produce another /// element in the plaintext space. +/// +/// # Remarks +/// This function supports multiple functions, which appear as adjacent +/// entries in the ciphertext (padded with 0 up to a power of 2). This +/// pattern repeats until `n/p` terms have been filled. pub(crate) fn generate_lut( output: &mut PolynomialRef>, - map: F, + maps: &[F], params: &GlweDef, plaintext_bits: PlaintextBits, ) where @@ -127,6 +135,16 @@ pub(crate) fn generate_lut( let p = (1 << plaintext_bits.0) as usize; let n = params.dim.polynomial_degree.0; + let v = maps.len(); + + let log_v = if v.is_power_of_two() { + v.ilog2() + } else { + v.ilog2() + 1 + }; + + let ceil_v = 0x1usize << log_v; + assert!(n >= p); let stride = n / p; @@ -136,14 +154,20 @@ pub(crate) fn generate_lut( let c = output.coeffs_mut(); for (j, p_i_unmapped) in (0..=p - 1).enumerate() { - let p_i = map(p_i_unmapped as u64); + // Insert a stride amount into the LUT + c[j * stride..(j + 1) * stride].iter_mut().enumerate().for_each(|(k, c)| { + let fn_id = k % ceil_v; - assert!(p_i < (p as u64), "The map function must produce a value less than p. Map produced the relation ({} -> {})", p_i_unmapped, p_i); + let p_i = if fn_id < v { + maps[fn_id](p_i_unmapped as u64) + } else { + 0u64 + }; - let p_i = p_i << delta; + assert!(p_i < (p as u64), "The map function must produce a value less than p. Map produced the relation ({} -> {})", p_i_unmapped, p_i); + + let p_i = p_i << delta; - // Insert a stride amount into the LUT - c[j * stride..(j + 1) * stride].iter_mut().for_each(|c| { *c = Torus::from(S::from_u64(p_i)); }); } @@ -177,7 +201,7 @@ pub(crate) fn generate_lut( /// use sunscreen_tfhe::{ /// high_level::{keygen, encryption, fft}, /// entities::{UnivariateLookupTable, LweCiphertext}, -/// ops::bootstrapping::programmable_bootstrap, +/// ops::bootstrapping::programmable_bootstrap_univariate, /// params::{ /// GLWE_1_1024_80, /// LWE_512_80, @@ -233,7 +257,7 @@ pub(crate) fn generate_lut( /// /// // Perform the programmable bootstrapping /// let mut result = LweCiphertext::new(&glwe_params.as_lwe_def()); -/// programmable_bootstrap( +/// programmable_bootstrap_univariate( /// &mut result, /// &input, /// &lut, @@ -261,7 +285,7 @@ pub(crate) fn generate_lut( /// [`programmable_bootstrap_bivariate`](programmable_bootstrap_bivariate) and /// its associated LUT /// [`BivariateLookupTable`](crate::entities::BivariateLookupTable). -pub fn programmable_bootstrap( +pub fn programmable_bootstrap_univariate( output: &mut LweCiphertextRef, input: &LweCiphertextRef, lut: &UnivariateLookupTableRef, @@ -271,6 +295,59 @@ pub fn programmable_bootstrap( radix: &RadixDecomposition, ) where S: TorusOps, +{ + allocate_scratch_ref!(glwe, GlweCiphertextRef, (glwe_params.dim)); + + generalized_programmable_bootstrap( + glwe, + input, + lut, + bootstrap_key, + 0, + 0, + lwe_params, + glwe_params, + radix, + ); + + // 3. Sample extract. + sample_extract(output, glwe, 0, glwe_params); +} + +#[allow(clippy::too_many_arguments)] +/// A generalized version of programmable bootstrapping. +/// Computes a function `lut` of the encrypted `input`. +/// However, this generalization features the ability to select which +/// bits to take during modulus switching. This capability enables +/// encoding multiple functions into `lut` and bootstrapping each of them +/// simultaneously. +/// +/// # Remarks +/// While [`programmable_bootstrap_univariate`] and +/// [`programmable_bootstrap_bivariate`] compute a single function of the +/// input ciphertext, this can compute multiple functions. To do this, +/// create a [`UnivariateLookupTable`](crate::entities::UnivariateLookupTable) using +/// [`UnivariateLookupTable::trivivial_multifunctional`](crate::entities::UnivariateLookupTable::trivivial_multifunctional). +/// +/// `log_v` should equal `ceil(log2(maps.len()))` for the `maps` you +/// used when creating the LUT. +/// +/// `log_chi` is the number of most-significant bits to drop during +/// bootstrapping. Generally, you should set this to zero unless building +/// other cryptographic primitives, such as Without Padding Bootstrapping +/// (WoP-PBS) +pub fn generalized_programmable_bootstrap( + output: &mut GlweCiphertextRef, + input: &LweCiphertextRef, + lut: &UnivariateLookupTableRef, + bootstrap_key: &BootstrapKeyFftRef>, + log_chi: u32, + log_v: u32, + lwe_params: &LweDef, + glwe_params: &GlweDef, + radix: &RadixDecomposition, +) where + S: TorusOps, { lwe_params.assert_valid(); glwe_params.assert_valid(); @@ -278,7 +355,7 @@ pub fn programmable_bootstrap( bootstrap_key.assert_valid(lwe_params, glwe_params, radix); lut.assert_valid(glwe_params); input.assert_valid(lwe_params); - output.assert_valid(&glwe_params.as_lwe_def()); + output.assert_valid(glwe_params); // Steps: // 1. Modulus switch the ciphertext to 2N. @@ -292,7 +369,7 @@ pub fn programmable_bootstrap( // 1. Modulus switch the ciphertext to 2N. let mut ct = input.to_owned(); - modulus_switch(&mut ct, S::BITS, two_n, lwe_params); + lwe_ciphertext_modulus_switch(&mut ct, log_chi, log_v, two_n, lwe_params); let (ct_a, ct_b) = ct.a_b(lwe_params); @@ -300,11 +377,10 @@ pub fn programmable_bootstrap( // key (the input LWE secret key bits). // Perform V_0 ^ X^{-b} - allocate_scratch_ref!(cmux_output, GlweCiphertextRef, (glwe_params.dim)); - cmux_output.clear(); + output.clear(); rotate_glwe_negative_monomial_negacyclic( - cmux_output, + output, lut.glwe(), ct_b.inner().to_u64() as usize, glwe_params, @@ -315,29 +391,19 @@ pub fn programmable_bootstrap( // Perform the cmux tree from the bootstrap key with the relation // V_n = V_{n-1} ^ X^{a_{n-1} s_{n-1}} for (a_i, index_select) in ct_a.iter().zip(bootstrap_key.rows(glwe_params, radix)) { - let tmp = cmux_output.to_owned(); + let tmp = output.to_owned(); // This operation performs a copy so the rotated_ct doesn't need to be // cleared. rotate_glwe_positive_monomial_negacyclic( rotated_ct, - cmux_output, + output, a_i.inner().to_u64() as usize, glwe_params, ); - cmux( - cmux_output, - &tmp, - rotated_ct, - index_select, - glwe_params, - radix, - ); + cmux(output, &tmp, rotated_ct, index_select, glwe_params, radix); } - - // 3. Sample extract. - sample_extract(output, cmux_output, 0, glwe_params); } /// Evaluate a bivariate function on a packed input. @@ -379,7 +445,7 @@ pub(crate) fn generate_bivariate_lut( generate_lut( output, - wrapped_func, + &[wrapped_func], params, PlaintextBits(plaintext_bits.0 + carry_bits.0), ); @@ -500,7 +566,7 @@ pub(crate) fn generate_bivariate_lut( /// # See also /// /// For the univariate version of programmable bootstrapping, see -/// [`programmable_bootstrap`](programmable_bootstrap) and its associated LUT +/// [`programmable_bootstrap_univariate`](programmable_bootstrap_univariate) and its associated LUT /// [`UnivariateLookupTable`](crate::entities::UnivariateLookupTable). #[allow(clippy::too_many_arguments)] pub fn programmable_bootstrap_bivariate( @@ -540,7 +606,7 @@ pub fn programmable_bootstrap_bivariate( scalar_mul_ciphertext_mad(pbs_input, &S::from_u64(shift), left_input, lwe_params); add_lwe_inplace(pbs_input, right_input, lwe_params); - programmable_bootstrap( + programmable_bootstrap_univariate( output, pbs_input, lut.as_univariate(), @@ -556,15 +622,15 @@ mod tests { use crate::{ entities::{ - BivariateLookupTable, BootstrapKey, BootstrapKeyFft, LweCiphertext, LweKeyswitchKey, - UnivariateLookupTable, + BivariateLookupTable, BootstrapKey, BootstrapKeyFft, GlweCiphertext, LweCiphertext, + LweKeyswitchKey, UnivariateLookupTable, }, - high_level::{keygen, TEST_GLWE_DEF_1, TEST_LWE_DEF_1, TEST_RADIX}, + high_level::{encryption, fft, keygen, TEST_GLWE_DEF_1, TEST_LWE_DEF_1, TEST_RADIX}, ops::{ encryption::{decrypt_ggsw_ciphertext, encrypt_lwe_ciphertext}, keyswitch::lwe_keyswitch_key::generate_keyswitch_key_lwe, }, - RoundedDiv, GLWE_1_1024_80, + RoundedDiv, GLWE_1_1024_80, LWE_512_80, }; use super::*; @@ -681,7 +747,15 @@ mod tests { let mut new_ct = LweCiphertext::new(&glwe.as_lwe_def()); - programmable_bootstrap(&mut new_ct, &original_ct, &lut, &bsk, &lwe, &glwe, &radix); + programmable_bootstrap_univariate( + &mut new_ct, + &original_ct, + &lut, + &bsk, + &lwe, + &glwe, + &radix, + ); let decoded = glwe_sk .to_lwe_secret_key() @@ -844,4 +918,67 @@ mod tests { } } } + + #[test] + fn can_generalized_bootstrap() { + let radix = &TEST_RADIX; + let lwe = &LWE_512_80; + let glwe = &GLWE_1_1024_80; + + // 1 message bit + 1 padding + let bits = PlaintextBits(1); + + let lwe_sk = keygen::generate_binary_lwe_sk(lwe); + let glwe_sk = keygen::generate_binary_glwe_sk(glwe); + let bs_key = keygen::generate_bootstrapping_key(&lwe_sk, &glwe_sk, lwe, glwe, radix); + let bs_key = fft::fft_bootstrap_key(&bs_key, lwe, glwe, radix); + + // Fill the LUT with nonsense and we'll overwrite it with + // the correct encoding. + let lut = UnivariateLookupTable::trivivial_multifunctional( + [|x| x % 2, |x| (x + 1) % 2, |x| x % 2].as_slice(), + glwe, + bits, + ); + + for i in [0, 1] { + //let input = encryption::encrypt_lwe_secret(i, &lwe_sk, lwe, bits); + let input = encryption::trivial_lwe(i, lwe, PlaintextBits(2)); + let mut output = GlweCiphertext::new(glwe); + + generalized_programmable_bootstrap( + &mut output, + &input, + &lut, + &bs_key, + 0, + 3, + lwe, + glwe, + radix, + ); + + let res = encryption::decrypt_glwe(&output, &glwe_sk, glwe, bits); + + if i == 0 { + assert_eq!(res.coeffs()[0], 0); + assert_eq!(res.coeffs()[1], 1); + assert_eq!(res.coeffs()[2], 0); + assert_eq!(res.coeffs()[3], 0); + assert_eq!(res.coeffs()[4], 0); + assert_eq!(res.coeffs()[5], 1); + assert_eq!(res.coeffs()[6], 0); + assert_eq!(res.coeffs()[7], 0); + } else { + assert_eq!(res.coeffs()[0], 1); + assert_eq!(res.coeffs()[1], 0); + assert_eq!(res.coeffs()[2], 1); + assert_eq!(res.coeffs()[3], 0); + assert_eq!(res.coeffs()[4], 1); + assert_eq!(res.coeffs()[5], 0); + assert_eq!(res.coeffs()[6], 1); + assert_eq!(res.coeffs()[7], 0); + } + } + } } diff --git a/sunscreen_tfhe/src/ops/ciphertext/lwe_ciphertext_ops.rs b/sunscreen_tfhe/src/ops/ciphertext/lwe_ciphertext_ops.rs index dbfa8b019..fe77f2002 100644 --- a/sunscreen_tfhe/src/ops/ciphertext/lwe_ciphertext_ops.rs +++ b/sunscreen_tfhe/src/ops/ciphertext/lwe_ciphertext_ops.rs @@ -1,4 +1,4 @@ -use crate::{entities::LweCiphertextRef, LweDef, RoundedDiv, Torus, TorusOps}; +use crate::{entities::LweCiphertextRef, LweDef, Torus, TorusOps}; /// Add the coefficients of a to the coefficients of c in place. pub fn add_lwe_inplace(c: &mut LweCiphertextRef, a: &LweCiphertextRef, params: &LweDef) @@ -68,10 +68,37 @@ pub(crate) fn scalar_mul_ciphertext_mad( /// Perform modulus switching on a ciphertext. We are assuming that moduli are /// both powers of two, and that the original number of bits is greater than the /// new number of bits. -pub fn modulus_switch( +/// +/// # Remarks +/// When performing the mod switch, the first `log_chi` MSBs are skipped in the input and +/// the message is padded with `log_v` bits in the LSB. Example: +/// +/// ```ignore +/// chi x r dropped +/// --------------------------------------------- +/// | 0 0 | 1 1 0 1 0 | 1 | 1 0 1 0 1 1 0 1 0 ... +/// +/// | +/// V +/// +/// | 1 1 0 1 1 | 0 0 0 | +/// ``` +/// +/// We drop the first `log_chi` bits then round the `x` section using the `r` bit. We copy +/// down the bits in the rounded `x` value and append `log_v` 0s as LSBs. +/// +/// When performing vanilla programmable bootstrapping, `log_chi` and `log_v` will be zero. +/// `log_chi` and `log_v` are used when performing multi-output PBS. +/// +/// For more information on generalized bootstrapping, see +/// "Improved Programmable Bootstrapping with Larger Precision and Efficient Arithmetic +/// Circuits for TFHE" +/// by Chillotti et al. +pub fn lwe_ciphertext_modulus_switch( ct: &mut LweCiphertextRef, - original_bits: u32, - new_bits: u32, + log_chi: u32, + log_v: u32, + log_modulus: u32, params: &LweDef, ) where S: TorusOps, @@ -81,13 +108,58 @@ pub fn modulus_switch( // We specifically want to zero out the MSBs instead of shifting them back // around. for a in c_a { - let c = a.inner().to_u64() as u128; - let res = (c * (1 << new_bits)).div_rounded(1 << original_bits as u128); - *a = Torus::from(S::from_u64(res as u64)); + let res = modulus_switch( + a.inner(), + log_chi as usize, + log_v as usize, + log_modulus as usize, + ); + *a = Torus::from(res); } - let c = c_b.inner().to_u64() as u128; - let res = (c * (1 << new_bits)).div_rounded(1 << original_bits as u128); + let res = modulus_switch( + c_b.inner(), + log_chi as usize, + log_v as usize, + log_modulus as usize, + ); - *c_b = Torus::from(S::from_u64(res as u64)); + *c_b = Torus::from(res); +} + +#[inline(never)] +fn modulus_switch(x: S, log_chi: usize, log_v: usize, log_modulus: usize) -> S { + let one = S::one(); + let mask = (one << log_modulus) - one; + let x = x << log_chi; + let shift_amount = S::BITS as usize - (log_modulus - log_v); + + let round = (x >> (shift_amount - 1)) & one; + let x = x >> shift_amount; + + // TODO: Non-power-of_two input moduli + + (x.wrapping_add(&round) & mask) << log_v +} + +#[cfg(test)] +mod tests { + use super::modulus_switch; + + #[test] + fn can_modulus_switch() { + let x = 0xDEADBEEF_BEEFDEADu64; + + let y = modulus_switch(x, 0, 0, 10); + assert_eq!(y, 0b11_0111_1011); + + let y = modulus_switch(x, 2, 0, 10); + assert_eq!(y, 0b01_1110_1011); + + let y = modulus_switch(x, 0, 3, 10); + assert_eq!(y, 0b11_0111_1000); + + let y = modulus_switch(x, 2, 3, 10); + assert_eq!(y, 0b01_1110_1000); + } }