Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rweber/multi pbs #365

Merged
merged 12 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions sunscreen_tfhe/src/entities/univariate_lookup_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ impl<S: TorusOps> OverlaySize for UnivariateLookupTableRef<S> {
}

impl<S: TorusOps> UnivariateLookupTable<S> {
/// Creates a lookup table that is trivially encrypted.
/// Creates a trivially encrypted lookup table that computes a single function `map`.
///
/// # Remarks
/// The result of this can be used with
/// [`programmable_bootstrap_univariate`](crate::ops::bootstrapping::programmable_bootstrap_univariate).
pub fn trivial_from_fn<F>(map: F, glwe: &GlweDef, plaintext_bits: PlaintextBits) -> Self
where
F: Fn(u64) -> u64,
Expand All @@ -40,7 +44,32 @@ impl<S: TorusOps> UnivariateLookupTable<S> {
data: vec![Torus::zero(); UnivariateLookupTableRef::<S>::size(glwe.dim)],
};

lut.fill_trivial_from_fn(map, glwe, plaintext_bits);
lut.fill_trivial_from_fns(&[map], glwe, plaintext_bits);

lut
}

/// Creates a trivially encrypted lookup table that computes multiple functions
/// given by `maps`.
///
/// # Remarks
/// The result of this should be used with
/// [`generalized_programmable_bootstrap`](crate::ops::bootstrapping::generalized_programmable_bootstrap).
pub fn trivivial_multifunctional<F>(
maps: &[F],
glwe: &GlweDef,
plaintext_bits: PlaintextBits,
) -> Self
where
F: Fn(u64) -> u64,
{
assert!(maps.len() > 1);

let mut lut = UnivariateLookupTable {
data: vec![Torus::zero(); UnivariateLookupTableRef::<S>::size(glwe.dim)],
};

lut.fill_trivial_from_fns(maps, glwe, plaintext_bits);

lut
}
Expand All @@ -60,15 +89,15 @@ impl<S: TorusOps> UnivariateLookupTableRef<S> {

/// Generates a look up table filled with the values from the provided map,
/// and trivially encrypts the lookup table.
pub fn fill_trivial_from_fn<F: Fn(u64) -> u64>(
pub fn fill_trivial_from_fns<F: Fn(u64) -> u64>(
&mut self,
map: F,
maps: &[F],
glwe: &GlweDef,
plaintext_bits: PlaintextBits,
) {
allocate_scratch_ref!(poly, PolynomialRef<Torus<S>>, (glwe.dim.polynomial_degree));

generate_lut(poly, map, glwe, plaintext_bits);
generate_lut(poly, maps, glwe, plaintext_bits);

trivially_encrypt_glwe_ciphertext(self.glwe_mut(), poly, glwe);
}
Expand Down
6 changes: 3 additions & 3 deletions sunscreen_tfhe/src/high_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ pub mod keygen {
/// However, anyone who possesses `glwe_key` can easily use the returned
/// [`BootstrapKey`] to recover `sk`.
pub fn generate_bootstrapping_key(
sk: &LweSecretKey<u64>,
glwe_key: &GlweSecretKey<u64>,
sk: &LweSecretKeyRef<u64>,
glwe_key: &GlweSecretKeyRef<u64>,
lwe: &LweDef,
glwe: &GlweDef,
radix: &RadixDecomposition,
Expand Down Expand Up @@ -844,7 +844,7 @@ pub mod evaluation {
) -> LweCiphertext<u64> {
let mut out = LweCiphertext::new(&glwe.as_lwe_def());

crate::ops::bootstrapping::programmable_bootstrap(
crate::ops::bootstrapping::programmable_bootstrap_univariate(
&mut out, input, lut, bsk, lwe, glwe, radix,
);

Expand Down
111 changes: 83 additions & 28 deletions sunscreen_tfhe/src/ops/bootstrapping/circuit_bootstrapping.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use num::Complex;
use sunscreen_math::Zero;

use crate::{
dst::FromMutSlice,
entities::{
BootstrapKeyFftRef, CircuitBootstrappingKeyswitchKeysRef, GgswCiphertextRef,
LweCiphertextListRef, LweCiphertextRef, UnivariateLookupTableRef,
GlweCiphertextRef, LweCiphertextListRef, LweCiphertextRef, UnivariateLookupTableRef,
},
ops::{
bootstrapping::programmable_bootstrap, homomorphisms::rotate,
bootstrapping::generalized_programmable_bootstrap, ciphertext::sample_extract,
homomorphisms::rotate,
keyswitch::private_functional_keyswitch::private_functional_keyswitch,
},
scratch::allocate_scratch_ref,
Expand Down Expand Up @@ -206,13 +208,11 @@ fn level_0_to_level_2<S: TorusOps>(
pbs_radix: &RadixDecomposition,
cbs_radix: &RadixDecomposition,
) {
allocate_scratch_ref!(glwe_out, GlweCiphertextRef<S>, (glwe_2.dim));
allocate_scratch_ref!(lut, UnivariateLookupTableRef<S>, (glwe_2.dim));
allocate_scratch_ref!(lwe_rotated, LweCiphertextRef<S>, (lwe_0.dim));
allocate_scratch_ref!(
lwe_bootstrapped,
LweCiphertextRef<S>,
(glwe_2.as_lwe_def().dim)
);
allocate_scratch_ref!(extracted, LweCiphertextRef<S>, (glwe_2.as_lwe_def().dim));
assert!(cbs_radix.count.0 < 8);

// Rotate our input by q/4, putting 0 centered on q/4 and 1 centered on
// -q/4.
Expand All @@ -223,42 +223,96 @@ fn level_0_to_level_2<S: TorusOps>(
lwe_0,
);

let log_v = if cbs_radix.count.0.is_power_of_two() {
cbs_radix.count.0.ilog2()
} else {
cbs_radix.count.0.ilog2() + 1
};

fill_multifunctional_cbs_decomposition_lut(lut, glwe_2, cbs_radix);

generalized_programmable_bootstrap(
glwe_out,
lwe_rotated,
lut,
bsk,
0,
log_v,
lwe_0,
glwe_2,
pbs_radix,
);

for (i, lwe_2) in lwes_2.ciphertexts_mut(&glwe_2.as_lwe_def()).enumerate() {
let cur_level = i + 1;

// Treat value as a T_{b^l+1} with one extra place for rounding as the last
// step.
let plaintext_bits = PlaintextBits((cbs_radix.radix_log.0 * cur_level + 1) as u32);

// Exploiting the fact that our LUT is negacyclic, we can encode -1 in T_{b^l+1}
// everywhere. Any lookup < q/2 will give -1 and any lookup > q/2 will
// give 1. Since we've shifted our input lwe by q/4, a 1 plaintext
// value will map to 1 and a 0 will map to -1.
let minus_one = (S::one() << plaintext_bits.0 as usize) - S::one();

lut.fill_with_constant(minus_one, glwe_2, plaintext_bits);

programmable_bootstrap(
lwe_bootstrapped,
lwe_rotated,
lut,
bsk,
lwe_0,
glwe_2,
pbs_radix,
);
sample_extract(extracted, glwe_out, i, glwe_2);

// Now we rotate our message containing -1 or 1 by 1 (wrt plaintext_bits).
// This will overflow -1 to 0 and cause 1 to wrap to 2.
rotate(
lwe_2,
lwe_bootstrapped,
extracted,
Torus::encode(S::one(), plaintext_bits),
&glwe_2.as_lwe_def(),
);
}
}

fn fill_multifunctional_cbs_decomposition_lut<S: TorusOps>(
lut: &mut UnivariateLookupTableRef<S>,
glwe: &GlweDef,
cbs_radix: &RadixDecomposition,
) {
lut.clear();

// Pick a largish number of levels nobody would ever exceed.
let mut levels = [Torus::zero(); 16];

assert!(cbs_radix.count.0 < levels.len());

// Compute our base decomposition factors.
// Exploiting the fact that our LUT is negacyclic, we can encode -1 in T_{b^l+1}
// everywhere. Any lookup < q/2 will give -1 and any lookup > q/2 will
// give 1. Since we've shifted our input lwe by q/4, a 1 plaintext
// value will map to 1 and a 0 will map to -1.
for (i, x) in levels.iter_mut().enumerate() {
let i = i + 1;
if i * cbs_radix.radix_log.0 + 1 < S::BITS as usize {
let plaintext_bits = PlaintextBits((cbs_radix.radix_log.0 * i + 1) as u32);

let minus_one = (S::one() << plaintext_bits.0 as usize) - S::one();
*x = Torus::encode(minus_one, plaintext_bits);
}
}

// Fill the table with alternating factors padded with zeros to a power of 2
let log_v = if cbs_radix.count.0.is_power_of_two() {
cbs_radix.count.0.ilog2()
} else {
cbs_radix.count.0.ilog2() + 1
};

let v = 0x1usize << log_v;

for (i, x) in lut
.glwe_mut()
.b_mut(glwe)
.coeffs_mut()
.iter_mut()
.enumerate()
{
let fn_id = i % v;

*x = if fn_id < cbs_radix.count.0 {
levels[fn_id]
} else {
Torus::zero()
};
}
}

/// Bootstraps a level 2 GLWE ciphertext to a level 1 GLWE ciphertext.
pub fn level_2_to_level1<S: TorusOps>(
result: &mut GgswCiphertextRef<S>,
Expand Down Expand Up @@ -322,6 +376,7 @@ mod tests {

let sk = keygen::generate_binary_lwe_sk(&TEST_LWE_DEF_1);
let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params);

let bsk = keygen::generate_bootstrapping_key(
&sk,
&glwe_sk,
Expand Down
Loading
Loading