Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor and add split for matrix #37

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/ciphertext.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{cell::RefCell, fmt::Display, sync::Arc};

use crate::core_crypto::{
matrix::{Matrix, MatrixMut, Row, RowMut},
matrix::{Matrix, MatrixEntity, MatrixMut, Row, RowMut},
modulus::NativeModulusBackend,
ntt::NativeNTTBackend,
num::UnsignedInteger,
Expand All @@ -17,7 +17,8 @@ pub trait Ciphertext {
type Scalar: UnsignedInteger;
type Row: Row<Element = Self::Scalar> + RowMut<Element = Self::Scalar>;
type Poly: Matrix<MatElement = Self::Scalar>
+ MatrixMut<MatElement = Self::Scalar, R = Self::Row>;
+ MatrixMut<MatElement = Self::Scalar, R = Self::Row>
+ MatrixEntity;

fn representation(&self) -> Representation;
fn representation_mut(&mut self) -> &mut Representation;
Expand Down
96 changes: 81 additions & 15 deletions src/core_crypto/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@ use std::fmt::Debug;

use aligned_vec::{avec, AVec};
use itertools::Itertools;
use num_traits::Zero;

use crate::core_crypto::num::UnsignedInteger;

pub trait Matrix: AsRef<[Self::R]> + Debug {
pub trait Matrix: AsRef<[Self::R]> {
type MatElement;
type R: Row<Element = Self::MatElement>;

fn dimension(&self) -> (usize, usize);

fn zeros(row: usize, col: usize) -> Self;

fn get_row(&self, row_idx: usize) -> &Self::R {
&self.as_ref()[row_idx]
}
Expand Down Expand Up @@ -43,6 +42,16 @@ where
self.as_mut()[row_idx].as_mut()[column_idx] = val;
}

/* Ideally, you want the return type to be MatrixMut, However this is not doable because of object safety
However, when R is a vector or align vector, the result will always be a matrixmut because we explicitly implemented MatrixMut
for them.

Note the postion idx is included in the second half
*/
fn split(&mut self, idx: usize) -> (&mut [<Self as Matrix>::R], &mut [<Self as Matrix>::R]) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A small thing.

Can we change this to split_at_row? Much clearer.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes you are right

self.as_mut().split_at_mut(idx)
}

fn get_col_iter_mut(
&mut self,
column_idx: usize,
Expand All @@ -67,6 +76,13 @@ where
}
}

pub trait MatrixEntity: Matrix
where
<Self as Matrix>::MatElement: Zero,
{
fn zeros(row: usize, col: usize) -> Self;
}

pub trait Drop2Dimension: Matrix {
/// Drop 2 dimensionality of matrix and reduces it self to a single row
/// vector
Expand Down Expand Up @@ -128,40 +144,80 @@ impl<T> IntoRowOwned for aligned_vec::ABox<[T]> {
}
}

impl<T> Matrix for AVec<AVec<T>>
where
T: UnsignedInteger,
{
impl<T> Matrix for AVec<AVec<T>> {
type MatElement = T;
type R = AVec<T>;

fn zeros(row: usize, col: usize) -> Self {
avec![avec![T::zero(); col]; row]
}

fn dimension(&self) -> (usize, usize) {
(self.len(), self[0].len())
}
}

impl<T> MatrixMut for AVec<AVec<T>> where T: UnsignedInteger {}
impl<T> MatrixMut for AVec<AVec<T>> {}

impl<T> Matrix for Vec<Vec<T>>
impl<T> MatrixEntity for AVec<AVec<T>>
where
T: UnsignedInteger,
T: Zero + Clone,
{
fn zeros(row: usize, col: usize) -> Self {
avec![avec![T::zero();col];row]
}
}

impl<T> Matrix for Vec<Vec<T>> {
type MatElement = T;
type R = Vec<T>;

fn dimension(&self) -> (usize, usize) {
(self.len(), self[0].len())
}
}
impl<T> MatrixMut for Vec<Vec<T>> {}

impl<T> MatrixEntity for Vec<Vec<T>>
where
T: Zero + Clone,
{
fn zeros(row: usize, col: usize) -> Self {
vec![vec![T::zero(); col]; row]
}
}
impl<T> Matrix for &[Vec<T>] {
type MatElement = T;
type R = Vec<T>;

fn dimension(&self) -> (usize, usize) {
(self.len(), self[0].len())
}
}

impl<T> Matrix for &mut [Vec<T>] {
type MatElement = T;
type R = Vec<T>;

fn dimension(&self) -> (usize, usize) {
(self.len(), self[0].len())
}
}
impl<T> MatrixMut for Vec<Vec<T>> where T: UnsignedInteger {}
impl<T> MatrixMut for &mut [Vec<T>] {}

impl<T> Matrix for &[AVec<T>] {
type MatElement = T;
type R = AVec<T>;

fn dimension(&self) -> (usize, usize) {
(self.len(), self[0].len())
}
}
impl<T> Matrix for &mut [AVec<T>] {
type MatElement = T;
type R = AVec<T>;

fn dimension(&self) -> (usize, usize) {
(self.len(), self[0].len())
}
}
impl<T> MatrixMut for &mut [AVec<T>] {}

#[cfg(test)]
mod test {
Expand Down Expand Up @@ -236,4 +292,14 @@ mod test {
v3.set(1, 1, 0);
assert_eq!(0, v3[1][1]);
}

#[test]
fn test_matrix_split() {
let v1 = vec![1_u64, 2_u64];
let v2 = vec![3_u64, 4_u64];
let mut v3 = vec![v1, v2];
let (mut first, mut second) = v3.split(1);
first.set(0, 0, 0);
assert_eq!(v3[0][0], 0)
}
}
17 changes: 9 additions & 8 deletions src/core_crypto/ring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,7 @@ mod tests {

use crate::{
core_crypto::{
matrix::MatrixEntity,
modulus::{ModulusBackendConfig, NativeModulusBackend},
ntt::NativeNTTBackend,
prime::generate_primes_vec,
Expand Down Expand Up @@ -731,7 +732,7 @@ mod tests {

let mut test = DefaultU64SeededRandomGenerator::new();

let mut poly_q_in = <Vec<Vec<u64>> as Matrix>::zeros(q_chain.len(), n);
let mut poly_q_in = <Vec<Vec<u64>> as MatrixEntity>::zeros(q_chain.len(), n);
RandomUniformDist::<Vec<Vec<u64>>>::random_fill(&mut test, &q_chain, &mut poly_q_in);
let mut poly_p_out = Vec::<Vec<u64>>::zeros(p_chain.len(), n);

Expand Down Expand Up @@ -828,10 +829,10 @@ mod tests {

let mut test_rng = DefaultU64SeededRandomGenerator::new();

let mut poly_q_in = <Vec<Vec<u64>> as Matrix>::zeros(q_chain.len(), n);
let mut poly_q_in = <Vec<Vec<u64>> as MatrixEntity>::zeros(q_chain.len(), n);
RandomUniformDist::<Vec<Vec<u64>>>::random_fill(&mut test_rng, &q_chain, &mut poly_q_in);

let mut poly_p_out = Vec::<Vec<u64>>::zeros(p_chain.len(), n);
let mut poly_p_out = <Vec<Vec<u64>> as MatrixEntity>::zeros(p_chain.len(), n);

switch_crt_basis(
&mut poly_p_out,
Expand Down Expand Up @@ -932,8 +933,8 @@ mod tests {
let mut test_rng = DefaultU64SeededRandomGenerator::new();

// P U Q
let mut poly_q_in = <Vec<Vec<u64>> as Matrix>::zeros(q_chain.len(), n);
let mut poly_p_in = <Vec<Vec<u64>> as Matrix>::zeros(p_chain.len(), n);
let mut poly_q_in = <Vec<Vec<u64>> as MatrixEntity>::zeros(q_chain.len(), n);
let mut poly_p_in = <Vec<Vec<u64>> as MatrixEntity>::zeros(p_chain.len(), n);
RandomUniformDist::<Vec<Vec<u64>>>::random_fill(&mut test_rng, &q_chain, &mut poly_q_in);
RandomUniformDist::<Vec<Vec<u64>>>::random_fill(&mut test_rng, &p_chain, &mut poly_p_in);

Expand Down Expand Up @@ -1040,7 +1041,7 @@ mod tests {

let mut test_rng = DefaultU64SeededRandomGenerator::new();

let mut poly_q_in = <Vec<Vec<u64>> as Matrix>::zeros(q_chain.len(), n);
let mut poly_q_in = <Vec<Vec<u64>> as MatrixEntity>::zeros(q_chain.len(), n);
RandomUniformDist::<Vec<Vec<u64>>>::random_fill(&mut test_rng, &q_chain, &mut poly_q_in);

let mut poly_t_out = Vec::<Vec<u64>>::zeros(1, n);
Expand Down Expand Up @@ -1155,8 +1156,8 @@ mod tests {
let mut test_rng = DefaultU64SeededRandomGenerator::new();

// Random polynomial in QP
let mut poly0_q_part = <Vec<Vec<u64>> as Matrix>::zeros(q_chain.len(), n);
let mut poly0_p_part = <Vec<Vec<u64>> as Matrix>::zeros(p_chain.len(), n);
let mut poly0_q_part = <Vec<Vec<u64>> as MatrixEntity>::zeros(q_chain.len(), n);
let mut poly0_p_part = <Vec<Vec<u64>> as MatrixEntity>::zeros(p_chain.len(), n);
RandomUniformDist::<Vec<Vec<u64>>>::random_fill(&mut test_rng, &q_chain, &mut poly0_q_part);
RandomUniformDist::<Vec<Vec<u64>>>::random_fill(&mut test_rng, &p_chain, &mut poly0_p_part);

Expand Down
9 changes: 5 additions & 4 deletions src/schemes/bfv/default_impl/entities.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
ciphertext::{Ciphertext, InitLevelledCiphertext, Representation, RlweCiphertext},
core_crypto::{
matrix::{Drop2Dimension, Matrix, MatrixMut, Row, RowMut},
matrix::{Drop2Dimension, Matrix, MatrixEntity, MatrixMut, Row, RowMut},
modulus::{
BarrettBackend, ModulusArithmeticBackend, ModulusBackendConfig, ModulusVecBackend,
MontgomeryBackend, MontgomeryScalar, NativeModulusBackend,
Expand Down Expand Up @@ -60,7 +60,7 @@ impl BfvSecretKey {

impl<P> Encryptor<[u64], BfvCiphertextScalarU64GenericStorage<P>> for BfvSecretKey
where
P: TryConvertFrom<[i32], Parameters = [u64]> + MatrixMut<MatElement = u64>,
P: TryConvertFrom<[i32], Parameters = [u64]> + MatrixMut<MatElement = u64> + MatrixEntity,
<P as Matrix>::R: RowMut,
{
fn encrypt(&self, message: &[u64]) -> BfvCiphertextScalarU64GenericStorage<P> {
Expand All @@ -77,6 +77,7 @@ impl<P> Decryptor<Vec<u64>, BfvCiphertextScalarU64GenericStorage<P>> for BfvSecr
where
P: Matrix<MatElement = u64>
+ MatrixMut
+ MatrixEntity
+ Drop2Dimension
+ TryConvertFrom<[i32], Parameters = [u64]>
+ Clone,
Expand All @@ -100,7 +101,7 @@ pub(super) struct BfvCiphertextScalarU64GenericStorage<P> {
impl<P, R> Ciphertext for BfvCiphertextScalarU64GenericStorage<P>
where
R: Row<Element = u64> + RowMut,
P: Matrix<MatElement = u64> + MatrixMut<MatElement = u64, R = R>,
P: Matrix<MatElement = u64> + MatrixMut<MatElement = u64, R = R> + MatrixEntity,
{
type Scalar = u64;
type Poly = P;
Expand Down Expand Up @@ -133,7 +134,7 @@ impl<P> InitLevelledCiphertext for BfvCiphertextScalarU64GenericStorage<P> {
impl<P, R> RlweCiphertext for BfvCiphertextScalarU64GenericStorage<P>
where
R: Row<Element = u64> + RowMut,
P: Matrix<MatElement = u64> + MatrixMut<MatElement = u64, R = R>,
P: Matrix<MatElement = u64> + MatrixMut<MatElement = u64, R = R> + MatrixEntity,
{
fn c_partq(&self) -> &[Self::Poly] {
&self.c_partq
Expand Down
6 changes: 3 additions & 3 deletions src/schemes/bfv/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use rand::{CryptoRng, RngCore};
use crate::{
ciphertext::{Ciphertext, InitLevelledCiphertext, Representation, RlweCiphertext},
core_crypto::{
matrix::{Matrix, MatrixMut, RowMut},
matrix::{Matrix, MatrixEntity, MatrixMut, RowMut},
modulus::ModulusVecBackend,
ntt::Ntt,
num::UnsignedInteger,
Expand Down Expand Up @@ -72,7 +72,7 @@ pub fn simd_decode_message<

pub fn secret_key_encryption<
Scalar: UnsignedInteger,
Poly: MatrixMut<MatElement = Scalar>,
Poly: MatrixMut<MatElement = Scalar> + MatrixEntity,
S: SecretKey<Scalar = i32>,
P: BfvEncryptionParameters<Scalar = Scalar>,
C: RlweCiphertext<Poly = Poly> + InitLevelledCiphertext<C = Vec<Poly>>,
Expand Down Expand Up @@ -303,7 +303,7 @@ pub fn ciphertext_mul<
let ring_size = parameters.ring_size();

// Scale and round c1 in basis Q by \frac{P}{Q} and output c0 in basis P
let mut poverq_c10_partp = <C::Poly as Matrix>::zeros(p_size, ring_size);
let mut poverq_c10_partp = <C::Poly>::zeros(p_size, ring_size);
fast_convert_p_over_q(
&mut poverq_c10_partp,
&c1.c_partq()[0],
Expand Down
26 changes: 18 additions & 8 deletions src/schemes/ckks/default_impl/entities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use rand_chacha::{ChaCha8Core, ChaCha8Rng};
use crate::{
ciphertext::{Ciphertext, Representation, RlweCiphertext, SeededCiphertext},
core_crypto::{
matrix::{Matrix, MatrixMut, RowMut},
matrix::{Matrix, MatrixEntity, MatrixMut, RowMut},
modulus::{ModulusBackendConfig, ModulusVecBackend, NativeModulusBackend},
ntt::{NativeNTTBackend, Ntt, NttConfig},
num::{big_float::BigFloat, BFloat, ComplexNumber},
Expand Down Expand Up @@ -40,7 +40,7 @@ pub static NATIVE_CKKS_CLIENT_PARAMETERS_U64: OnceLock<
>,
> = OnceLock::new();

impl<M: MatrixMut<MatElement = u64>> LevelEncoder<M> for Vec<DefaultComplex>
impl<M: MatrixMut<MatElement = u64> + MatrixEntity> LevelEncoder<M> for Vec<DefaultComplex>
where
<M as Matrix>::R: RowMut,
{
Expand Down Expand Up @@ -94,8 +94,12 @@ impl CkksSecretKey {
}
}

impl<M: MatrixMut<MatElement = u64> + Clone + TryConvertFrom<[i32], Parameters = [u64]>>
Encryptor<[DefaultComplex], CkksCiphertextGenericStorage<M>> for CkksSecretKey
impl<
M: MatrixMut<MatElement = u64>
+ MatrixEntity
+ Clone
+ TryConvertFrom<[i32], Parameters = [u64]>,
> Encryptor<[DefaultComplex], CkksCiphertextGenericStorage<M>> for CkksSecretKey
where
<M as Matrix>::R: RowMut,
{
Expand Down Expand Up @@ -125,8 +129,12 @@ where
}
}

impl<M: MatrixMut<MatElement = u64> + Clone + TryConvertFrom<[i32], Parameters = [u64]>>
Decryptor<Vec<DefaultComplex>, CkksCiphertextGenericStorage<M>> for CkksSecretKey
impl<
M: MatrixMut<MatElement = u64>
+ Clone
+ MatrixEntity
+ TryConvertFrom<[i32], Parameters = [u64]>,
> Decryptor<Vec<DefaultComplex>, CkksCiphertextGenericStorage<M>> for CkksSecretKey
where
<M as Matrix>::R: RowMut,
{
Expand Down Expand Up @@ -280,8 +288,9 @@ pub struct CkksCiphertextGenericStorage<M> {
representation: Representation,
}

impl<M: MatrixMut<MatElement = u64>> Ciphertext for CkksCiphertextGenericStorage<M>
impl<M> Ciphertext for CkksCiphertextGenericStorage<M>
where
M: MatrixMut<MatElement = u64> + MatrixEntity,
<M as Matrix>::R: RowMut,
{
type Poly = M;
Expand All @@ -308,8 +317,9 @@ impl<M> SeededCiphertext for CkksCiphertextGenericStorage<M> {
}
}

impl<M: MatrixMut<MatElement = u64>> RlweCiphertext for CkksCiphertextGenericStorage<M>
impl<M> RlweCiphertext for CkksCiphertextGenericStorage<M>
where
M: MatrixMut<MatElement = u64> + MatrixEntity,
<M as Matrix>::R: RowMut,
{
fn c_partq(&self) -> &[Self::Poly] {
Expand Down
Loading
Loading