From d9b8412d794fe9596a3292f717f00e11f2bc08f2 Mon Sep 17 00:00:00 2001 From: Franco Nieddu Date: Tue, 3 Sep 2024 14:30:40 +0200 Subject: [PATCH] feat: added new rep3 impl --- mpc-core/src/protocols.rs | 1 + mpc-core/src/protocols/rep3new.rs | 8 + mpc-core/src/protocols/rep3new/arithmetic.rs | 170 ++++++++++++++++++ .../src/protocols/rep3new/arithmetic/ops.rs | 89 +++++++++ .../src/protocols/rep3new/arithmetic/types.rs | 103 +++++++++++ mpc-core/src/protocols/rep3new/binary.rs | 92 ++++++++++ mpc-core/src/protocols/rep3new/binary/ops.rs | 73 ++++++++ .../src/protocols/rep3new/binary/types.rs | 27 +++ 8 files changed, 563 insertions(+) create mode 100644 mpc-core/src/protocols/rep3new.rs create mode 100644 mpc-core/src/protocols/rep3new/arithmetic.rs create mode 100644 mpc-core/src/protocols/rep3new/arithmetic/ops.rs create mode 100644 mpc-core/src/protocols/rep3new/arithmetic/types.rs create mode 100644 mpc-core/src/protocols/rep3new/binary.rs create mode 100644 mpc-core/src/protocols/rep3new/binary/ops.rs create mode 100644 mpc-core/src/protocols/rep3new/binary/types.rs diff --git a/mpc-core/src/protocols.rs b/mpc-core/src/protocols.rs index ba899aebc..88b3a75ef 100644 --- a/mpc-core/src/protocols.rs +++ b/mpc-core/src/protocols.rs @@ -5,4 +5,5 @@ pub mod bridges; pub mod plain; pub mod rep3; +pub mod rep3new; pub mod shamir; diff --git a/mpc-core/src/protocols/rep3new.rs b/mpc-core/src/protocols/rep3new.rs new file mode 100644 index 000000000..b730c5d47 --- /dev/null +++ b/mpc-core/src/protocols/rep3new.rs @@ -0,0 +1,8 @@ +mod arithmetic; +mod binary; + +pub use arithmetic::types::Rep3PrimeFieldShare; +pub use arithmetic::types::Rep3PrimeFieldShareVec; +pub use arithmetic::Arithmetic; + +pub use binary::types::Rep3BigUintShare; diff --git a/mpc-core/src/protocols/rep3new/arithmetic.rs b/mpc-core/src/protocols/rep3new/arithmetic.rs new file mode 100644 index 000000000..1f6c3216d --- /dev/null +++ b/mpc-core/src/protocols/rep3new/arithmetic.rs @@ -0,0 +1,170 @@ +use std::marker::PhantomData; + +use ark_ff::PrimeField; +use itertools::{izip, Itertools}; +use types::{Rep3PrimeFieldShare, Rep3PrimeFieldShareVec}; + +use crate::protocols::rep3::{id::PartyID, network::Rep3Network, rngs::Rep3CorrelatedRng}; + +type FieldShare = Rep3PrimeFieldShare; +type FieldShareVec = Rep3PrimeFieldShareVec; +type IoResult = std::io::Result; + +mod ops; +pub(super) mod types; + +// this will be moved later +pub struct IoContext { + pub(crate) rngs: Rep3CorrelatedRng, + pub(crate) network: N, +} + +pub struct Arithmetic { + field: PhantomData, + network: PhantomData, +} + +impl Arithmetic { + pub fn add(a: &FieldShare, b: &FieldShare) -> FieldShare { + a + b + } + + pub fn add_public(shared: &FieldShare, public: F) -> FieldShare { + shared + public + } + + pub fn sub(a: &FieldShare, b: &FieldShare) -> FieldShare { + a - b + } + + pub fn sub_public(shared: &FieldShare, public: F) -> FieldShare { + shared - public + } + + pub async fn mul( + a: &FieldShare, + b: &FieldShare, + io_context: &mut IoContext, + ) -> IoResult> { + let local_a = a * b + io_context.rngs.rand.masking_field_element::(); + io_context.network.send_next(local_a)?; + let local_b = io_context.network.recv_prev()?; + Ok(FieldShare { + a: local_a, + b: local_b, + }) + } + + /// Multiply a share b by a public value a: c = a * \[b\]. + pub fn mul_with_public(shared: &FieldShare, public: F) -> FieldShare { + shared * public + } + + pub async fn mul_vec( + a: &FieldShareVec, + b: &FieldShareVec, + io_context: &mut IoContext, + ) -> IoResult> { + //debug_assert_eq!(a.len(), b.len()); + let local_a = izip!(a.a.iter(), a.b.iter(), b.a.iter(), b.b.iter()) + .map(|(aa, ab, ba, bb)| { + *aa * ba + *aa * bb + *ab * ba + io_context.rngs.rand.masking_field_element::() + }) + .collect_vec(); + io_context.network.send_next_many(&local_a)?; + let local_b = io_context.network.recv_prev_many()?; + if local_b.len() != local_a.len() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "During execution of mul_vec in MPC: Invalid number of elements received", + )); + } + Ok(FieldShareVec::new(local_a, local_b)) + } + + /// Negates a shared value: \[b\] = -\[a\]. + pub fn neg(a: &FieldShare) -> FieldShare { + -a + } + + pub async fn inv(a: &FieldShare, io_context: &mut IoContext) -> IoResult> { + let r = FieldShare::rand(&mut io_context.rngs); + let y = Self::mul_open(a, &r, io_context).await?; + if y.is_zero() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "During execution of inverse in MPC: cannot compute inverse of zero", + )); + } + let y_inv = y + .inverse() + .expect("we checked if y is zero. Must be possible to inverse."); + Ok(r * y_inv) + } + + pub async fn open(a: &FieldShare, io_context: &mut IoContext) -> IoResult { + io_context.network.send_next(a.b)?; + let c = io_context.network.recv_prev::()?; + Ok(a.a + a.b + c) + } + + /// Computes a CMUX: If cond is 1, returns truthy, otherwise returns falsy. + /// Implementations should not overwrite this method. + pub async fn cmux( + cond: &FieldShare, + truthy: &FieldShare, + falsy: &FieldShare, + io_context: &mut IoContext, + ) -> IoResult> { + let b_min_a = Self::sub(truthy, falsy); + let d = Self::mul(cond, &b_min_a, io_context).await?; + Ok(Self::add(falsy, &d)) + } + + /// Convenience method for \[a\] + \[b\] * c + pub fn add_mul_public(a: &FieldShare, b: &FieldShare, c: F) -> FieldShare { + Self::add(a, &Self::mul_with_public(b, c)) + } + + /// Convenience method for \[a\] + \[b\] * \[c\] + pub async fn add_mul( + &mut self, + a: &FieldShare, + b: &FieldShare, + c: &FieldShare, + io_context: &mut IoContext, + ) -> IoResult> { + Ok(Self::add(a, &Self::mul(c, b, io_context).await?)) + } + + /// Transforms a public value into a shared value: \[a\] = a. + pub fn promote_to_trivial_share( + public_value: F, + io_context: &mut IoContext, + ) -> FieldShare { + match io_context.network.get_id() { + PartyID::ID0 => Rep3PrimeFieldShare::new(public_value, F::zero()), + PartyID::ID1 => Rep3PrimeFieldShare::new(F::zero(), public_value), + PartyID::ID2 => Rep3PrimeFieldShare::zero_share(), + } + } + + /// This function performs a multiplication directly followed by an opening. This safes one round of communication in some MPC protocols compared to calling `mul` and `open` separately. + pub async fn mul_open( + a: &FieldShare, + b: &FieldShare, + io_context: &mut IoContext, + ) -> IoResult { + let a = a * b + io_context.rngs.rand.masking_field_element::(); + io_context.network.send_next(a.to_owned())?; + io_context + .network + .send(io_context.network.get_id().prev_id(), a.to_owned())?; + + let b = io_context.network.recv_prev::()?; + let c = io_context + .network + .recv::(io_context.network.get_id().next_id())?; + Ok(a + b + c) + } +} diff --git a/mpc-core/src/protocols/rep3new/arithmetic/ops.rs b/mpc-core/src/protocols/rep3new/arithmetic/ops.rs new file mode 100644 index 000000000..10d48b818 --- /dev/null +++ b/mpc-core/src/protocols/rep3new/arithmetic/ops.rs @@ -0,0 +1,89 @@ +use ark_ff::PrimeField; + +use super::types::Rep3PrimeFieldShare; + +impl std::ops::Add<&Rep3PrimeFieldShare> for &'_ Rep3PrimeFieldShare { + type Output = Rep3PrimeFieldShare; + + fn add(self, rhs: &Rep3PrimeFieldShare) -> Self::Output { + Rep3PrimeFieldShare:: { + a: self.a + rhs.a, + b: self.b + rhs.b, + } + } +} + +impl std::ops::Add for &Rep3PrimeFieldShare { + type Output = Rep3PrimeFieldShare; + + fn add(self, rhs: F) -> Self::Output { + Self::Output { + a: self.a + rhs, + b: self.b + rhs, + } + } +} + +impl std::ops::Sub<&Rep3PrimeFieldShare> for &'_ Rep3PrimeFieldShare { + type Output = Rep3PrimeFieldShare; + + fn sub(self, rhs: &Rep3PrimeFieldShare) -> Self::Output { + Rep3PrimeFieldShare:: { + a: self.a - rhs.a, + b: self.b - rhs.b, + } + } +} + +impl std::ops::Sub for &Rep3PrimeFieldShare { + type Output = Rep3PrimeFieldShare; + + fn sub(self, rhs: F) -> Self::Output { + Self::Output { + a: self.a - rhs, + b: self.b - rhs, + } + } +} + +impl std::ops::Mul<&Rep3PrimeFieldShare> for &'_ Rep3PrimeFieldShare { + type Output = F; + + // Local part of mul only + fn mul(self, rhs: &Rep3PrimeFieldShare) -> Self::Output { + self.a * rhs.a + self.a * rhs.b + self.b * rhs.a + } +} + +impl std::ops::Mul for Rep3PrimeFieldShare { + type Output = Rep3PrimeFieldShare; + + fn mul(self, rhs: F) -> Self::Output { + Self::Output { + a: self.a * rhs, + b: self.b * rhs, + } + } +} + +impl std::ops::Mul for &Rep3PrimeFieldShare { + type Output = Rep3PrimeFieldShare; + + fn mul(self, rhs: F) -> Self::Output { + Self::Output { + a: self.a * rhs, + b: self.b * rhs, + } + } +} + +impl std::ops::Neg for &Rep3PrimeFieldShare { + type Output = Rep3PrimeFieldShare; + + fn neg(self) -> Self::Output { + Rep3PrimeFieldShare:: { + a: -self.a, + b: -self.b, + } + } +} diff --git a/mpc-core/src/protocols/rep3new/arithmetic/types.rs b/mpc-core/src/protocols/rep3new/arithmetic/types.rs new file mode 100644 index 000000000..ac0544a53 --- /dev/null +++ b/mpc-core/src/protocols/rep3new/arithmetic/types.rs @@ -0,0 +1,103 @@ +use ark_ff::PrimeField; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; + +use crate::protocols::rep3::{id::PartyID, rngs::Rep3CorrelatedRng}; + +/// This type represents a replicated shared value. Since a replicated share of a field element contains additive shares of two parties, this type contains two field elements. +#[derive(Debug, Clone, PartialEq, Eq, Hash, CanonicalSerialize, CanonicalDeserialize)] +pub struct Rep3PrimeFieldShare { + pub(crate) a: F, + pub(crate) b: F, +} + +/// This type represents a vector of replicated shared value. Since a replicated share of a field element contains additive shares of two parties, this type contains two vectors of field elements. +#[derive(Debug, Clone, Default, PartialEq, Eq, CanonicalSerialize, CanonicalDeserialize)] +pub struct Rep3PrimeFieldShareVec { + pub(crate) a: Vec, + pub(crate) b: Vec, +} + +impl Rep3PrimeFieldShare { + /// Constructs the type from two additive shares. + pub fn new(a: F, b: F) -> Self { + Self { a, b } + } + + pub(super) fn zero_share() -> Self { + Self { + a: F::zero(), + b: F::zero(), + } + } + + /// Unwraps the type into two additive shares. + pub fn ab(self) -> (F, F) { + (self.a, self.b) + } + + pub(crate) fn double(&mut self) { + self.a.double_in_place(); + self.b.double_in_place(); + } + + pub(super) fn rand(rngs: &mut Rep3CorrelatedRng) -> Self { + let (a, b) = rngs.rand.random_fes(); + Self::new(a, b) + } + + /// Promotes a public field element to a replicated share by setting the additive share of the party with id=0 and leaving all other shares to be 0. Thus, the replicated shares of party 0 and party 1 are set. + pub fn promote_from_trivial(val: &F, id: PartyID) -> Self { + match id { + PartyID::ID0 => Self::new(*val, F::zero()), + PartyID::ID1 => Self::new(F::zero(), *val), + PartyID::ID2 => Self::zero_share(), + } + } +} + +impl Rep3PrimeFieldShareVec { + /// Constructs the type from two vectors of additive shares. + pub fn new(a: Vec, b: Vec) -> Self { + Self { a, b } + } + + /// Unwraps the type into two vectors of additive shares. + pub fn get_ab(self) -> (Vec, Vec) { + (self.a, self.b) + } + + /// Checks whether the wrapped vectors are empty. + pub fn is_empty(&self) -> bool { + debug_assert_eq!(self.a.is_empty(), self.b.is_empty()); + self.a.is_empty() + } + + /// Returns the length of the wrapped vectors. + pub fn len(&self) -> usize { + debug_assert_eq!(self.a.len(), self.b.len()); + self.a.len() + } + + /// Promotes a vector of public field elements to a vector of replicated shares by setting the additive shares of the party with id=0 and leaving all other shares to be 0. Thus, the replicated shares of party 0 and party 1 are set. + pub fn promote_from_trivial(val: &[F], id: PartyID) -> Self { + let len = val.len(); + + match id { + PartyID::ID0 => { + let a = val.to_vec(); + let b = vec![F::zero(); len]; + Self { a, b } + } + PartyID::ID1 => { + let a = vec![F::zero(); len]; + let b = val.to_vec(); + Self { a, b } + } + PartyID::ID2 => { + let a = vec![F::zero(); len]; + let b = vec![F::zero(); len]; + Self { a, b } + } + } + } +} diff --git a/mpc-core/src/protocols/rep3new/binary.rs b/mpc-core/src/protocols/rep3new/binary.rs new file mode 100644 index 000000000..e7d190bb1 --- /dev/null +++ b/mpc-core/src/protocols/rep3new/binary.rs @@ -0,0 +1,92 @@ +use std::marker::PhantomData; + +use ark_ff::PrimeField; +use num_bigint::BigUint; +use num_traits::ConstZero; + +use crate::protocols::rep3::{id::PartyID, network::Rep3Network}; + +use super::arithmetic::IoContext; + +mod ops; +pub(super) mod types; + +type BinaryShare = types::Rep3BigUintShare; +type IoResult = std::io::Result; + +pub struct Binary { + field: PhantomData, + network: PhantomData, +} + +impl Binary { + // this happens to compile time so we just do "as usize" + const BITLEN: usize = F::MODULUS_BIT_SIZE as usize; + + pub fn xor(a: &BinaryShare, b: &BinaryShare) -> BinaryShare { + a ^ b + } + + pub fn xor_public(shared: &BinaryShare, public: &BigUint) -> BinaryShare { + shared ^ public + } + + pub async fn and( + a: &BinaryShare, + b: &BinaryShare, + io_context: &mut IoContext, + ) -> IoResult { + debug_assert!(a.a.bits() <= u64::try_from(Self::BITLEN).expect("usize fits into u64")); + debug_assert!(b.a.bits() <= u64::try_from(Self::BITLEN).expect("usize fits into u64")); + let (mut mask, mask_b) = io_context.rngs.rand.random_biguint(Self::BITLEN); + mask ^= mask_b; + let local_a = (a & b) ^ mask; + io_context.network.send_next(local_a.to_owned())?; + let local_b = io_context.network.recv_prev()?; + Ok(BinaryShare::new(local_a, local_b)) + } + + pub fn and_with_public(shared: &BinaryShare, public: &BigUint) -> BinaryShare { + shared & public + } + + //pub async fn and_vec( + // a: &FieldShareVec, + // b: &FieldShareVec, + // io_context: &mut IoContext, + //) -> IoResult> { + // //debug_assert_eq!(a.len(), b.len()); + // let local_a = izip!(a.a.iter(), a.b.iter(), b.a.iter(), b.b.iter()) + // .map(|(aa, ab, ba, bb)| { + // *aa * ba + *aa * bb + *ab * ba + io_context.rngs.rand.masking_field_element::() + // }) + // .collect_vec(); + // io_context.network.send_next_many(&local_a)?; + // let local_b = io_context.network.recv_prev_many()?; + // if local_b.len() != local_a.len() { + // return Err(std::io::Error::new( + // std::io::ErrorKind::InvalidData, + // "During execution of mul_vec in MPC: Invalid number of elements received", + // )); + // } + // Ok(FieldShareVec::new(local_a, local_b)) + //} + + pub async fn open(a: &BinaryShare, io_context: &mut IoContext) -> IoResult { + io_context.network.send_next(a.b.clone())?; + let c = io_context.network.recv_prev::()?; + Ok(&a.a ^ &a.b ^ c) + } + + /// Transforms a public value into a shared value: \[a\] = a. + pub fn promote_to_trivial_share( + public_value: BigUint, + io_context: &mut IoContext, + ) -> BinaryShare { + match io_context.network.get_id() { + PartyID::ID0 => BinaryShare::new(public_value, BigUint::ZERO), + PartyID::ID1 => BinaryShare::new(BigUint::ZERO, public_value), + PartyID::ID2 => BinaryShare::zero_share(), + } + } +} diff --git a/mpc-core/src/protocols/rep3new/binary/ops.rs b/mpc-core/src/protocols/rep3new/binary/ops.rs new file mode 100644 index 000000000..f63b1071c --- /dev/null +++ b/mpc-core/src/protocols/rep3new/binary/ops.rs @@ -0,0 +1,73 @@ +use num_bigint::BigUint; + +use super::types::Rep3BigUintShare; + +impl std::ops::BitXor<&Rep3BigUintShare> for &'_ Rep3BigUintShare { + type Output = Rep3BigUintShare; + + fn bitxor(self, rhs: &Rep3BigUintShare) -> Self::Output { + Self::Output { + a: &self.a ^ &rhs.a, + b: &self.b ^ &rhs.b, + } + } +} + +impl std::ops::BitXor<&BigUint> for &Rep3BigUintShare { + type Output = Rep3BigUintShare; + + fn bitxor(self, rhs: &BigUint) -> Self::Output { + Self::Output { + a: &self.a ^ rhs, + b: &self.b ^ rhs, + } + } +} + +impl std::ops::BitAnd<&BigUint> for &Rep3BigUintShare { + type Output = Rep3BigUintShare; + + fn bitand(self, rhs: &BigUint) -> Self::Output { + Rep3BigUintShare { + a: &self.a & rhs, + b: &self.b & rhs, + } + } +} + +impl std::ops::BitAnd<&Rep3BigUintShare> for &'_ Rep3BigUintShare { + type Output = BigUint; + + fn bitand(self, rhs: &Rep3BigUintShare) -> Self::Output { + (&self.a & &rhs.a) ^ (&self.a & &rhs.b) ^ (&self.b & &rhs.a) + } +} + +impl std::ops::ShlAssign for Rep3BigUintShare { + fn shl_assign(&mut self, rhs: usize) { + self.a <<= rhs; + self.b <<= rhs; + } +} + +impl std::ops::Shl for Rep3BigUintShare { + type Output = Self; + + fn shl(self, rhs: usize) -> Self::Output { + Rep3BigUintShare { + a: &self.a << rhs, + b: &self.b << rhs, + } + } +} + +impl std::ops::Shr for &Rep3BigUintShare { + type Output = Rep3BigUintShare; + + fn shr(self, rhs: usize) -> Self::Output { + Rep3BigUintShare { + a: &self.a >> rhs, + b: &self.b >> rhs, + } + } +} diff --git a/mpc-core/src/protocols/rep3new/binary/types.rs b/mpc-core/src/protocols/rep3new/binary/types.rs new file mode 100644 index 000000000..52ed24882 --- /dev/null +++ b/mpc-core/src/protocols/rep3new/binary/types.rs @@ -0,0 +1,27 @@ +use num_bigint::BigUint; + +/// This type represents a packed vector of replicated shared bits. Each additively shared vector is represented as [BigUint]. Thus, this type contains two [BigUint]s. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] +pub struct Rep3BigUintShare { + pub(crate) a: BigUint, + pub(crate) b: BigUint, +} + +impl Rep3BigUintShare { + /// Constructs the type from two additive shares. + pub fn new(a: BigUint, b: BigUint) -> Self { + Self { a, b } + } + + pub(super) fn zero_share() -> Self { + Self { + a: BigUint::ZERO, + b: BigUint::ZERO, + } + } + + /// Unwraps the type into two additive shares. + pub fn ab(self) -> (BigUint, BigUint) { + (self.a, self.b) + } +}