Skip to content

Commit

Permalink
feat: added new rep3 impl
Browse files Browse the repository at this point in the history
  • Loading branch information
0xThemis committed Oct 15, 2024
1 parent 8dc6fdd commit d9b8412
Show file tree
Hide file tree
Showing 8 changed files with 563 additions and 0 deletions.
1 change: 1 addition & 0 deletions mpc-core/src/protocols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
pub mod bridges;
pub mod plain;
pub mod rep3;
pub mod rep3new;
pub mod shamir;
8 changes: 8 additions & 0 deletions mpc-core/src/protocols/rep3new.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
mod arithmetic;
mod binary;

pub use arithmetic::types::Rep3PrimeFieldShare;
pub use arithmetic::types::Rep3PrimeFieldShareVec;
pub use arithmetic::Arithmetic;

pub use binary::types::Rep3BigUintShare;
170 changes: 170 additions & 0 deletions mpc-core/src/protocols/rep3new/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
use std::marker::PhantomData;

use ark_ff::PrimeField;
use itertools::{izip, Itertools};
use types::{Rep3PrimeFieldShare, Rep3PrimeFieldShareVec};

use crate::protocols::rep3::{id::PartyID, network::Rep3Network, rngs::Rep3CorrelatedRng};

type FieldShare<F> = Rep3PrimeFieldShare<F>;
type FieldShareVec<F> = Rep3PrimeFieldShareVec<F>;
type IoResult<F> = std::io::Result<F>;

mod ops;
pub(super) mod types;

// this will be moved later
pub struct IoContext<N: Rep3Network> {
pub(crate) rngs: Rep3CorrelatedRng,
pub(crate) network: N,
}

pub struct Arithmetic<F: PrimeField, N: Rep3Network> {
field: PhantomData<F>,
network: PhantomData<N>,
}

impl<F: PrimeField, N: Rep3Network> Arithmetic<F, N> {
pub fn add(a: &FieldShare<F>, b: &FieldShare<F>) -> FieldShare<F> {
a + b
}

pub fn add_public(shared: &FieldShare<F>, public: F) -> FieldShare<F> {
shared + public
}

pub fn sub(a: &FieldShare<F>, b: &FieldShare<F>) -> FieldShare<F> {
a - b
}

pub fn sub_public(shared: &FieldShare<F>, public: F) -> FieldShare<F> {
shared - public
}

pub async fn mul(
a: &FieldShare<F>,
b: &FieldShare<F>,
io_context: &mut IoContext<N>,
) -> IoResult<FieldShare<F>> {
let local_a = a * b + io_context.rngs.rand.masking_field_element::<F>();
io_context.network.send_next(local_a)?;
let local_b = io_context.network.recv_prev()?;
Ok(FieldShare {
a: local_a,
b: local_b,
})
}

/// Multiply a share b by a public value a: c = a * \[b\].
pub fn mul_with_public(shared: &FieldShare<F>, public: F) -> FieldShare<F> {
shared * public
}

pub async fn mul_vec(
a: &FieldShareVec<F>,
b: &FieldShareVec<F>,
io_context: &mut IoContext<N>,
) -> IoResult<FieldShareVec<F>> {
//debug_assert_eq!(a.len(), b.len());
let local_a = izip!(a.a.iter(), a.b.iter(), b.a.iter(), b.b.iter())
.map(|(aa, ab, ba, bb)| {
*aa * ba + *aa * bb + *ab * ba + io_context.rngs.rand.masking_field_element::<F>()
})
.collect_vec();
io_context.network.send_next_many(&local_a)?;
let local_b = io_context.network.recv_prev_many()?;
if local_b.len() != local_a.len() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"During execution of mul_vec in MPC: Invalid number of elements received",
));
}
Ok(FieldShareVec::new(local_a, local_b))
}

/// Negates a shared value: \[b\] = -\[a\].
pub fn neg(a: &FieldShare<F>) -> FieldShare<F> {
-a
}

pub async fn inv(a: &FieldShare<F>, io_context: &mut IoContext<N>) -> IoResult<FieldShare<F>> {
let r = FieldShare::rand(&mut io_context.rngs);
let y = Self::mul_open(a, &r, io_context).await?;
if y.is_zero() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"During execution of inverse in MPC: cannot compute inverse of zero",
));
}
let y_inv = y
.inverse()
.expect("we checked if y is zero. Must be possible to inverse.");
Ok(r * y_inv)
}

pub async fn open(a: &FieldShare<F>, io_context: &mut IoContext<N>) -> IoResult<F> {
io_context.network.send_next(a.b)?;
let c = io_context.network.recv_prev::<F>()?;
Ok(a.a + a.b + c)
}

/// Computes a CMUX: If cond is 1, returns truthy, otherwise returns falsy.
/// Implementations should not overwrite this method.
pub async fn cmux(
cond: &FieldShare<F>,
truthy: &FieldShare<F>,
falsy: &FieldShare<F>,
io_context: &mut IoContext<N>,
) -> IoResult<FieldShare<F>> {
let b_min_a = Self::sub(truthy, falsy);
let d = Self::mul(cond, &b_min_a, io_context).await?;
Ok(Self::add(falsy, &d))
}

/// Convenience method for \[a\] + \[b\] * c
pub fn add_mul_public(a: &FieldShare<F>, b: &FieldShare<F>, c: F) -> FieldShare<F> {
Self::add(a, &Self::mul_with_public(b, c))
}

/// Convenience method for \[a\] + \[b\] * \[c\]
pub async fn add_mul(
&mut self,
a: &FieldShare<F>,
b: &FieldShare<F>,
c: &FieldShare<F>,
io_context: &mut IoContext<N>,
) -> IoResult<FieldShare<F>> {
Ok(Self::add(a, &Self::mul(c, b, io_context).await?))
}

/// Transforms a public value into a shared value: \[a\] = a.
pub fn promote_to_trivial_share(
public_value: F,
io_context: &mut IoContext<N>,
) -> FieldShare<F> {
match io_context.network.get_id() {
PartyID::ID0 => Rep3PrimeFieldShare::new(public_value, F::zero()),
PartyID::ID1 => Rep3PrimeFieldShare::new(F::zero(), public_value),
PartyID::ID2 => Rep3PrimeFieldShare::zero_share(),
}
}

/// This function performs a multiplication directly followed by an opening. This safes one round of communication in some MPC protocols compared to calling `mul` and `open` separately.
pub async fn mul_open(
a: &FieldShare<F>,
b: &FieldShare<F>,
io_context: &mut IoContext<N>,
) -> IoResult<F> {
let a = a * b + io_context.rngs.rand.masking_field_element::<F>();
io_context.network.send_next(a.to_owned())?;
io_context
.network
.send(io_context.network.get_id().prev_id(), a.to_owned())?;

let b = io_context.network.recv_prev::<F>()?;
let c = io_context
.network
.recv::<F>(io_context.network.get_id().next_id())?;
Ok(a + b + c)
}
}
89 changes: 89 additions & 0 deletions mpc-core/src/protocols/rep3new/arithmetic/ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use ark_ff::PrimeField;

use super::types::Rep3PrimeFieldShare;

impl<F: PrimeField> std::ops::Add<&Rep3PrimeFieldShare<F>> for &'_ Rep3PrimeFieldShare<F> {
type Output = Rep3PrimeFieldShare<F>;

fn add(self, rhs: &Rep3PrimeFieldShare<F>) -> Self::Output {
Rep3PrimeFieldShare::<F> {
a: self.a + rhs.a,
b: self.b + rhs.b,
}
}
}

impl<F: PrimeField> std::ops::Add<F> for &Rep3PrimeFieldShare<F> {
type Output = Rep3PrimeFieldShare<F>;

fn add(self, rhs: F) -> Self::Output {
Self::Output {
a: self.a + rhs,
b: self.b + rhs,
}
}
}

impl<F: PrimeField> std::ops::Sub<&Rep3PrimeFieldShare<F>> for &'_ Rep3PrimeFieldShare<F> {
type Output = Rep3PrimeFieldShare<F>;

fn sub(self, rhs: &Rep3PrimeFieldShare<F>) -> Self::Output {
Rep3PrimeFieldShare::<F> {
a: self.a - rhs.a,
b: self.b - rhs.b,
}
}
}

impl<F: PrimeField> std::ops::Sub<F> for &Rep3PrimeFieldShare<F> {
type Output = Rep3PrimeFieldShare<F>;

fn sub(self, rhs: F) -> Self::Output {
Self::Output {
a: self.a - rhs,
b: self.b - rhs,
}
}
}

impl<F: PrimeField> std::ops::Mul<&Rep3PrimeFieldShare<F>> for &'_ Rep3PrimeFieldShare<F> {
type Output = F;

// Local part of mul only
fn mul(self, rhs: &Rep3PrimeFieldShare<F>) -> Self::Output {
self.a * rhs.a + self.a * rhs.b + self.b * rhs.a
}
}

impl<F: PrimeField> std::ops::Mul<F> for Rep3PrimeFieldShare<F> {
type Output = Rep3PrimeFieldShare<F>;

fn mul(self, rhs: F) -> Self::Output {
Self::Output {
a: self.a * rhs,
b: self.b * rhs,
}
}
}

impl<F: PrimeField> std::ops::Mul<F> for &Rep3PrimeFieldShare<F> {
type Output = Rep3PrimeFieldShare<F>;

fn mul(self, rhs: F) -> Self::Output {
Self::Output {
a: self.a * rhs,
b: self.b * rhs,
}
}
}

impl<F: PrimeField> std::ops::Neg for &Rep3PrimeFieldShare<F> {
type Output = Rep3PrimeFieldShare<F>;

fn neg(self) -> Self::Output {
Rep3PrimeFieldShare::<F> {
a: -self.a,
b: -self.b,
}
}
}
103 changes: 103 additions & 0 deletions mpc-core/src/protocols/rep3new/arithmetic/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use ark_ff::PrimeField;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};

use crate::protocols::rep3::{id::PartyID, rngs::Rep3CorrelatedRng};

/// This type represents a replicated shared value. Since a replicated share of a field element contains additive shares of two parties, this type contains two field elements.
#[derive(Debug, Clone, PartialEq, Eq, Hash, CanonicalSerialize, CanonicalDeserialize)]
pub struct Rep3PrimeFieldShare<F: PrimeField> {
pub(crate) a: F,
pub(crate) b: F,
}

/// This type represents a vector of replicated shared value. Since a replicated share of a field element contains additive shares of two parties, this type contains two vectors of field elements.
#[derive(Debug, Clone, Default, PartialEq, Eq, CanonicalSerialize, CanonicalDeserialize)]
pub struct Rep3PrimeFieldShareVec<F: PrimeField> {
pub(crate) a: Vec<F>,
pub(crate) b: Vec<F>,
}

impl<F: PrimeField> Rep3PrimeFieldShare<F> {
/// Constructs the type from two additive shares.
pub fn new(a: F, b: F) -> Self {
Self { a, b }
}

pub(super) fn zero_share() -> Self {
Self {
a: F::zero(),
b: F::zero(),
}
}

/// Unwraps the type into two additive shares.
pub fn ab(self) -> (F, F) {
(self.a, self.b)
}

pub(crate) fn double(&mut self) {
self.a.double_in_place();
self.b.double_in_place();
}

pub(super) fn rand(rngs: &mut Rep3CorrelatedRng) -> Self {
let (a, b) = rngs.rand.random_fes();
Self::new(a, b)
}

/// Promotes a public field element to a replicated share by setting the additive share of the party with id=0 and leaving all other shares to be 0. Thus, the replicated shares of party 0 and party 1 are set.
pub fn promote_from_trivial(val: &F, id: PartyID) -> Self {
match id {
PartyID::ID0 => Self::new(*val, F::zero()),
PartyID::ID1 => Self::new(F::zero(), *val),
PartyID::ID2 => Self::zero_share(),
}
}
}

impl<F: PrimeField> Rep3PrimeFieldShareVec<F> {
/// Constructs the type from two vectors of additive shares.
pub fn new(a: Vec<F>, b: Vec<F>) -> Self {
Self { a, b }
}

/// Unwraps the type into two vectors of additive shares.
pub fn get_ab(self) -> (Vec<F>, Vec<F>) {
(self.a, self.b)
}

/// Checks whether the wrapped vectors are empty.
pub fn is_empty(&self) -> bool {
debug_assert_eq!(self.a.is_empty(), self.b.is_empty());
self.a.is_empty()
}

/// Returns the length of the wrapped vectors.
pub fn len(&self) -> usize {
debug_assert_eq!(self.a.len(), self.b.len());
self.a.len()
}

/// Promotes a vector of public field elements to a vector of replicated shares by setting the additive shares of the party with id=0 and leaving all other shares to be 0. Thus, the replicated shares of party 0 and party 1 are set.
pub fn promote_from_trivial(val: &[F], id: PartyID) -> Self {
let len = val.len();

match id {
PartyID::ID0 => {
let a = val.to_vec();
let b = vec![F::zero(); len];
Self { a, b }
}
PartyID::ID1 => {
let a = vec![F::zero(); len];
let b = val.to_vec();
Self { a, b }
}
PartyID::ID2 => {
let a = vec![F::zero(); len];
let b = vec![F::zero(); len];
Self { a, b }
}
}
}
}
Loading

0 comments on commit d9b8412

Please sign in to comment.