diff --git a/util/network/src/host.rs b/util/network/src/host.rs index 887b7ade0e3..f75158e4a4d 100644 --- a/util/network/src/host.rs +++ b/util/network/src/host.rs @@ -868,6 +868,12 @@ impl Host { let reserved = self.reserved_nodes.read(); if let Some(h) = handlers.get(&p).clone() { h.connected(&NetworkContext::new(io, p, session.clone(), self.sessions.clone(), &reserved), &token); + + // accumulate pending packets. + if let Some(session) = session.as_ref() { + let mut session = session.lock(); + packet_data.extend(session.mark_connected(p)); + } } } for (p, packet_id, data) in packet_data { diff --git a/util/network/src/session.rs b/util/network/src/session.rs index a1336a0ad38..9c8bed9da73 100644 --- a/util/network/src/session.rs +++ b/util/network/src/session.rs @@ -18,6 +18,8 @@ use std::{str, io}; use std::net::SocketAddr; use std::cmp::Ordering; use std::sync::*; +use std::collections::HashMap; + use mio::*; use mio::deprecated::{Handler, EventLoop}; use mio::tcp::*; @@ -36,6 +38,14 @@ use time; const PING_TIMEOUT_SEC: u64 = 15; const PING_INTERVAL_SEC: u64 = 30; +#[derive(Debug, Clone)] +enum ProtocolState { + // Packets pending protocol on_connect event return. + Pending(Vec<(Vec, u8)>), + // Protocol connected. + Connected, +} + /// Peer session over encrypted connection. /// When created waits for Hello packet exchange and signals ready state. /// Sends and receives protocol packets and handles basic packes such as ping/pong and disconnect. @@ -49,6 +59,8 @@ pub struct Session { ping_time_ns: u64, pong_time_ns: Option, state: State, + // Protocol states -- accumulates pending packets until signaled as ready. + protocol_states: HashMap, } enum State { @@ -186,6 +198,7 @@ impl Session { ping_time_ns: 0, pong_time_ns: None, expired: false, + protocol_states: HashMap::new(), }) } @@ -361,6 +374,20 @@ impl Session { self.connection().token() } + /// Signal that a subprotocol has handled the connection successfully and + /// get all pending packets in order received. + pub fn mark_connected(&mut self, protocol: ProtocolId) -> Vec<(ProtocolId, u8, Vec)> { + match self.protocol_states.insert(protocol, ProtocolState::Connected) { + None => Vec::new(), + Some(ProtocolState::Connected) => { + debug!(target: "network", "Protocol {:?} marked as connected more than once", protocol); + Vec::new() + } + Some(ProtocolState::Pending(pending)) => + pending.into_iter().map(|(data, id)| (protocol, id, data)).collect(), + } + } + fn read_packet(&mut self, io: &IoContext, packet: Packet, host: &HostInfo) -> Result where Message: Send + Sync + Clone { if packet.data.len() < 2 { @@ -409,8 +436,19 @@ impl Session { // map to protocol let protocol = self.info.capabilities[i].protocol; let pid = packet_id - self.info.capabilities[i].id_offset; - trace!(target: "network", "Packet {} mapped to {:?}:{}, i={}, capabilities={:?}", packet_id, protocol, pid, i, self.info.capabilities); - Ok(SessionData::Packet { data: packet.data, protocol: protocol, packet_id: pid } ) + + match *self.protocol_states.entry(protocol).or_insert_with(|| ProtocolState::Pending(Vec::new())) { + ProtocolState::Connected => { + trace!(target: "network", "Packet {} mapped to {:?}:{}, i={}, capabilities={:?}", packet_id, protocol, pid, i, self.info.capabilities); + Ok(SessionData::Packet { data: packet.data, protocol: protocol, packet_id: pid } ) + } + ProtocolState::Pending(ref mut pending) => { + trace!(target: "network", "Packet {} deferred until protocol connection event completion", packet_id); + pending.push((packet.data, packet_id)); + + Ok(SessionData::Continue) + } + } }, _ => { debug!(target: "network", "Unknown packet: {:?}", packet_id);