Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve triple generation speed by spinning up as a tokio::task #801

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chain-signatures/node/src/protocol/cryptography.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ impl CryptographicProtocol for RunningState {
if let Err(err) = triple_manager.stockpile(active, protocol_cfg) {
tracing::warn!(?err, "running: failed to stockpile triples");
}
for (p, msg) in triple_manager.poke(protocol_cfg).await {
for (p, msg) in triple_manager.poke().await {
let info = self.fetch_participant(&p)?;
messages.push(info.clone(), MpcMessage::Triple(msg));
}
Expand Down
2 changes: 1 addition & 1 deletion chain-signatures/node/src/protocol/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ impl MessageHandler for RunningState {

if let Some(protocol) = protocol {
while let Some(message) = queue.pop_front() {
protocol.message(message.from, message.data);
protocol.message(message.from, message.data).await?;
}
}
}
Expand Down
4 changes: 1 addition & 3 deletions chain-signatures/node/src/protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,7 @@ impl MpcSignProtocol {

let message_time = Instant::now();
if let Err(err) = state.handle(&self, &mut queue).await {
tracing::info!("protocol unable to handle messages: {err:?}");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
tracing::warn!("protocol unable to handle messages: {err:?}");
}
crate::metrics::PROTOCOL_LATENCY_ITER_MESSAGE
.with_label_values(&[my_account_id.as_str()])
Expand Down
152 changes: 104 additions & 48 deletions chain-signatures/node/src/protocol/triple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use super::message::TripleMessage;
use super::presignature::GenerationError;
use crate::gcp::error;
use crate::storage::triple_storage::{LockTripleNodeStorageBox, TripleData};
use crate::types::TripleProtocol;
use crate::util::AffinePointExt;

use cait_sith::protocol::{Action, InitializationError, Participant, ProtocolError};
use cait_sith::protocol::{
Action, InitializationError, MessageData, Participant, Protocol, ProtocolError,
};
use cait_sith::triples::{TripleGenerationOutput, TriplePub, TripleShare};
use chrono::Utc;
use highway::{HighwayHash, HighwayHasher};
Expand Down Expand Up @@ -38,41 +39,111 @@ pub struct Triple {
pub struct TripleGenerator {
pub id: TripleId,
pub participants: Vec<Participant>,
pub protocol: TripleProtocol,
pub threshold: usize,
pub timestamp: Option<Instant>,
pub timeout: Duration,

/// Join handle for spawned task that runs the protocol.
join_handle: tokio::task::JoinHandle<Result<(), ProtocolError>>,
/// Message sender for when the node receives a message and needs to forward it to the protocl task.
message_tx: tokio::sync::mpsc::Sender<(Participant, MessageData)>,
/// Message receiver for when the protocol needs to send a message and we should get it back
/// on the main runtime thread for it to be sent to other nodes.
protocol_rx: tokio::sync::mpsc::Receiver<Action<TripleGenerationOutput<Secp256k1>>>,
}

impl TripleGenerator {
pub fn new(
id: TripleId,
me: Participant,
participants: Vec<Participant>,
protocol: TripleProtocol,
threshold: usize,
timeout: u64,
) -> Self {
Self {
) -> Result<Self, InitializationError> {
let (message_tx, mut message_rx) = tokio::sync::mpsc::channel(2048);
let (protocol_tx, protocol_rx) = tokio::sync::mpsc::channel(2048);

let mut protocol =
cait_sith::triples::generate_triple::<Secp256k1>(&participants, me, threshold)?;

let join_handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_millis(500));

loop {
tokio::select! {
Some((from, data)) = message_rx.recv() => {
protocol.message(from, data);
}
_ = interval.tick() => {
loop {
let action = protocol.poke()?;
match &action {
Action::Wait => {
protocol_tx.send(action).await.map_err(|err| {
ProtocolError::Other(Box::new(err))
})?;
break;
}
Action::Return(_) => {
protocol_tx.send(action).await.map_err(|err| {
ProtocolError::Other(Box::new(err))
})?;
return Ok(());
}
Action::SendMany(_) | Action::SendPrivate(_, _) => {
protocol_tx.send(action).await.map_err(|err| {
ProtocolError::Other(Box::new(err))
})?;
}
}
}
}
}
}
});

Ok(Self {
id,
participants,
protocol,
threshold,
timestamp: None,
timeout: Duration::from_millis(timeout),
}
join_handle,
message_tx,
protocol_rx,
})
}

pub fn poke(&mut self) -> Result<Action<TripleGenerationOutput<Secp256k1>>, ProtocolError> {
pub async fn poke(
&mut self,
) -> Result<Action<TripleGenerationOutput<Secp256k1>>, ProtocolError> {
let timestamp = self.timestamp.get_or_insert_with(Instant::now);
if timestamp.elapsed() > self.timeout {
tracing::info!(
id = self.id,
elapsed = ?timestamp.elapsed(),
"triple protocol timed out"
);
self.join_handle.abort();
return Err(ProtocolError::Other(
anyhow::anyhow!("triple protocol timed out").into(),
));
}

self.protocol.poke()
self.protocol_rx.recv().await.ok_or_else(|| {
ProtocolError::Other(anyhow::anyhow!("action sender has been dropped").into())
})
}

pub async fn message(
&mut self,
from: Participant,
data: MessageData,
) -> Result<(), ProtocolError> {
self.message_tx
.send((from, data))
.await
.map_err(|err| ProtocolError::Other(Box::new(err)))
}
}

Expand Down Expand Up @@ -222,14 +293,9 @@ impl TripleManager {

tracing::debug!(id, "starting protocol to generate a new triple");
let participants: Vec<_> = participants.keys().cloned().collect();
let protocol: TripleProtocol = Box::new(cait_sith::triples::generate_triple::<Secp256k1>(
&participants,
self.me,
self.threshold,
)?);
self.generators.insert(
id,
TripleGenerator::new(id, participants, protocol, timeout),
TripleGenerator::new(id, self.me, participants, self.threshold, timeout)?,
);
self.queued.push_back(id);
self.introduced.insert(id);
Expand Down Expand Up @@ -394,7 +460,7 @@ impl TripleManager {
id: TripleId,
participants: &Participants,
cfg: &ProtocolConfig,
) -> Result<Option<&mut TripleProtocol>, CryptographicError> {
) -> Result<Option<&mut TripleGenerator>, CryptographicError> {
if self.triples.contains_key(&id) || self.gc.contains_key(&id) {
Ok(None)
} else {
Expand All @@ -409,24 +475,20 @@ impl TripleManager {

tracing::debug!(id, "joining protocol to generate a new triple");
let participants = participants.keys_vec();
let protocol = Box::new(cait_sith::triples::generate_triple::<Secp256k1>(
&participants,
self.me,
self.threshold,
)?);
let generator = e.insert(TripleGenerator::new(
id,
self.me,
participants,
protocol,
self.threshold,
cfg.triple.generation_timeout,
));
)?);
self.queued.push_back(id);
crate::metrics::NUM_TOTAL_HISTORICAL_TRIPLE_GENERATORS
.with_label_values(&[self.my_account_id.as_str()])
.inc();
Ok(Some(&mut generator.protocol))
Ok(Some(generator))
}
Entry::Occupied(e) => Ok(Some(&mut e.into_mut().protocol)),
Entry::Occupied(e) => Ok(Some(e.into_mut())),
}
}
}
Expand All @@ -435,27 +497,19 @@ impl TripleManager {
/// messages to be sent to the respective participant.
///
/// An empty vector means we cannot progress until we receive a new message.
pub async fn poke(&mut self, cfg: &ProtocolConfig) -> Vec<(Participant, TripleMessage)> {
// Add more protocols to the ongoing pool if there is space.
let to_generate_len = cfg.max_concurrent_generation as usize - self.ongoing.len();
if !self.queued.is_empty() && to_generate_len > 0 {
for _ in 0..to_generate_len {
self.queued.pop_front().map(|id| self.ongoing.insert(id));
}
}

pub async fn poke(&mut self) -> Vec<(Participant, TripleMessage)> {
let mut messages = Vec::new();
let mut triples_to_insert = Vec::new();
let mut errors = Vec::new();
self.generators.retain(|id, generator| {
if !self.ongoing.contains(id) {
// If the protocol is not ongoing, we should retain it for the next time
// it is in the ongoing pool.
return true;
}

let ids = self.generators.keys().into_iter().cloned().collect::<Vec<_>>();
for id in &ids {
let Some((_, mut generator)) = self.generators.remove_entry(id) else {
continue;
};

loop {
let action = match generator.poke() {
let action = match generator.poke().await {
Ok(action) => action,
Err(e) => {
errors.push(e);
Expand All @@ -466,15 +520,17 @@ impl TripleManager {
elapsed = ?generator.timestamp.unwrap().elapsed(),
"added {id} to failed triples"
);
break false;
break;
}
};

match action {
Action::Wait => {
tracing::trace!("waiting");
// Retain protocol until we are finished
break true;

// protocol not done: insert back to our pool of generators.
self.generators.insert(*id, generator);
break;
}
Action::SendMany(data) => {
for p in &generator.participants {
Expand Down Expand Up @@ -559,12 +615,12 @@ impl TripleManager {
// Protocol done, remove it from the ongoing pool.
self.ongoing.remove(id);
self.introduced.remove(id);
// Do not retain the protocol
break false;
break;
}
}
}
});
}
// });
self.insert_triples_to_storage(triples_to_insert).await;

if !errors.is_empty() {
Expand Down
4 changes: 2 additions & 2 deletions chain-signatures/node/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl TestTripleManagers {

async fn poke(&mut self, index: usize) -> Result<bool, ProtocolError> {
let mut quiet = true;
let messages = self.managers[index].poke(&self.config.protocol).await;
let messages = self.managers[index].poke().await;
for (
participant,
ref tm @ TripleMessage {
Expand All @@ -101,7 +101,7 @@ impl TestTripleManagers {
.get_or_generate(id, &self.participants, &self.config.protocol)
.unwrap()
{
protocol.message(from, data.to_vec());
protocol.message(from, data.to_vec()).await.unwrap();
} else {
println!("Tried to write to completed mailbox {:?}", tm);
}
Expand Down
Loading
Loading