Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Optimize shamir double randomnes generation using seeds #214

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 82 additions & 45 deletions mpc-core/src/protocols/shamir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ pub fn combine_curve_point<C: CurveGroup>(
Ok(rec)
}

/// This type is used to construct a [`SahmirProtocol`].
/// This type is used to construct a [`ShamirProtocol`].
/// Preprocess `amount` number of corre;ated randomness pairs that are consumed while using the protocol.
pub struct ShamirPreprocessing<F: PrimeField, N: ShamirNetwork> {
threshold: usize,
Expand All @@ -188,18 +188,16 @@ impl<F: PrimeField, N: ShamirNetwork> ShamirPreprocessing<F, N> {
eyre::bail!("Threshold too large for number of parties")
}

let num_parties = network.get_num_parties();

let seed: [u8; crate::SEED_SIZE] = RngType::from_entropy().gen();
let mut rng_buffer = ShamirRng::new(seed, threshold, num_parties);
let mut rng_buffer = ShamirRng::new(seed, threshold, &mut network)?;

tracing::info!(
"Party {}: generating correlated randomness..",
network.get_id()
);
let start = Instant::now();
// buffer_triple generates amount * (t + 1), so we ceil dive the amount we want
let amount = amount.div_ceil(threshold + 1);
// buffer_triple generates amount * batch_size, so we ceil dive the amount we want
let amount = amount.div_ceil(rng_buffer.get_size_per_batch());
rng_buffer.buffer_triples(&mut network, amount)?;
tracing::info!(
"Party {}: generating took {} ms",
Expand Down Expand Up @@ -234,11 +232,22 @@ impl<F: PrimeField, N: ShamirNetwork> From<ShamirPreprocessing<F, N>> for Shamir
let mul_lagrange_2t =
core::lagrange_from_coeff(&(1..=2 * value.threshold + 1).collect::<Vec<_>>());

#[allow(clippy::assertions_on_constants)]
{
debug_assert_eq!(Self::KING_ID, 0); // Slightly different implementation required in degree reduce if not
}

// precompute the poly for interpolating a secret with known zero shares
let num_non_zero = num_parties - value.threshold;
let zero_points = (num_non_zero + 1..=num_parties).collect::<Vec<_>>();
let mul_reconstruct_with_zeros = core::interpolation_poly_from_zero_points(&zero_points);

ShamirProtocol {
threshold: value.threshold,
open_lagrange_t,
open_lagrange_2t,
mul_lagrange_2t,
mul_reconstruct_with_zeros,
rng: value.rng_buffer.rng,
r_t: value.rng_buffer.r_t,
r_2t: value.rng_buffer.r_2t,
Expand All @@ -255,6 +264,7 @@ pub struct ShamirProtocol<F: PrimeField, N: ShamirNetwork> {
pub open_lagrange_t: Vec<F>,
pub(crate) open_lagrange_2t: Vec<F>,
mul_lagrange_2t: Vec<F>,
mul_reconstruct_with_zeros: Vec<F>,
rng: RngType,
pub(crate) r_t: Vec<F>,
pub(crate) r_2t: Vec<F>,
Expand All @@ -272,6 +282,7 @@ impl<F: PrimeField, N: ShamirNetwork> ShamirProtocol<F, N> {
open_lagrange_t: self.open_lagrange_t.clone(),
open_lagrange_2t: self.open_lagrange_2t.clone(),
mul_lagrange_2t: self.mul_lagrange_2t.clone(),
mul_reconstruct_with_zeros: self.mul_reconstruct_with_zeros.clone(),
rng: RngType::from_seed(self.rng.gen()),
r_t: self.r_t.drain(0..amount).collect(),
r_2t: self.r_2t.drain(0..amount).collect(),
Expand All @@ -297,6 +308,8 @@ impl<F: PrimeField, N: ShamirNetwork> ShamirProtocol<F, N> {
}

pub(crate) fn degree_reduce(&mut self, mut input: F) -> std::io::Result<ShamirShare<F>> {
let num_non_zero = self.network.get_num_parties() - self.threshold;

let (r_t, r_2t) = self.get_pair()?;
input += r_2t;

Expand All @@ -315,18 +328,21 @@ impl<F: PrimeField, N: ShamirNetwork> ShamirProtocol<F, N> {
// So far parties who do not require sending, do not send, so no receive here

// Send fresh shares
let shares = core::share(
acc,
self.network.get_num_parties(),
self.threshold,
&mut self.rng,
// Since <acc> does not have to be private, we share it as a known polynomial, such that t parties know their share is 0. Consequently we can reduce the amount of communication.
// Note: When expanding t+1 double shares to n double shares (Atlas) we cannot do this anymore, since <acc> needs to stay private. Atlas also requires rotating the King server.

let poly = core::poly_with_zeros_from_precomputed(
&acc,
self.mul_reconstruct_with_zeros.to_owned(),
);

let mut my_share = F::default();
for (other_id, share) in shares.into_iter().enumerate() {
if my_id == other_id {
my_share = share;
for id in 0..num_non_zero {
let val = core::evaluate_poly(&poly, F::from(id as u64 + 1));
if id == my_id {
my_share = val;
} else {
self.network.send(other_id, share)?;
self.network.send(id, val)?;
}
}
my_share
Expand All @@ -335,7 +351,11 @@ impl<F: PrimeField, N: ShamirNetwork> ShamirProtocol<F, N> {
// Only send if my items are required
self.network.send(Self::KING_ID, input)?;
}
self.network.recv(Self::KING_ID)?
if my_id < num_non_zero {
self.network.recv(Self::KING_ID)?
} else {
F::zero()
}
};

Ok(ShamirShare::new(my_share - r_t))
Expand All @@ -346,6 +366,8 @@ impl<F: PrimeField, N: ShamirNetwork> ShamirProtocol<F, N> {
&mut self,
mut inputs: Vec<F>,
) -> std::io::Result<Vec<ShamirShare<F>>> {
let num_non_zero = self.network.get_num_parties() - self.threshold;

let len = inputs.len();
let mut r_ts = Vec::with_capacity(len);

Expand Down Expand Up @@ -379,28 +401,29 @@ impl<F: PrimeField, N: ShamirNetwork> ShamirProtocol<F, N> {
// So far parties who do not require sending, do not send, so no receive here

// Send fresh shares
let mut shares = (0..self.network.get_num_parties())
.map(|_| Vec::with_capacity(len))
.collect::<Vec<_>>();
// Since <acc> does not have to be private, we share it as a known polynomial, such that t parties know their share is 0. Consequently we can reduce the amount of communication.
// Note: When expanding t+1 double shares to n double shares (Atlas) we cannot do this anymore, since <acc> needs to stay private. Atlas also requires rotating the King server.

let mut polys = Vec::with_capacity(acc.len());
for acc in acc {
let s = core::share(
acc,
self.network.get_num_parties(),
self.threshold,
&mut self.rng,
let poly = core::poly_with_zeros_from_precomputed(
&acc,
self.mul_reconstruct_with_zeros.to_owned(),
);
for (des, src) in izip!(&mut shares, s) {
des.push(src);
}
polys.push(poly);
}

let mut my_share = Vec::new();
for (other_id, share) in shares.into_iter().enumerate() {
if my_id == other_id {
my_share = share;
for id in 0..num_non_zero {
let id_f = F::from(id as u64 + 1);
let vals = polys
.iter()
.map(|poly| core::evaluate_poly(poly, id_f))
.collect::<Vec<_>>();
if id == my_id {
my_share = vals;
} else {
self.network.send_many(other_id, &share)?;
self.network.send_many(id, &vals)?;
}
}
my_share
Expand All @@ -409,13 +432,17 @@ impl<F: PrimeField, N: ShamirNetwork> ShamirProtocol<F, N> {
// Only send if my items are required
self.network.send_many(Self::KING_ID, &inputs)?;
}
let r = self.network.recv_many::<F>(Self::KING_ID)?;
if r.len() != len {
return Err(std::io::Error::new(
if my_id < num_non_zero {
let r = self.network.recv_many::<F>(Self::KING_ID)?;
if r.len() != len {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,"During execution of degree_reduce_vec in MPC: Invalid number of elements received",
));
}
r
} else {
vec![F::zero(); len]
}
r
};

for (share, r) in izip!(&mut my_shares, r_ts) {
Expand All @@ -431,6 +458,8 @@ impl<F: PrimeField, N: ShamirNetwork> ShamirProtocol<F, N> {
where
C: CurveGroup + std::ops::Mul<F, Output = C> + for<'a> std::ops::Mul<&'a F, Output = C>,
{
let num_non_zero = self.network.get_num_parties() - self.threshold;

let (r_t, r_2t) = self.get_pair()?;
let r_t = C::generator().mul(r_t);
let r_2t = C::generator().mul(r_2t);
Expand All @@ -452,27 +481,35 @@ impl<F: PrimeField, N: ShamirNetwork> ShamirProtocol<F, N> {
// So far parties who do not require sending, do not send, so no receive here

// Send fresh shares
let shares = core::share_point(
acc,
self.network.get_num_parties(),
self.threshold,
&mut self.rng,
// Since <acc> does not have to be private, we share it as a known polynomial, such that t parties know their share is 0. Consequently we can reduce the amount of communication.
// Note: When expanding t+1 double shares to n double shares (Atlas) we cannot do this anymore, since <acc> needs to stay private. Atlas also requires rotating the King server.

let poly = core::poly_with_zeros_from_precomputed_point(
&acc,
&self.mul_reconstruct_with_zeros,
);

let mut my_share = C::default();
for (other_id, share) in shares.into_iter().enumerate() {
if my_id == other_id {
my_share = share;
for id in 0..num_non_zero {
let val = core::evaluate_poly_point(&poly, C::ScalarField::from(id as u64 + 1));
if id == my_id {
my_share = val;
} else {
self.network.send(other_id, share)?;
self.network.send(id, val)?;
}
}

my_share
} else {
if my_id <= self.threshold * 2 {
// Only send if my items are required
self.network.send(Self::KING_ID, input)?;
}
self.network.recv(Self::KING_ID)?
if my_id < num_non_zero {
self.network.recv(Self::KING_ID)?
} else {
C::default()
}
};

Ok(ShamirPointShare::new(my_share - r_t))
Expand Down
Loading
Loading