diff --git a/mpc-core/src/protocols/rep3/rngs.rs b/mpc-core/src/protocols/rep3/rngs.rs index 9d4a3a928..c4cd4938f 100644 --- a/mpc-core/src/protocols/rep3/rngs.rs +++ b/mpc-core/src/protocols/rep3/rngs.rs @@ -57,7 +57,7 @@ impl Rep3Rand { } pub fn random_biguint(&mut self, bitlen: usize) -> (BigUint, BigUint) { - let limbsize = (bitlen + 31) / 32; + let limbsize = bitlen.div_ceil(8); let a = BigUint::new((0..limbsize).map(|_| self.rng1.gen()).collect()); let b = BigUint::new((0..limbsize).map(|_| self.rng2.gen()).collect()); let mask = (BigUint::from(1u32) << bitlen) - BigUint::one(); diff --git a/mpc-core/src/protocols/rep3new.rs b/mpc-core/src/protocols/rep3new.rs index b730c5d47..5c00e07a1 100644 --- a/mpc-core/src/protocols/rep3new.rs +++ b/mpc-core/src/protocols/rep3new.rs @@ -1,8 +1,86 @@ -mod arithmetic; -mod binary; +mod a2b; +pub mod arithmetic; +pub mod binary; pub use arithmetic::types::Rep3PrimeFieldShare; pub use arithmetic::types::Rep3PrimeFieldShareVec; -pub use arithmetic::Arithmetic; pub use binary::types::Rep3BigUintShare; + +pub mod conversion { + use ark_ff::PrimeField; + use num_bigint::BigUint; + + use crate::protocols::rep3::{id::PartyID, network::Rep3Network}; + + type IoResult = std::io::Result; + use super::{a2b, arithmetic::IoContext, Rep3BigUintShare, Rep3PrimeFieldShare}; + + //re-export a2b + pub use super::a2b::a2b; + + /// Transforms the replicated shared value x from a binary sharing to an arithmetic sharing. I.e., x = x_1 xor x_2 xor x_3 gets transformed into x = x'_1 + x'_2 + x'_3. This implementation currently works only for a binary sharing of a valid field element, i.e., x = x_1 xor x_2 xor x_3 < p. + + // Keep in mind: Only works if the input is actually a binary sharing of a valid field element + // If the input has the correct number of bits, but is >= P, then either x can be reduced with self.low_depth_sub_p_cmux(x) first, or self.low_depth_binary_add_2_mod_p(x, y) is extended to subtract 2P in parallel as well. The second solution requires another multiplexer in the end. + pub async fn b2a( + x: Rep3BigUintShare, + io_context: &mut IoContext, + ) -> IoResult> { + let mut y = Rep3BigUintShare::zero_share(); + let mut res = Rep3PrimeFieldShare::zero_share(); + + let bitlen = usize::try_from(F::MODULUS_BIT_SIZE).expect("u32 fits into usize"); + let (mut r, r2) = io_context.rngs.rand.random_biguint(bitlen); + r ^= r2; + + match io_context.id { + PartyID::ID0 => { + let k3 = io_context.rngs.bitcomp2.random_fes_3keys::(); + + res.b = (k3.0 + k3.1 + k3.2).neg(); + y.a = r; + } + PartyID::ID1 => { + let k2 = io_context.rngs.bitcomp1.random_fes_3keys::(); + + res.a = (k2.0 + k2.1 + k2.2).neg(); + y.a = r; + } + PartyID::ID2 => { + let k2 = io_context.rngs.bitcomp1.random_fes_3keys::(); + let k3 = io_context.rngs.bitcomp2.random_fes_3keys::(); + + let k2_comp = k2.0 + k2.1 + k2.2; + let k3_comp = k3.0 + k3.1 + k3.2; + let val: BigUint = (k2_comp + k3_comp).into(); + y.a = val ^ r; + res.a = k3_comp.neg(); + res.b = k2_comp.neg(); + } + } + + // Reshare y + io_context.network.send_next(y.a.to_owned())?; + let local_b = io_context.network.recv_prev()?; + y.b = local_b; + + let z = a2b::low_depth_binary_add_mod_p::(x, y, io_context, bitlen).await?; + + match io_context.id { + PartyID::ID0 => { + io_context.network.send_next(z.b.to_owned())?; + let rcv: BigUint = io_context.network.recv_prev()?; + res.a = (z.a ^ z.b ^ rcv).into(); + } + PartyID::ID1 => { + let rcv: BigUint = io_context.network.recv_prev()?; + res.b = (z.a ^ z.b ^ rcv).into(); + } + PartyID::ID2 => { + io_context.network.send_next(z.b)?; + } + } + Ok(res) + } +} diff --git a/mpc-core/src/protocols/rep3new/a2b.rs b/mpc-core/src/protocols/rep3new/a2b.rs new file mode 100644 index 000000000..1487a0db5 --- /dev/null +++ b/mpc-core/src/protocols/rep3new/a2b.rs @@ -0,0 +1,216 @@ +use ark_ff::One; +use ark_ff::PrimeField; +use ark_ff::Zero; +use num_bigint::BigUint; + +use crate::protocols::rep3::id::PartyID; +use crate::protocols::rep3::network::Rep3Network; + +use super::arithmetic::IoContext; +use super::binary; +use super::Rep3BigUintShare; +use super::Rep3PrimeFieldShare; + +type IoResult = std::io::Result; + +/// Transforms the replicated shared value x from an arithmetic sharing to a binary sharing. I.e., x = x_1 + x_2 + x_3 gets transformed into x = x'_1 xor x'_2 xor x'_3. +pub async fn a2b( + x: &Rep3PrimeFieldShare, + io_context: &mut IoContext, +) -> IoResult { + let mut x01 = Rep3BigUintShare::zero_share(); + let mut x2 = Rep3BigUintShare::zero_share(); + + let bitlen = usize::try_from(F::MODULUS_BIT_SIZE).expect("u32 fits into usize"); + + let (mut r, r2) = io_context.rngs.rand.random_biguint(bitlen); + r ^= r2; + + match io_context.id { + PartyID::ID0 => { + x01.a = r; + x2.b = x.b.into(); + } + PartyID::ID1 => { + let val: BigUint = (x.a + x.b).into(); + x01.a = val ^ r; + } + PartyID::ID2 => { + x01.a = r; + x2.a = x.a.into(); + } + } + + // Reshare x01 + io_context.network.send_next(x01.a.to_owned())?; + let local_b = io_context.network.recv_prev()?; + x01.b = local_b; + + low_depth_binary_add_mod_p::(x01, x2, io_context, bitlen).await +} + +pub(super) async fn low_depth_binary_add_mod_p( + x1: Rep3BigUintShare, + x2: Rep3BigUintShare, + io_context: &mut IoContext, + bitlen: usize, +) -> IoResult { + let x = low_depth_binary_add(x1, x2, io_context, bitlen).await?; + low_depth_sub_p_cmux::(x, io_context, bitlen).await +} + +async fn low_depth_binary_add( + x1: Rep3BigUintShare, + x2: Rep3BigUintShare, + io_context: &mut IoContext, + bitlen: usize, +) -> IoResult { + // Add x1 + x2 via a packed Kogge-Stone adder + let p = &x1 ^ &x2; + let g = binary::and(&x1, &x2, io_context, bitlen).await?; + kogge_stone_inner(p, g, io_context, bitlen).await +} + +async fn kogge_stone_inner( + mut p: Rep3BigUintShare, + mut g: Rep3BigUintShare, + io_context: &mut IoContext, + bitlen: usize, +) -> IoResult { + let d = ceil_log2(bitlen); + let s_ = p.to_owned(); + + for i in 0..d { + let shift = 1 << i; + let mut p_ = p.to_owned(); + let mut g_ = g.to_owned(); + let mask = (BigUint::from(1u64) << (bitlen - shift)) - BigUint::one(); + p_ &= &mask; + g_ &= &mask; + let p_shift = &p >> shift; + + // TODO: Make and more communication efficient, ATM we send the full element for each level, even though they reduce in size + // maybe just input the mask into AND? + let (r1, r2) = and_twice(p_shift, g_, p_, io_context, bitlen - shift).await?; + p = r2 << shift; + g ^= &(r1 << shift); + } + g <<= 1; + g ^= &s_; + Ok(g) +} + +async fn low_depth_sub_p_cmux( + mut x: Rep3BigUintShare, + io_context: &mut IoContext, + bitlen: usize, +) -> IoResult { + let mask = (BigUint::from(1u64) << bitlen) - BigUint::one(); + let x_msb = &x >> bitlen; + x &= &mask; + let mut y = low_depth_binary_sub_p::(&x, io_context, bitlen).await?; + let y_msb = &y >> (bitlen + 1); + y &= &mask; + + // Spread the ov share to the whole biguint + let ov_a = (x_msb.a.iter_u64_digits().next().unwrap_or_default() + ^ y_msb.a.iter_u64_digits().next().unwrap_or_default()) + & 1; + let ov_b = (x_msb.b.iter_u64_digits().next().unwrap_or_default() + ^ y_msb.b.iter_u64_digits().next().unwrap_or_default()) + & 1; + + let ov_a = if ov_a == 1 { + mask.to_owned() + } else { + BigUint::zero() + }; + let ov_b = if ov_b == 1 { mask } else { BigUint::zero() }; + let ov = Rep3BigUintShare::new(ov_a, ov_b); + + // one big multiplexer + let res = binary::cmux(&ov, &y, &x, io_context, bitlen).await?; + Ok(res) +} + +// Calculates 2^k + x1 - x2 +async fn low_depth_binary_sub( + x1: Rep3BigUintShare, + x2: Rep3BigUintShare, + io_context: &mut IoContext, + bitlen: usize, +) -> IoResult { + // Let x2' = be the bit_not of x2 + // Add x1 + x2' via a packed Kogge-Stone adder, where carry_in = 1 + // This is equivalent to x1 - x2 = x1 + two's complement of x2 + let mask = (BigUint::from(1u64) << bitlen) - BigUint::one(); + // bitnot of x2 + let x2 = binary::xor_public(&x2, &mask, io_context.id); + // Now start the Kogge-Stone adder + let p = &x1 ^ &x2; + let mut g = binary::and(&x1, &x2, io_context, bitlen).await?; + // Since carry_in = 1, we need to XOR the LSB of x1 and x2 to g (i.e., xor the LSB of p) + g ^= &(&p & &BigUint::one()); + + let res = kogge_stone_inner(p, g, io_context, bitlen).await?; + let res = binary::xor_public(&res, &BigUint::one(), io_context.id); // cin=1 + Ok(res) +} + +fn ceil_log2(x: usize) -> usize { + let mut y = 0; + let mut x = x - 1; + while x > 0 { + x >>= 1; + y += 1; + } + y +} + +async fn and_twice( + a: Rep3BigUintShare, + b1: Rep3BigUintShare, + b2: Rep3BigUintShare, + io_context: &mut IoContext, + bitlen: usize, +) -> IoResult<(Rep3BigUintShare, Rep3BigUintShare)> { + debug_assert!(a.a.bits() <= bitlen as u64); + debug_assert!(b1.a.bits() <= bitlen as u64); + debug_assert!(b2.a.bits() <= bitlen as u64); + let (mut mask1, mask_b) = io_context.rngs.rand.random_biguint(bitlen); + mask1 ^= mask_b; + + let (mut mask2, mask_b) = io_context.rngs.rand.random_biguint(bitlen); + mask2 ^= mask_b; + + let local_a1 = (&b1 & &a) ^ mask1; + let local_a2 = (&a & &b2) ^ mask2; + io_context.network.send_next(local_a1.to_owned())?; + io_context.network.send_next(local_a2.to_owned())?; + let local_b1 = io_context.network.recv_prev()?; + let local_b2 = io_context.network.recv_prev()?; + + let r1 = Rep3BigUintShare { + a: local_a1, + b: local_b1, + }; + let r2 = Rep3BigUintShare { + a: local_a2, + b: local_b2, + }; + + Ok((r1, r2)) +} + +async fn low_depth_binary_sub_p( + x: &Rep3BigUintShare, + io_context: &mut IoContext, + bitlen: usize, +) -> IoResult { + let p_ = (BigUint::from(1u64) << (bitlen + 1)) - F::MODULUS.into(); + + // Add x1 + p_ via a packed Kogge-Stone adder + let p = binary::xor_public(&x, &p_, io_context.id); + let g = x & &p_; + kogge_stone_inner(p, g, io_context, bitlen + 1).await +} diff --git a/mpc-core/src/protocols/rep3new/arithmetic.rs b/mpc-core/src/protocols/rep3new/arithmetic.rs index 1f6c3216d..2073e3b97 100644 --- a/mpc-core/src/protocols/rep3new/arithmetic.rs +++ b/mpc-core/src/protocols/rep3new/arithmetic.rs @@ -1,5 +1,3 @@ -use std::marker::PhantomData; - use ark_ff::PrimeField; use itertools::{izip, Itertools}; use types::{Rep3PrimeFieldShare, Rep3PrimeFieldShareVec}; @@ -15,156 +13,153 @@ pub(super) mod types; // this will be moved later pub struct IoContext { + pub(crate) id: PartyID, pub(crate) rngs: Rep3CorrelatedRng, pub(crate) network: N, } -pub struct Arithmetic { - field: PhantomData, - network: PhantomData, +pub fn add(a: &FieldShare, b: &FieldShare) -> FieldShare { + a + b } -impl Arithmetic { - pub fn add(a: &FieldShare, b: &FieldShare) -> FieldShare { - a + b - } - - pub fn add_public(shared: &FieldShare, public: F) -> FieldShare { - shared + public - } +pub fn add_public(shared: &FieldShare, public: F) -> FieldShare { + shared + public +} - pub fn sub(a: &FieldShare, b: &FieldShare) -> FieldShare { - a - b - } +pub fn sub(a: &FieldShare, b: &FieldShare) -> FieldShare { + a - b +} - pub fn sub_public(shared: &FieldShare, public: F) -> FieldShare { - shared - public - } +pub fn sub_public(shared: &FieldShare, public: F) -> FieldShare { + shared - public +} - pub async fn mul( - a: &FieldShare, - b: &FieldShare, - io_context: &mut IoContext, - ) -> IoResult> { - let local_a = a * b + io_context.rngs.rand.masking_field_element::(); - io_context.network.send_next(local_a)?; - let local_b = io_context.network.recv_prev()?; - Ok(FieldShare { - a: local_a, - b: local_b, - }) - } +pub async fn mul( + a: &FieldShare, + b: &FieldShare, + io_context: &mut IoContext, +) -> IoResult> { + let local_a = a * b + io_context.rngs.rand.masking_field_element::(); + io_context.network.send_next(local_a)?; + let local_b = io_context.network.recv_prev()?; + Ok(FieldShare { + a: local_a, + b: local_b, + }) +} - /// Multiply a share b by a public value a: c = a * \[b\]. - pub fn mul_with_public(shared: &FieldShare, public: F) -> FieldShare { - shared * public - } +/// Multiply a share b by a public value a: c = a * \[b\]. +pub fn mul_with_public(shared: &FieldShare, public: F) -> FieldShare { + shared * public +} - pub async fn mul_vec( - a: &FieldShareVec, - b: &FieldShareVec, - io_context: &mut IoContext, - ) -> IoResult> { - //debug_assert_eq!(a.len(), b.len()); - let local_a = izip!(a.a.iter(), a.b.iter(), b.a.iter(), b.b.iter()) - .map(|(aa, ab, ba, bb)| { - *aa * ba + *aa * bb + *ab * ba + io_context.rngs.rand.masking_field_element::() - }) - .collect_vec(); - io_context.network.send_next_many(&local_a)?; - let local_b = io_context.network.recv_prev_many()?; - if local_b.len() != local_a.len() { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "During execution of mul_vec in MPC: Invalid number of elements received", - )); - } - Ok(FieldShareVec::new(local_a, local_b)) +pub async fn mul_vec( + a: &FieldShareVec, + b: &FieldShareVec, + io_context: &mut IoContext, +) -> IoResult> { + //debug_assert_eq!(a.len(), b.len()); + let local_a = izip!(a.a.iter(), a.b.iter(), b.a.iter(), b.b.iter()) + .map(|(aa, ab, ba, bb)| { + *aa * ba + *aa * bb + *ab * ba + io_context.rngs.rand.masking_field_element::() + }) + .collect_vec(); + io_context.network.send_next_many(&local_a)?; + let local_b = io_context.network.recv_prev_many()?; + if local_b.len() != local_a.len() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "During execution of mul_vec in MPC: Invalid number of elements received", + )); } + Ok(FieldShareVec::new(local_a, local_b)) +} - /// Negates a shared value: \[b\] = -\[a\]. - pub fn neg(a: &FieldShare) -> FieldShare { - -a - } +/// Negates a shared value: \[b\] = -\[a\]. +pub fn neg(a: &FieldShare) -> FieldShare { + -a +} - pub async fn inv(a: &FieldShare, io_context: &mut IoContext) -> IoResult> { - let r = FieldShare::rand(&mut io_context.rngs); - let y = Self::mul_open(a, &r, io_context).await?; - if y.is_zero() { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "During execution of inverse in MPC: cannot compute inverse of zero", - )); - } - let y_inv = y - .inverse() - .expect("we checked if y is zero. Must be possible to inverse."); - Ok(r * y_inv) +pub async fn inv( + a: &FieldShare, + io_context: &mut IoContext, +) -> IoResult> { + let r = FieldShare::rand(&mut io_context.rngs); + let y = mul_open(a, &r, io_context).await?; + if y.is_zero() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "During execution of inverse in MPC: cannot compute inverse of zero", + )); } + let y_inv = y + .inverse() + .expect("we checked if y is zero. Must be possible to inverse."); + Ok(r * y_inv) +} - pub async fn open(a: &FieldShare, io_context: &mut IoContext) -> IoResult { - io_context.network.send_next(a.b)?; - let c = io_context.network.recv_prev::()?; - Ok(a.a + a.b + c) - } +pub async fn open( + a: &FieldShare, + io_context: &mut IoContext, +) -> IoResult { + io_context.network.send_next(a.b)?; + let c = io_context.network.recv_prev::()?; + Ok(a.a + a.b + c) +} - /// Computes a CMUX: If cond is 1, returns truthy, otherwise returns falsy. - /// Implementations should not overwrite this method. - pub async fn cmux( - cond: &FieldShare, - truthy: &FieldShare, - falsy: &FieldShare, - io_context: &mut IoContext, - ) -> IoResult> { - let b_min_a = Self::sub(truthy, falsy); - let d = Self::mul(cond, &b_min_a, io_context).await?; - Ok(Self::add(falsy, &d)) - } +/// Computes a CMUX: If cond is 1, returns truthy, otherwise returns falsy. +/// Implementations should not overwrite this method. +pub async fn cmux( + cond: &FieldShare, + truthy: &FieldShare, + falsy: &FieldShare, + io_context: &mut IoContext, +) -> IoResult> { + let b_min_a = sub(truthy, falsy); + let d = mul(cond, &b_min_a, io_context).await?; + Ok(add(falsy, &d)) +} - /// Convenience method for \[a\] + \[b\] * c - pub fn add_mul_public(a: &FieldShare, b: &FieldShare, c: F) -> FieldShare { - Self::add(a, &Self::mul_with_public(b, c)) - } +/// Convenience method for \[a\] + \[b\] * c +pub fn add_mul_public(a: &FieldShare, b: &FieldShare, c: F) -> FieldShare { + add(a, &mul_with_public(b, c)) +} - /// Convenience method for \[a\] + \[b\] * \[c\] - pub async fn add_mul( - &mut self, - a: &FieldShare, - b: &FieldShare, - c: &FieldShare, - io_context: &mut IoContext, - ) -> IoResult> { - Ok(Self::add(a, &Self::mul(c, b, io_context).await?)) - } +/// Convenience method for \[a\] + \[b\] * \[c\] +pub async fn add_mul( + a: &FieldShare, + b: &FieldShare, + c: &FieldShare, + io_context: &mut IoContext, +) -> IoResult> { + let mul = mul(c, b, io_context).await?; + Ok(add(a, &mul)) +} - /// Transforms a public value into a shared value: \[a\] = a. - pub fn promote_to_trivial_share( - public_value: F, - io_context: &mut IoContext, - ) -> FieldShare { - match io_context.network.get_id() { - PartyID::ID0 => Rep3PrimeFieldShare::new(public_value, F::zero()), - PartyID::ID1 => Rep3PrimeFieldShare::new(F::zero(), public_value), - PartyID::ID2 => Rep3PrimeFieldShare::zero_share(), - } +/// Transforms a public value into a shared value: \[a\] = a. +pub fn promote_to_trivial_share(id: PartyID, public_value: F) -> FieldShare { + match id { + PartyID::ID0 => Rep3PrimeFieldShare::new(public_value, F::zero()), + PartyID::ID1 => Rep3PrimeFieldShare::new(F::zero(), public_value), + PartyID::ID2 => Rep3PrimeFieldShare::zero_share(), } +} - /// This function performs a multiplication directly followed by an opening. This safes one round of communication in some MPC protocols compared to calling `mul` and `open` separately. - pub async fn mul_open( - a: &FieldShare, - b: &FieldShare, - io_context: &mut IoContext, - ) -> IoResult { - let a = a * b + io_context.rngs.rand.masking_field_element::(); - io_context.network.send_next(a.to_owned())?; - io_context - .network - .send(io_context.network.get_id().prev_id(), a.to_owned())?; - - let b = io_context.network.recv_prev::()?; - let c = io_context - .network - .recv::(io_context.network.get_id().next_id())?; - Ok(a + b + c) - } +/// This function performs a multiplication directly followed by an opening. This safes one round of communication in some MPC protocols compared to calling `mul` and `open` separately. +pub async fn mul_open( + a: &FieldShare, + b: &FieldShare, + io_context: &mut IoContext, +) -> IoResult { + let a = a * b + io_context.rngs.rand.masking_field_element::(); + io_context.network.send_next(a.to_owned())?; + io_context + .network + .send(io_context.network.get_id().prev_id(), a.to_owned())?; + + let b = io_context.network.recv_prev::()?; + let c = io_context + .network + .recv::(io_context.network.get_id().next_id())?; + Ok(a + b + c) } diff --git a/mpc-core/src/protocols/rep3new/arithmetic/types.rs b/mpc-core/src/protocols/rep3new/arithmetic/types.rs index ac0544a53..a51a20596 100644 --- a/mpc-core/src/protocols/rep3new/arithmetic/types.rs +++ b/mpc-core/src/protocols/rep3new/arithmetic/types.rs @@ -23,7 +23,7 @@ impl Rep3PrimeFieldShare { Self { a, b } } - pub(super) fn zero_share() -> Self { + pub fn zero_share() -> Self { Self { a: F::zero(), b: F::zero(), diff --git a/mpc-core/src/protocols/rep3new/binary.rs b/mpc-core/src/protocols/rep3new/binary.rs index e7d190bb1..7bda091e3 100644 --- a/mpc-core/src/protocols/rep3new/binary.rs +++ b/mpc-core/src/protocols/rep3new/binary.rs @@ -1,8 +1,4 @@ -use std::marker::PhantomData; - -use ark_ff::PrimeField; use num_bigint::BigUint; -use num_traits::ConstZero; use crate::protocols::rep3::{id::PartyID, network::Rep3Network}; @@ -14,79 +10,87 @@ pub(super) mod types; type BinaryShare = types::Rep3BigUintShare; type IoResult = std::io::Result; -pub struct Binary { - field: PhantomData, - network: PhantomData, +pub fn xor(a: &BinaryShare, b: &BinaryShare) -> BinaryShare { + a ^ b } -impl Binary { - // this happens to compile time so we just do "as usize" - const BITLEN: usize = F::MODULUS_BIT_SIZE as usize; - - pub fn xor(a: &BinaryShare, b: &BinaryShare) -> BinaryShare { - a ^ b - } - - pub fn xor_public(shared: &BinaryShare, public: &BigUint) -> BinaryShare { +pub fn xor_public(shared: &BinaryShare, public: &BigUint, id: PartyID) -> BinaryShare { + if let PartyID::ID0 = id { shared ^ public + } else { + shared.to_owned() } +} - pub async fn and( - a: &BinaryShare, - b: &BinaryShare, - io_context: &mut IoContext, - ) -> IoResult { - debug_assert!(a.a.bits() <= u64::try_from(Self::BITLEN).expect("usize fits into u64")); - debug_assert!(b.a.bits() <= u64::try_from(Self::BITLEN).expect("usize fits into u64")); - let (mut mask, mask_b) = io_context.rngs.rand.random_biguint(Self::BITLEN); - mask ^= mask_b; - let local_a = (a & b) ^ mask; - io_context.network.send_next(local_a.to_owned())?; - let local_b = io_context.network.recv_prev()?; - Ok(BinaryShare::new(local_a, local_b)) - } +pub async fn and( + a: &BinaryShare, + b: &BinaryShare, + io_context: &mut IoContext, + bitlen: usize, +) -> IoResult { + debug_assert!(a.a.bits() <= u64::try_from(bitlen).expect("usize fits into u64")); + debug_assert!(b.a.bits() <= u64::try_from(bitlen).expect("usize fits into u64")); + let (mut mask, mask_b) = io_context.rngs.rand.random_biguint(bitlen); + mask ^= mask_b; + let local_a = (a & b) ^ mask; + io_context.network.send_next(local_a.to_owned())?; + let local_b = io_context.network.recv_prev()?; + Ok(BinaryShare::new(local_a, local_b)) +} - pub fn and_with_public(shared: &BinaryShare, public: &BigUint) -> BinaryShare { - shared & public - } +pub fn and_with_public(shared: &BinaryShare, public: &BigUint) -> BinaryShare { + shared & public +} - //pub async fn and_vec( - // a: &FieldShareVec, - // b: &FieldShareVec, - // io_context: &mut IoContext, - //) -> IoResult> { - // //debug_assert_eq!(a.len(), b.len()); - // let local_a = izip!(a.a.iter(), a.b.iter(), b.a.iter(), b.b.iter()) - // .map(|(aa, ab, ba, bb)| { - // *aa * ba + *aa * bb + *ab * ba + io_context.rngs.rand.masking_field_element::() - // }) - // .collect_vec(); - // io_context.network.send_next_many(&local_a)?; - // let local_b = io_context.network.recv_prev_many()?; - // if local_b.len() != local_a.len() { - // return Err(std::io::Error::new( - // std::io::ErrorKind::InvalidData, - // "During execution of mul_vec in MPC: Invalid number of elements received", - // )); - // } - // Ok(FieldShareVec::new(local_a, local_b)) - //} +//pub async fn and_vec( +// a: &FieldShareVec, +// b: &FieldShareVec, +// io_context: &mut IoContext, +//) -> IoResult> { +// //debug_assert_eq!(a.len(), b.len()); +// let local_a = izip!(a.a.iter(), a.b.iter(), b.a.iter(), b.b.iter()) +// .map(|(aa, ab, ba, bb)| { +// *aa * ba + *aa * bb + *ab * ba + io_context.rngs.rand.masking_field_element::() +// }) +// .collect_vec(); +// io_context.network.send_next_many(&local_a)?; +// let local_b = io_context.network.recv_prev_many()?; +// if local_b.len() != local_a.len() { +// return Err(std::io::Error::new( +// std::io::ErrorKind::InvalidData, +// "During execution of mul_vec in MPC: Invalid number of elements received", +// )); +// } +// Ok(FieldShareVec::new(local_a, local_b)) +//} - pub async fn open(a: &BinaryShare, io_context: &mut IoContext) -> IoResult { - io_context.network.send_next(a.b.clone())?; - let c = io_context.network.recv_prev::()?; - Ok(&a.a ^ &a.b ^ c) - } +pub async fn open( + a: &BinaryShare, + io_context: &mut IoContext, +) -> IoResult { + io_context.network.send_next(a.b.clone())?; + let c = io_context.network.recv_prev::()?; + Ok(&a.a ^ &a.b ^ c) +} - /// Transforms a public value into a shared value: \[a\] = a. - pub fn promote_to_trivial_share( - public_value: BigUint, - io_context: &mut IoContext, - ) -> BinaryShare { - match io_context.network.get_id() { - PartyID::ID0 => BinaryShare::new(public_value, BigUint::ZERO), - PartyID::ID1 => BinaryShare::new(BigUint::ZERO, public_value), - PartyID::ID2 => BinaryShare::zero_share(), - } +/// Transforms a public value into a shared value: \[a\] = a. +pub fn promote_to_trivial_share(id: PartyID, public_value: BigUint) -> BinaryShare { + match id { + PartyID::ID0 => BinaryShare::new(public_value, BigUint::ZERO), + PartyID::ID1 => BinaryShare::new(BigUint::ZERO, public_value), + PartyID::ID2 => BinaryShare::zero_share(), } } + +pub async fn cmux( + c: &BinaryShare, + x_t: &BinaryShare, + x_f: &BinaryShare, + io_context: &mut IoContext, + bitlen: usize, +) -> IoResult { + let xor = x_f ^ x_t; + let mut and = and(c, &xor, io_context, bitlen).await?; + and ^= x_f; + Ok(and) +} diff --git a/mpc-core/src/protocols/rep3new/binary/ops.rs b/mpc-core/src/protocols/rep3new/binary/ops.rs index f63b1071c..35ef6ef67 100644 --- a/mpc-core/src/protocols/rep3new/binary/ops.rs +++ b/mpc-core/src/protocols/rep3new/binary/ops.rs @@ -24,6 +24,13 @@ impl std::ops::BitXor<&BigUint> for &Rep3BigUintShare { } } +impl std::ops::BitXorAssign<&Self> for Rep3BigUintShare { + fn bitxor_assign(&mut self, rhs: &Self) { + self.a ^= &rhs.a; + self.b ^= &rhs.b; + } +} + impl std::ops::BitAnd<&BigUint> for &Rep3BigUintShare { type Output = Rep3BigUintShare; @@ -43,6 +50,13 @@ impl std::ops::BitAnd<&Rep3BigUintShare> for &'_ Rep3BigUintShare { } } +impl std::ops::BitAndAssign<&BigUint> for Rep3BigUintShare { + fn bitand_assign(&mut self, rhs: &BigUint) { + self.a &= rhs; + self.b &= rhs; + } +} + impl std::ops::ShlAssign for Rep3BigUintShare { fn shl_assign(&mut self, rhs: usize) { self.a <<= rhs; diff --git a/mpc-core/src/protocols/rep3new/binary/types.rs b/mpc-core/src/protocols/rep3new/binary/types.rs index 52ed24882..6e10bee8d 100644 --- a/mpc-core/src/protocols/rep3new/binary/types.rs +++ b/mpc-core/src/protocols/rep3new/binary/types.rs @@ -1,7 +1,7 @@ use num_bigint::BigUint; /// This type represents a packed vector of replicated shared bits. Each additively shared vector is represented as [BigUint]. Thus, this type contains two [BigUint]s. -#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Rep3BigUintShare { pub(crate) a: BigUint, pub(crate) b: BigUint, @@ -13,7 +13,7 @@ impl Rep3BigUintShare { Self { a, b } } - pub(super) fn zero_share() -> Self { + pub fn zero_share() -> Self { Self { a: BigUint::ZERO, b: BigUint::ZERO,