diff --git a/Cargo.lock b/Cargo.lock index 014cc268b0b..96bfb177d48 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2422,6 +2422,7 @@ dependencies = [ "libp2p-swarm", "libp2p-yamux", "log", + "parking_lot 0.12.1", "prost", "prost-build", "quickcheck-ext", diff --git a/protocols/kad/Cargo.toml b/protocols/kad/Cargo.toml index 26c19275014..c119ad3e74b 100644 --- a/protocols/kad/Cargo.toml +++ b/protocols/kad/Cargo.toml @@ -20,6 +20,7 @@ futures = "0.3.26" log = "0.4" libp2p-core = { version = "0.39.0", path = "../../core" } libp2p-swarm = { version = "0.42.0", path = "../../swarm" } +parking_lot = "0.12.0" prost = "0.11" rand = "0.8" sha2 = "0.10.0" diff --git a/protocols/kad/src/behaviour.rs b/protocols/kad/src/behaviour.rs index 986b24a0740..f093b59f77f 100644 --- a/protocols/kad/src/behaviour.rs +++ b/protocols/kad/src/behaviour.rs @@ -1745,7 +1745,7 @@ where &mut self, key: record::Key, provider: KadPeer, - guard: InboundStreamEventGuard, + guard: Arc, ) { if &provider.node_id != self.kbuckets.local_key().preimage() { let record = ProviderRecord { @@ -1777,7 +1777,7 @@ where KademliaEvent::InboundRequest { request: InboundRequest::AddProvider { record: Some(record), - guard: Some(Arc::new(guard)), + guard: Some(guard), }, }, )); diff --git a/protocols/kad/src/handler.rs b/protocols/kad/src/handler.rs index 7b54b67a59d..47ffef1c0f8 100644 --- a/protocols/kad/src/handler.rs +++ b/protocols/kad/src/handler.rs @@ -36,9 +36,10 @@ use libp2p_swarm::{ KeepAlive, NegotiatedSubstream, SubstreamProtocol, }; use log::trace; +use parking_lot::Mutex; use std::collections::VecDeque; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use std::task::Waker; use std::{ error, fmt, io, marker::PhantomData, pin::Pin, task::Context, task::Poll, time::Duration, @@ -178,13 +179,14 @@ enum OutboundSubstreamState { #[derive(Debug)] pub struct InboundStreamEventGuard { ready: Arc, - waker: Option, + waker: Mutex>, } impl Drop for InboundStreamEventGuard { fn drop(&mut self) { self.ready.store(true, Ordering::Release); self.waker + .lock() .take() .expect("Only called once in Drop impl") .wake(); @@ -206,9 +208,9 @@ enum InboundSubstreamState { KadInStreamSink, Option, ), - PendingStateTransition { - ready: Arc, - next_state: Box>, + PendingProcessing { + weak_guard: Weak, + substream: KadInStreamSink, }, /// Waiting to send an answer back to the remote. PendingSend( @@ -268,15 +270,12 @@ impl InboundSubstreamState { ) { InboundSubstreamState::WaitingMessage { substream, .. } | InboundSubstreamState::WaitingBehaviour(_, substream, _) + | InboundSubstreamState::PendingProcessing { substream, .. } | InboundSubstreamState::PendingSend(_, substream, _) | InboundSubstreamState::PendingFlush(_, substream) | InboundSubstreamState::Closing(substream) => { *self = InboundSubstreamState::Closing(substream); } - InboundSubstreamState::PendingStateTransition { next_state, .. } => { - *self = *next_state; - self.close(); - } InboundSubstreamState::Cancelled => { *self = InboundSubstreamState::Cancelled; } @@ -348,7 +347,7 @@ pub enum KademliaHandlerEvent { /// The peer that is the provider of the value for `key`. provider: KadPeer, /// Guard corresponding to inbound stream that generated this event. - guard: InboundStreamEventGuard, + guard: Arc, }, /// Request to get a value from the dht records @@ -1077,16 +1076,14 @@ where } Poll::Ready(Some(Ok(KadRequestMsg::AddProvider { key, provider }))) => { let ready = Arc::new(AtomicBool::new(false)); - let guard = InboundStreamEventGuard { - ready: ready.clone(), - waker: Some(cx.waker().clone()), - }; - let next_state = Box::new(InboundSubstreamState::WaitingMessage { - first: false, - connection_id, - substream, + let guard = Arc::new(InboundStreamEventGuard { + ready, + waker: Mutex::new(Some(cx.waker().clone())), }); - *this = InboundSubstreamState::PendingStateTransition { ready, next_state }; + *this = InboundSubstreamState::PendingProcessing { + weak_guard: Arc::downgrade(&guard), + substream, + }; return Poll::Ready(Some(ConnectionHandlerEvent::Custom( KademliaHandlerEvent::AddProvider { @@ -1145,11 +1142,22 @@ where return Poll::Pending; } - InboundSubstreamState::PendingStateTransition { ready, next_state } => { - *this = if ready.load(Ordering::Acquire) { - *next_state + InboundSubstreamState::PendingProcessing { + weak_guard, + substream, + } => { + *this = if let Some(guard) = weak_guard.upgrade() { + let old_waker = guard.waker.lock().replace(cx.waker().clone()); + if old_waker.is_none() || guard.ready.load(Ordering::Acquire) { + return Poll::Ready(None); + } else { + InboundSubstreamState::PendingProcessing { + weak_guard, + substream, + } + } } else { - InboundSubstreamState::PendingStateTransition { ready, next_state } + return Poll::Ready(None); }; return Poll::Pending;