Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract out GLev encryption into its own operations #382

Merged
merged 3 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions sunscreen_tfhe/src/entities/glev_ciphertext.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use num::Complex;
use num::{Complex, Zero};
use serde::{Deserialize, Serialize};

use crate::{dst::OverlaySize, GlweDef, GlweDimension, RadixCount, Torus, TorusOps};
use crate::{
dst::OverlaySize, GlweDef, GlweDimension, RadixCount, RadixDecomposition, Torus, TorusOps,
};

use super::{
GlevCiphertextFftRef, GlweCiphertextIterator, GlweCiphertextIteratorMut, GlweCiphertextRef,
};

dst! {
/// A GLEV ciphertext. For the FFT variant, see
/// A GLev ciphertext. For the FFT variant, see
/// [`GlevCiphertextFft`](crate::entities::GlevCiphertextFft).
GlevCiphertext,
GlevCiphertextRef,
Expand All @@ -29,6 +31,20 @@ where
}
}

impl<S> GlevCiphertext<S>
where
S: TorusOps,
{
/// Create a new zero GLev ciphertext with the given parameters.
pub fn new(params: &GlweDef, radix: &RadixDecomposition) -> Self {
let elems = GlevCiphertextRef::<S>::size((params.dim, radix.count));

Self {
data: avec![Torus::zero(); elems],
}
}
}

impl<S> GlevCiphertextRef<S>
where
S: TorusOps,
Expand All @@ -55,4 +71,12 @@ where
i.fft(fft, params);
}
}

/// Assert that this entityt is valid.
pub fn assert_valid(&self, params: &GlweDef, radix: &RadixDecomposition) {
assert_eq!(
self.data.len(),
GlevCiphertextRef::<S>::size((params.dim, radix.count))
);
}
}
112 changes: 22 additions & 90 deletions sunscreen_tfhe/src/ops/encryption/ggsw_encryption.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
use num::Zero;

use crate::{
dst::FromMutSlice,
entities::{GgswCiphertextRef, GlweCiphertextRef, GlweSecretKeyRef, Polynomial, PolynomialRef},
polynomial::{polynomial_external_mad, polynomial_scalar_mad, polynomial_scalar_mul},
ops::encryption::encrypt_glev_ciphertext_generic,
polynomial::polynomial_external_mad,
scratch::allocate_scratch_ref,
GlweDef, PlaintextBits, RadixDecomposition, Torus, TorusOps,
};

use super::{
decrypt_glwe_ciphertext, encrypt_glwe_ciphertext_secret,
trivially_encrypt_glwe_with_sk_argument,
decrypt_glwe_in_glev, encrypt_glwe_ciphertext_secret, trivially_encrypt_glwe_with_sk_argument,
};

/// Perform a ggsw encryption. This is generic in case a trivial GGSW encryption
/// Perform a GGSW encryption. This is generic in case a trivial GGSW encryption
/// is wanted (for example, for testing purposes).
pub(crate) fn encrypt_ggsw_ciphertext_generic<S>(
ggsw_ciphertext: &mut GgswCiphertextRef<S>,
Expand All @@ -34,7 +32,6 @@ pub(crate) fn encrypt_ggsw_ciphertext_generic<S>(
let max_val = S::from_u64(0x1 << plaintext_bits.0);
assert!(msg.coeffs().iter().all(|x| *x < max_val));

let decomposition_radix_log = radix.radix_log.0;
let polynomial_degree = params.dim.polynomial_degree.0;
let glwe_size = params.dim.size.0;

Expand Down Expand Up @@ -62,18 +59,7 @@ pub(crate) fn encrypt_ggsw_ciphertext_generic<S>(
msg.as_torus()
};

for (j, col) in row.glwe_ciphertexts_mut(params).enumerate() {
let mut scaled_msg = Polynomial::zero(polynomial_degree);

// The factor is q / B^{i+1}. Since B is a power of 2, this is equivalent to
// multiplying by 2^{log2(q) - log2(B) * (i + 1)}
let decomp_factor =
S::from_u64(0x1 << (S::BITS as usize - decomposition_radix_log * (j + 1)));

polynomial_scalar_mul(&mut scaled_msg, m_times_s, decomp_factor);

encrypt(col, &scaled_msg, glwe_secret_key, params);
}
encrypt_glev_ciphertext_generic(row, m_times_s, glwe_secret_key, params, radix, &encrypt);
}
}

Expand Down Expand Up @@ -130,63 +116,25 @@ pub fn encrypt_ggsw_ciphertext_scalar<S>(
ggsw_ciphertext: &mut GgswCiphertextRef<S>,
msg: S,
glwe_secret_key: &GlweSecretKeyRef<S>,
glwe_def: &GlweDef,
params: &GlweDef,
radix: &RadixDecomposition,
plaintext_bits: PlaintextBits,
) where
S: TorusOps,
{
assert!(plaintext_bits.0 < S::BITS);
radix.assert_valid::<S>();
glwe_def.assert_valid();
glwe_secret_key.assert_valid(glwe_def);
ggsw_ciphertext.assert_valid(glwe_def, radix);

let max_val = S::from_u64(0x1 << plaintext_bits.0);
assert!(msg < max_val);

let decomposition_radix_log = radix.radix_log.0;
let polynomial_degree = glwe_def.dim.polynomial_degree.0;
let glwe_size = glwe_def.dim.size.0;

// k + 1 rows with l columns of glwe ciphertexts. Element (i,j) is a glwe encryption
// of -M/B^{i+1} * s_j, except for j=k+1, where it's simply an encryption of
// M/B^{j+1}
for (i, row) in ggsw_ciphertext.rows_mut(glwe_def, radix).enumerate() {
let mut m_times_s = Polynomial::<Torus<S>>::zero(polynomial_degree);
let m_times_s = if i < glwe_size {
let s = glwe_secret_key.s(glwe_def).nth(i).unwrap();
polynomial_scalar_mad(&mut m_times_s, s.as_torus(), msg);
&m_times_s
} else {
// Last row isn't multiplied by secret key.
m_times_s.clear();
m_times_s.coeffs_mut()[0] = Torus::from(msg);
&m_times_s
};

for (j, col) in row.glwe_ciphertexts_mut(glwe_def).enumerate() {
let mut scaled_msg = Polynomial::zero(polynomial_degree);
// The factor is q / B^{i+1}. Since B is a power of 2, this is equivalent to
// multiplying by 2^{log2(q) - log2(B) * (i + 1)}
let decomp_factor =
S::from_u64(0x1 << (S::BITS as usize - decomposition_radix_log * (j + 1)));

if i < glwe_size {
let decomp_factor = decomp_factor.wrapping_neg();
let polynomial_degree = params.dim.polynomial_degree.0;

polynomial_scalar_mul(&mut scaled_msg, m_times_s, decomp_factor);
} else {
scaled_msg.coeffs_mut()[0] = Torus::from(msg.wrapping_mul(&decomp_factor));
let mut poly_msg = Polynomial::<S>::zero(polynomial_degree);
poly_msg.coeffs_mut()[0] = msg;

for c in scaled_msg.coeffs_mut().iter_mut().skip(1) {
*c = Torus::zero();
}
}

encrypt_glwe_ciphertext_secret(col, &scaled_msg, glwe_secret_key, glwe_def);
}
}
encrypt_ggsw_ciphertext(
ggsw_ciphertext,
&poly_msg,
glwe_secret_key,
params,
radix,
plaintext_bits,
)
}

fn decrypt_glwe_in_ggsw<S>(
Expand All @@ -201,27 +149,8 @@ fn decrypt_glwe_in_ggsw<S>(
where
S: TorusOps,
{
let decomposition_radix_log = radix.radix_log.0;

// To decrypt a GGSW ciphertext, it suffices to decrypt the first GLWE ciphertext in
// the last row and divide by its decomposition factor.
let glev = ggsw_ciphertext.rows(params, radix).nth(row)?;
let glwe = glev.glwe_ciphertexts(params).nth(column)?;

// Decrypt that specific GLWE ciphertext, which should have a message of
// q / beta ^ {column + 1} * SM, where SM is the message times the secret
// every row but the last (-SM) and M for the last row.
decrypt_glwe_ciphertext(msg, glwe, glwe_secret_key, params);

let mask = (0x1 << decomposition_radix_log) - 1;

for c in msg.coeffs_mut() {
let val = c.inner() >> (S::BITS as usize - decomposition_radix_log * (column + 1));
let r = (c.inner() >> (S::BITS as usize - decomposition_radix_log * (column + 1) - 1))
& S::from_u64(0x1);

*c = Torus::from((val + r) & S::from_u64(mask));
}
decrypt_glwe_in_glev(msg, glev, glwe_secret_key, params, radix, column)?;

Some(())
}
Expand All @@ -242,8 +171,11 @@ pub fn decrypt_ggsw_ciphertext<S>(
ggsw_ciphertext.assert_valid(params, radix);
glwe_secret_key.assert_valid(params);

// To decrypt a GGSW ciphertext, it suffices to decrypt the first GLWE
// ciphertext in the last row. We can decrypt any of the GLWE ciphertexts in
// the last row and divide them by their decomposition factor; we choose the
// first GLWE ciphertext.
let row = params.dim.size.0;

decrypt_glwe_in_ggsw(msg, ggsw_ciphertext, glwe_secret_key, params, radix, row, 0).unwrap();
}

Expand Down
Loading
Loading