diff --git a/sunscreen_tfhe/src/high_level.rs b/sunscreen_tfhe/src/high_level.rs index 6ad1e88ed..957c964e6 100644 --- a/sunscreen_tfhe/src/high_level.rs +++ b/sunscreen_tfhe/src/high_level.rs @@ -863,8 +863,9 @@ pub mod evaluation { use crate::{ entities::{ BootstrapKeyFft, BootstrapKeyFftRef, CircuitBootstrappingKeyswitchKeysRef, - GgswCiphertext, GgswCiphertextFftRef, GlweCiphertext, GlweCiphertextRef, LweCiphertext, - LweCiphertextRef, LweKeyswitchKeyRef, UnivariateLookupTableRef, + GgswCiphertext, GgswCiphertextFftRef, GlevCiphertext, GlevCiphertextRef, + GlweCiphertext, GlweCiphertextRef, LweCiphertext, LweCiphertextRef, LweKeyswitchKeyRef, + UnivariateLookupTableRef, }, GlweDef, LweDef, RadixDecomposition, }; @@ -903,6 +904,41 @@ pub mod evaluation { result } + /// Perform a multiplexing operation over [`GlevCiphertext`]s. + /// When `b_fft` encrypts a zero polynomial, the resulting [`GlevCiphertext`] will + /// the same message as `d_0`. When `b_fft` encrypts the 1 polynomial, the result will + /// contain the same message as `d_1`. + /// + /// # Remarks + /// `b_fft`, `d_0`, and `d_1` must all be encrypted under the same + /// [`GlweSecretKey`](crate::entities::GlweSecretKey). This implies `params` must + /// correspond with all three values. + /// + /// Additionally, `radix` must correspond to `b_fft`. + /// + /// For + /// [`GgswCiphertext`] resulting from [`circuit_bootstrap`] operations, + /// `radix` must be the same as `cbs_radix` and `params` must be the same as + /// `glwe_1`. + /// + /// # Panics + /// If `params` doesn't correspond with `b_fft`, `d_0`, `d_1`. + /// If `radix` doesn't correspond with `b_fft`. + /// If `radix` or `params` are invalid. + pub fn glev_cmux( + b_fft: &GgswCiphertextFftRef>, + d_0: &GlevCiphertextRef, + d_1: &GlevCiphertextRef, + params: &GlweDef, + radix: &RadixDecomposition, + ) -> GlevCiphertext { + let mut result = GlevCiphertext::new(params, radix); + + crate::ops::fft_ops::glev_cmux(&mut result, d_0, d_1, b_fft, params, radix); + + result + } + #[allow(clippy::too_many_arguments)] /// Perform a programmable bootstrapping operation. Bootstrapping takes /// `input` and produces a new ciphertext with a fixed noise level, applying @@ -965,7 +1001,7 @@ pub mod evaluation { /// For step 2, we use private functional keyswitching (PFKS) to transform the /// `cbs_radix.count` [LweCiphertext]s encrypted under `glwe_2` into /// `cbs_radix.count * glwe_2.size + 1` [GlweCiphertext]s. The PFKS operations multiply each - /// [GlevCiphertext](crate::entities::GlevCiphertext) by the corresponding polynomial in + /// [GlevCiphertext] by the corresponding polynomial in /// the `glwe_1` [GlweSecretKey](crate::entities::GlweSecretKey) to create a valid /// [GgswCiphertext]. /// diff --git a/sunscreen_tfhe/src/ops/encryption/glev_encryption.rs b/sunscreen_tfhe/src/ops/encryption/glev_encryption.rs index b02e4d368..162246889 100644 --- a/sunscreen_tfhe/src/ops/encryption/glev_encryption.rs +++ b/sunscreen_tfhe/src/ops/encryption/glev_encryption.rs @@ -49,12 +49,19 @@ pub(crate) fn encrypt_secret_glev_ciphertext_generic( ); for (j, glwe) in glev_ciphertext.glwe_ciphertexts_mut(params).enumerate() { - scale_msg(scaled_msg, msg, decomposition_radix_log, j); + scale_msg_by_gadget_factor(scaled_msg, msg, decomposition_radix_log, j); encrypt(glwe, scaled_msg, glwe_secret_key, params); } } -fn scale_msg( +/// Multiplies each [`Torus`] coefficient in `msg` by `1/beta^(j + 1)`, writing the result +/// to `scaled_msg`. +/// +/// # Remarks +/// GLEV ciphertexts feature redundant encryptions of `msg` where each message is scaled +/// by a corresponding gadget factor. This sets up some clever algebraic cancellation +/// that enables the GGSW-times-GLWE outer product. +pub fn scale_msg_by_gadget_factor( scaled_msg: &mut PolynomialRef>, msg: &PolynomialRef>, decomposition_radix_log: usize, @@ -146,7 +153,7 @@ pub fn encrypt_rlev_ciphertext( ); for (j, glwe) in rlev_ciphertext.glwe_ciphertexts_mut(params).enumerate() { - scale_msg(scaled_msg, msg, radix.radix_log.0, j); + scale_msg_by_gadget_factor(scaled_msg, msg, radix.radix_log.0, j); dbg!(&scaled_msg.coeffs()[0..16]); diff --git a/sunscreen_tfhe/src/ops/encryption/glwe_encryption.rs b/sunscreen_tfhe/src/ops/encryption/glwe_encryption.rs index 9092c9e87..5ea6786e8 100644 --- a/sunscreen_tfhe/src/ops/encryption/glwe_encryption.rs +++ b/sunscreen_tfhe/src/ops/encryption/glwe_encryption.rs @@ -98,6 +98,9 @@ pub fn trivially_encrypt_glwe_ciphertext( } /// Decrypt GLWE ciphertext `ct` into `msg` using secret key `sk`. +/// +/// # Remarks +/// This method does not decode the resulting `msg`. pub fn decrypt_glwe_ciphertext( msg: &mut PolynomialRef>, ct: &GlweCiphertextRef, diff --git a/sunscreen_tfhe/src/ops/fft_ops.rs b/sunscreen_tfhe/src/ops/fft_ops.rs index d0fcd4568..864d0325b 100644 --- a/sunscreen_tfhe/src/ops/fft_ops.rs +++ b/sunscreen_tfhe/src/ops/fft_ops.rs @@ -132,6 +132,16 @@ pub fn glwe_polynomial_mad( /// where the output `c` is a different encryption than either of the initial /// inputs. Note that this will result in higher noise than in the original /// ciphertexts. +/// +/// # Remarks +/// To make some internal computations, this function actually homomorphically computes +/// +/// ```text +/// c += cmux(d_0, d_1, b_fft); +/// ``` +/// +/// Unless you want this behavior, you should first call `c.clear()`, use a freshly +/// allocated `c`, or use [crate::high_level::evaluation::cmux]. pub fn cmux( c: &mut GlweCiphertextRef, d_0: &GlweCiphertextRef, @@ -166,6 +176,45 @@ pub fn cmux( add_glwe_ciphertexts(c, prod, d_0, params); } +/// Compute a cmux between [`GlevCiphertext`](crate::entities::GlevCiphertext)s `d_0`, +/// `d_1`, and select bit `b_fft`. +/// +/// # Remarks +/// A glev_cmux simply computes a cmux over each of the constituent GLWE ciphertexts within +/// the +/// +/// `ggsw_radix` describes the [`RadixDecomposition`] of the `b_fft` +/// [`GgswCiphertextFft`](crate::entities::GgswCiphertextFft), ciphertext, not the GLEV +/// decomposition. +/// +/// # Remarks +/// To make some internal computations, this function actually homomorphically computes +/// +/// ```text +/// c += cmux(d_0, d_1, b_fft); +/// ``` +/// +/// Unless you want this behavior, you should first call `c.clear()`, use a freshly +/// allocated `c`, or use [crate::high_level::evaluation::glev_cmux]. +pub fn glev_cmux( + c: &mut GlevCiphertextRef, + d_0: &GlevCiphertextRef, + d_1: &GlevCiphertextRef, + b_fft: &GgswCiphertextFftRef>, + params: &GlweDef, + ggsw_radix: &RadixDecomposition, +) where + S: TorusOps, +{ + for ((c, d_0), d_1) in c + .glwe_ciphertexts_mut(params) + .zip(d_0.glwe_ciphertexts(params)) + .zip(d_1.glwe_ciphertexts(params)) + { + cmux(c, d_0, d_1, b_fft, params, ggsw_radix); + } +} + /// This is the same as `generate_encrypted_secret_key_component` but it assumes /// that all the positions where the index is not being written are already /// zeroed out. @@ -400,11 +449,12 @@ mod tests { GgswCiphertext, GgswCiphertextFft, GlevCiphertext, GlweCiphertext, GlweCiphertextFft, GlweSecretKey, Polynomial, SchemeSwitchKey, SchemeSwitchKeyFft, }, - high_level::*, + high_level::{self, *}, ops::{ bootstrapping::{generate_scheme_switch_key, scheme_switch}, encryption::{ decrypt_ggsw_ciphertext, decrypt_glev_ciphertext, encrypt_secret_glev_ciphertext, + scale_msg_by_gadget_factor, }, }, polynomial::polynomial_external_mad, @@ -750,4 +800,61 @@ mod tests { _can_cmux_after_scheme_switch_fft(message); } } + + #[test] + fn can_glev_cmux() { + let params = TEST_RLWE_DEF; + let radix = TEST_RADIX; + + let sk = keygen::generate_binary_glwe_sk(¶ms); + + let zero = Polynomial::zero(params.dim.polynomial_degree.0); + let zero_ct = high_level::encryption::trivial_binary_glev(&zero, ¶ms, &radix); + + let mut one = Polynomial::zero(params.dim.polynomial_degree.0); + zero.map_into(&mut one, |_| 1); + let one_ct = high_level::encryption::trivial_binary_glev(&one, ¶ms, &radix); + + for _ in 0..100 { + let sel_0 = + high_level::encryption::encrypt_ggsw(0, &sk, ¶ms, &radix, PlaintextBits(1)); + let sel_0 = high_level::fft::fft_ggsw(&sel_0, ¶ms, &radix); + + let sel_1 = + high_level::encryption::encrypt_ggsw(1, &sk, ¶ms, &radix, PlaintextBits(1)); + let sel_1 = high_level::fft::fft_ggsw(&sel_1, ¶ms, &radix); + + let mut result = GlevCiphertext::new(¶ms, &radix); + + glev_cmux(&mut result, &zero_ct, &one_ct, &sel_0, ¶ms, &radix); + + for glwe in result.glwe_ciphertexts(¶ms) { + let actual = + high_level::encryption::decrypt_glwe(glwe, &sk, ¶ms, PlaintextBits(1)); + + assert_eq!(actual, zero); + } + + glev_cmux(&mut result, &zero_ct, &one_ct, &sel_1, ¶ms, &radix); + + for (i, glwe) in result.glwe_ciphertexts(¶ms).enumerate() { + // The i'th decomposition factor requires (i + 1) * radix_log.0 bits of + // message space. + let pt_bits = PlaintextBits(((i + 1) * radix.radix_log.0) as u32); + let actual = high_level::encryption::decrypt_glwe(glwe, &sk, ¶ms, pt_bits); + + let mut scaled = Polynomial::zero(params.dim.polynomial_degree.0); + + // Compute 1 / beta^(i + 1). This will be shifted into the MSBs, so we + // need to decode this message before we can compare + scale_msg_by_gadget_factor(&mut scaled, one.as_torus(), radix.radix_log.0, i); + + // Decode expected msg. No need to round because we didn't encrypt it + // hence no noise. + let expected = scaled.map(|x| x.inner() >> (u64::BITS - pt_bits.0)); + + assert_eq!(actual, expected); + } + } + } }