Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve triple generation speed by spinning up as a tokio::task #801

Closed
wants to merge 2 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Triple protocol now gets spunned up as a tokio::task
Phuong Nguyen committed Aug 3, 2024
commit 1800d070bfa7333bb202efc22c03d756e3414347
2 changes: 1 addition & 1 deletion chain-signatures/node/src/protocol/cryptography.rs
Original file line number Diff line number Diff line change
@@ -381,7 +381,7 @@ impl CryptographicProtocol for RunningState {
if let Err(err) = triple_manager.stockpile(active, protocol_cfg) {
tracing::warn!(?err, "running: failed to stockpile triples");
}
for (p, msg) in triple_manager.poke(protocol_cfg).await {
for (p, msg) in triple_manager.poke().await {
let info = self.fetch_participant(&p)?;
messages.push(info.clone(), MpcMessage::Triple(msg));
}
2 changes: 1 addition & 1 deletion chain-signatures/node/src/protocol/message.rs
Original file line number Diff line number Diff line change
@@ -267,7 +267,7 @@ impl MessageHandler for RunningState {

if let Some(protocol) = protocol {
while let Some(message) = queue.pop_front() {
protocol.message(message.from, message.data);
protocol.message(message.from, message.data).await?;
}
}
}
4 changes: 1 addition & 3 deletions chain-signatures/node/src/protocol/mod.rs
Original file line number Diff line number Diff line change
@@ -327,9 +327,7 @@ impl MpcSignProtocol {

let message_time = Instant::now();
if let Err(err) = state.handle(&self, &mut queue).await {
tracing::info!("protocol unable to handle messages: {err:?}");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
tracing::warn!("protocol unable to handle messages: {err:?}");
}
crate::metrics::PROTOCOL_LATENCY_ITER_MESSAGE
.with_label_values(&[my_account_id.as_str()])
152 changes: 104 additions & 48 deletions chain-signatures/node/src/protocol/triple.rs
Original file line number Diff line number Diff line change
@@ -4,10 +4,11 @@ use super::message::TripleMessage;
use super::presignature::GenerationError;
use crate::gcp::error;
use crate::storage::triple_storage::{LockTripleNodeStorageBox, TripleData};
use crate::types::TripleProtocol;
use crate::util::AffinePointExt;

use cait_sith::protocol::{Action, InitializationError, Participant, ProtocolError};
use cait_sith::protocol::{
Action, InitializationError, MessageData, Participant, Protocol, ProtocolError,
};
use cait_sith::triples::{TripleGenerationOutput, TriplePub, TripleShare};
use chrono::Utc;
use highway::{HighwayHash, HighwayHasher};
@@ -38,41 +39,111 @@ pub struct Triple {
pub struct TripleGenerator {
pub id: TripleId,
pub participants: Vec<Participant>,
pub protocol: TripleProtocol,
pub threshold: usize,
pub timestamp: Option<Instant>,
pub timeout: Duration,

/// Join handle for spawned task that runs the protocol.
join_handle: tokio::task::JoinHandle<Result<(), ProtocolError>>,
/// Message sender for when the node receives a message and needs to forward it to the protocl task.
message_tx: tokio::sync::mpsc::Sender<(Participant, MessageData)>,
/// Message receiver for when the protocol needs to send a message and we should get it back
/// on the main runtime thread for it to be sent to other nodes.
protocol_rx: tokio::sync::mpsc::Receiver<Action<TripleGenerationOutput<Secp256k1>>>,
}

impl TripleGenerator {
pub fn new(
id: TripleId,
me: Participant,
participants: Vec<Participant>,
protocol: TripleProtocol,
threshold: usize,
timeout: u64,
) -> Self {
Self {
) -> Result<Self, InitializationError> {
let (message_tx, mut message_rx) = tokio::sync::mpsc::channel(2048);
let (protocol_tx, protocol_rx) = tokio::sync::mpsc::channel(2048);

let mut protocol =
cait_sith::triples::generate_triple::<Secp256k1>(&participants, me, threshold)?;

let join_handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_millis(500));

loop {
tokio::select! {
Some((from, data)) = message_rx.recv() => {
protocol.message(from, data);
}
_ = interval.tick() => {
loop {
let action = protocol.poke()?;
match &action {
Action::Wait => {
protocol_tx.send(action).await.map_err(|err| {
ProtocolError::Other(Box::new(err))
})?;
break;
}
Action::Return(_) => {
protocol_tx.send(action).await.map_err(|err| {
ProtocolError::Other(Box::new(err))
})?;
return Ok(());
}
Action::SendMany(_) | Action::SendPrivate(_, _) => {
protocol_tx.send(action).await.map_err(|err| {
ProtocolError::Other(Box::new(err))
})?;
}
}
}
}
}
}
});

Ok(Self {
id,
participants,
protocol,
threshold,
timestamp: None,
timeout: Duration::from_millis(timeout),
}
join_handle,
message_tx,
protocol_rx,
})
}

pub fn poke(&mut self) -> Result<Action<TripleGenerationOutput<Secp256k1>>, ProtocolError> {
pub async fn poke(
&mut self,
) -> Result<Action<TripleGenerationOutput<Secp256k1>>, ProtocolError> {
let timestamp = self.timestamp.get_or_insert_with(Instant::now);
if timestamp.elapsed() > self.timeout {
tracing::info!(
id = self.id,
elapsed = ?timestamp.elapsed(),
"triple protocol timed out"
);
self.join_handle.abort();
return Err(ProtocolError::Other(
anyhow::anyhow!("triple protocol timed out").into(),
));
}

self.protocol.poke()
self.protocol_rx.recv().await.ok_or_else(|| {
ProtocolError::Other(anyhow::anyhow!("action sender has been dropped").into())
})
}

pub async fn message(
&mut self,
from: Participant,
data: MessageData,
) -> Result<(), ProtocolError> {
self.message_tx
.send((from, data))
.await
.map_err(|err| ProtocolError::Other(Box::new(err)))
}
}

@@ -222,14 +293,9 @@ impl TripleManager {

tracing::debug!(id, "starting protocol to generate a new triple");
let participants: Vec<_> = participants.keys().cloned().collect();
let protocol: TripleProtocol = Box::new(cait_sith::triples::generate_triple::<Secp256k1>(
&participants,
self.me,
self.threshold,
)?);
self.generators.insert(
id,
TripleGenerator::new(id, participants, protocol, timeout),
TripleGenerator::new(id, self.me, participants, self.threshold, timeout)?,
);
self.queued.push_back(id);
self.introduced.insert(id);
@@ -394,7 +460,7 @@ impl TripleManager {
id: TripleId,
participants: &Participants,
cfg: &ProtocolConfig,
) -> Result<Option<&mut TripleProtocol>, CryptographicError> {
) -> Result<Option<&mut TripleGenerator>, CryptographicError> {
if self.triples.contains_key(&id) || self.gc.contains_key(&id) {
Ok(None)
} else {
@@ -409,24 +475,20 @@ impl TripleManager {

tracing::debug!(id, "joining protocol to generate a new triple");
let participants = participants.keys_vec();
let protocol = Box::new(cait_sith::triples::generate_triple::<Secp256k1>(
&participants,
self.me,
self.threshold,
)?);
let generator = e.insert(TripleGenerator::new(
id,
self.me,
participants,
protocol,
self.threshold,
cfg.triple.generation_timeout,
));
)?);
self.queued.push_back(id);
crate::metrics::NUM_TOTAL_HISTORICAL_TRIPLE_GENERATORS
.with_label_values(&[self.my_account_id.as_str()])
.inc();
Ok(Some(&mut generator.protocol))
Ok(Some(generator))
}
Entry::Occupied(e) => Ok(Some(&mut e.into_mut().protocol)),
Entry::Occupied(e) => Ok(Some(e.into_mut())),
}
}
}
@@ -435,27 +497,19 @@ impl TripleManager {
/// messages to be sent to the respective participant.
///
/// An empty vector means we cannot progress until we receive a new message.
pub async fn poke(&mut self, cfg: &ProtocolConfig) -> Vec<(Participant, TripleMessage)> {
// Add more protocols to the ongoing pool if there is space.
let to_generate_len = cfg.max_concurrent_generation as usize - self.ongoing.len();
if !self.queued.is_empty() && to_generate_len > 0 {
for _ in 0..to_generate_len {
self.queued.pop_front().map(|id| self.ongoing.insert(id));
}
}

pub async fn poke(&mut self) -> Vec<(Participant, TripleMessage)> {
let mut messages = Vec::new();
let mut triples_to_insert = Vec::new();
let mut errors = Vec::new();
self.generators.retain(|id, generator| {
if !self.ongoing.contains(id) {
// If the protocol is not ongoing, we should retain it for the next time
// it is in the ongoing pool.
return true;
}

let ids = self.generators.keys().into_iter().cloned().collect::<Vec<_>>();
for id in &ids {
let Some((_, mut generator)) = self.generators.remove_entry(id) else {
continue;
};

loop {
let action = match generator.poke() {
let action = match generator.poke().await {
Ok(action) => action,
Err(e) => {
errors.push(e);
@@ -466,15 +520,17 @@ impl TripleManager {
elapsed = ?generator.timestamp.unwrap().elapsed(),
"added {id} to failed triples"
);
break false;
break;
}
};

match action {
Action::Wait => {
tracing::trace!("waiting");
// Retain protocol until we are finished
break true;

// protocol not done: insert back to our pool of generators.
self.generators.insert(*id, generator);
break;
}
Action::SendMany(data) => {
for p in &generator.participants {
@@ -559,12 +615,12 @@ impl TripleManager {
// Protocol done, remove it from the ongoing pool.
self.ongoing.remove(id);
self.introduced.remove(id);
// Do not retain the protocol
break false;
break;
}
}
}
});
}
// });
self.insert_triples_to_storage(triples_to_insert).await;

if !errors.is_empty() {
4 changes: 2 additions & 2 deletions chain-signatures/node/src/test_utils.rs
Original file line number Diff line number Diff line change
@@ -85,7 +85,7 @@ impl TestTripleManagers {

async fn poke(&mut self, index: usize) -> Result<bool, ProtocolError> {
let mut quiet = true;
let messages = self.managers[index].poke(&self.config.protocol).await;
let messages = self.managers[index].poke().await;
for (
participant,
ref tm @ TripleMessage {
@@ -101,7 +101,7 @@ impl TestTripleManagers {
.get_or_generate(id, &self.participants, &self.config.protocol)
.unwrap()
{
protocol.message(from, data.to_vec());
protocol.message(from, data.to_vec()).await.unwrap();
} else {
println!("Tried to write to completed mailbox {:?}", tm);
}
Loading