From 903d1b0477a7dbc82681b9d83349fe9eb52ceb38 Mon Sep 17 00:00:00 2001 From: Or Neeman Date: Tue, 27 Oct 2020 16:12:09 -0600 Subject: [PATCH] Cherry pick new dial scheduler from upstream (#1192) * celo-blockchain cherry-pick from upstream: p2p: new dial scheduler (#20592) Conflicts: p2p/dial.go p2p/server.go p2p/server_test.go * p2p: new dial scheduler This change replaces the peer-to-peer dial scheduler with a new and improved implementation. The new code is better than the previous implementation in two key aspects: - The time between discovery of a node and dialing that node is significantly lower in the new version. The old dialState kept a buffer of nodes and launched a task to refill it whenever the buffer became empty. This worked well with the discovery interface we used to have, but doesn't really work with the new iterator-based discovery API. - Selection of static dial candidates (created by Server.AddPeer or through static-nodes.json) performs much better for large amounts of static peers. Connections to static nodes are now limited like dynanic dials and can no longer overstep MaxPeers or the dial ratio. * p2p/simulations/adapters: adapt to new NodeDialer interface * p2p: re-add check for self in checkDial * p2p: remove peersetCh * p2p: allow static dials when discovery is disabled * p2p: add test for dialScheduler.removeStatic * p2p: remove blank line * p2p: fix documentation of maxDialPeers * p2p: change "ok" to "added" in static node log * p2p: improve dialTask docs Also increase log level for "Can't resolve node" * p2p: ensure dial resolver is truly nil without discovery * p2p: add "looking for peers" log message * p2p: clean up Server.run comments * p2p: fix maxDialedConns for maxpeers < dialRatio Always allocate at least one dial slot unless dialing is disabled using NoDial or MaxPeers == 0. Most importantly, this fixes MaxPeers == 1 to dedicate the sole slot to dialing instead of listening. * p2p: fix RemovePeer to disconnect the peer again Also make RemovePeer synchronous and add a test. * p2p: remove "Connection set up" log message * p2p: clean up connection logging We previously logged outgoing connection failures up to three times. - in SetupConn() as "Setting up connection failed addr=..." - in setupConn() with an error-specific message and "id=... addr=..." - in dial() as "Dial error task=..." This commit ensures a single log message is emitted per failure and adds "id=... addr=... conn=..." everywhere (id= omitted when the ID isn't known yet). Also avoid printing a log message when a static dial fails but can't be resolved because discv4 is disabled. The light client hit this case all the time, increasing the message count to four lines per failed connection. * p2p: document that RemovePeer blocks * les: add bootstrap nodes as initial discoveries (#20688) * celo-blockchain addition: make static dials exempt from the maxDialPeers limit Static peers are ones we explicitly want to connect to, and so should not be subject to this limit. This is especially important since connections between validators/proxies and other validators/proxies rely on this being so, and because that's how it was prior to the new dial scheduler from upstream. * celo-blockchain addition: Add tests for two untested RemovePeer() scenarios Co-authored-by: Felix Lange --- les/serverpool.go | 13 + p2p/dial.go | 597 ++++++++++------ p2p/dial_test.go | 1059 +++++++++++++++------------- p2p/peer_test.go | 44 +- p2p/server.go | 290 ++++---- p2p/server_test.go | 261 +++---- p2p/simulations/adapters/inproc.go | 3 +- p2p/util.go | 20 +- p2p/util_test.go | 12 +- 9 files changed, 1272 insertions(+), 1027 deletions(-) diff --git a/les/serverpool.go b/les/serverpool.go index d5c6e55d6b5f..2e4c48064dbc 100644 --- a/les/serverpool.go +++ b/les/serverpool.go @@ -179,6 +179,19 @@ func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) { pool.checkDial() pool.wg.Add(1) go pool.eventLoop() + + // Inject the bootstrap nodes as initial dial candiates. + pool.wg.Add(1) + go func() { + defer pool.wg.Done() + for _, n := range server.BootstrapNodes { + select { + case pool.discNodes <- n: + case <-pool.closeCh: + return + } + } + }() } func (pool *serverPool) stop() { diff --git a/p2p/dial.go b/p2p/dial.go index 7fed0a5abf51..4f97cf1ef2e3 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -17,11 +17,17 @@ package p2p import ( + "context" + crand "crypto/rand" + "encoding/binary" "errors" "fmt" + mrand "math/rand" "net" + "sync" "time" + "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/netutil" @@ -33,8 +39,9 @@ const ( // private networks. dialHistoryExpiration = inboundThrottleTime + 5*time.Second - // If no peers are found for this amount of time, the initial bootnodes are dialed. - fallbackInterval = 20 * time.Second + // Config for the "Looking for peers" message. + dialStatsLogInterval = 10 * time.Second // printed at most this often + dialStatsPeerLimit = 3 // but not if more than this many dialed peers // Endpoint resolution is throttled with bounded backoff. initialResolveDelay = 60 * time.Second @@ -42,224 +49,445 @@ const ( ) // NodeDialer is used to connect to nodes in the network, typically by using -// an underlying net.Dialer but also using net.Pipe in tests +// an underlying net.Dialer but also using net.Pipe in tests. type NodeDialer interface { - Dial(*enode.Node) (net.Conn, error) + Dial(context.Context, *enode.Node) (net.Conn, error) } type nodeResolver interface { Resolve(*enode.Node) *enode.Node } -// TCPDialer implements the NodeDialer interface by using a net.Dialer to -// create TCP connections to nodes in the network -type TCPDialer struct { - *net.Dialer +// tcpDialer implements NodeDialer using real TCP connections. +type tcpDialer struct { + d *net.Dialer } -// Dial creates a TCP connection to the node -func (t TCPDialer) Dial(dest *enode.Node) (net.Conn, error) { - addr := &net.TCPAddr{IP: dest.IP(), Port: dest.TCP()} - return t.Dialer.Dial("tcp", addr.String()) +func (t tcpDialer) Dial(ctx context.Context, dest *enode.Node) (net.Conn, error) { + return t.d.DialContext(ctx, "tcp", nodeAddr(dest).String()) } -// dialstate schedules dials and discovery lookups. -// It gets a chance to compute new tasks on every iteration -// of the main loop in Server.run. -type dialstate struct { - maxDynDials int - netrestrict *netutil.Netlist - self enode.ID - bootnodes []*enode.Node // default dials when there are no peers - log log.Logger +func nodeAddr(n *enode.Node) net.Addr { + return &net.TCPAddr{IP: n.IP(), Port: n.TCP()} +} + +// checkDial errors: +var ( + errSelf = errors.New("is self") + errAlreadyDialing = errors.New("already dialing") + errAlreadyConnected = errors.New("already connected") + errRecentlyDialed = errors.New("recently dialed") + errNotWhitelisted = errors.New("not contained in netrestrict whitelist") +) - start time.Time // time when the dialer was first used - lookupRunning bool - dialing map[enode.ID]connFlag - lookupBuf []*enode.Node // current discovery lookup results - static map[enode.ID]*dialTask - hist expHeap +// dialer creates outbound connections and submits them into Server. +// Two types of peer connections can be created: +// +// - static dials are pre-configured connections. The dialer attempts +// keep these nodes connected at all times. +// +// - dynamic dials are created from node discovery results. The dialer +// continuously reads candidate nodes from its input iterator and attempts +// to create peer connections to nodes arriving through the iterator. +// +type dialScheduler struct { + dialConfig + setupFunc dialSetupFunc + wg sync.WaitGroup + cancel context.CancelFunc + ctx context.Context + nodesIn chan *enode.Node + doneCh chan *dialTask + addStaticCh chan *enode.Node + remStaticCh chan *enode.Node + addPeerCh chan *conn + remPeerCh chan *conn + + // Everything below here belongs to loop and + // should only be accessed by code on the loop goroutine. + dialing map[enode.ID]*dialTask // active tasks + peers map[enode.ID]connFlag // all connected peers + dialPeers int // current number of dialed peers + + // The static map tracks all static dial tasks. The subset of usable static dial tasks + // (i.e. those passing checkDial) is kept in staticPool. The scheduler prefers + // launching static tasks from the pool over launching dynamic dials from the + // iterator. + static map[enode.ID]*dialTask + staticPool []*dialTask + + // The dial history keeps recently dialed nodes. Members of history are not dialed. + history expHeap + historyTimer mclock.Timer + historyTimerTime mclock.AbsTime + + // for logStats + lastStatsLog mclock.AbsTime + doneSinceLastLog int } -type task interface { - Do(*Server) +type dialSetupFunc func(net.Conn, connFlag, *enode.Node) error + +type dialConfig struct { + self enode.ID // our own ID + maxDialPeers int // maximum number of dialed peers + maxActiveDials int // maximum number of active dials + netRestrict *netutil.Netlist // IP whitelist, disabled if nil + resolver nodeResolver + dialer NodeDialer + log log.Logger + clock mclock.Clock + rand *mrand.Rand } -func newDialState(self enode.ID, maxdyn int, cfg *Config) *dialstate { - s := &dialstate{ - maxDynDials: maxdyn, - self: self, - netrestrict: cfg.NetRestrict, - log: cfg.Logger, - static: make(map[enode.ID]*dialTask), - dialing: make(map[enode.ID]connFlag), - bootnodes: make([]*enode.Node, len(cfg.BootstrapNodes)), +func (cfg dialConfig) withDefaults() dialConfig { + if cfg.maxActiveDials == 0 { + cfg.maxActiveDials = defaultMaxPendingPeers + } + if cfg.log == nil { + cfg.log = log.Root() } - copy(s.bootnodes, cfg.BootstrapNodes) - if s.log == nil { - s.log = log.Root() + if cfg.clock == nil { + cfg.clock = mclock.System{} } - for _, n := range cfg.StaticNodes { - s.addStatic(n) + if cfg.rand == nil { + seedb := make([]byte, 8) + crand.Read(seedb) + seed := int64(binary.BigEndian.Uint64(seedb)) + cfg.rand = mrand.New(mrand.NewSource(seed)) } - return s + return cfg } -func (s *dialstate) addStatic(n *enode.Node) { - // This overwrites the task instead of updating an existing - // entry, giving users the opportunity to force a resolve operation. - s.static[n.ID()] = &dialTask{flags: staticDialedConn, dest: n} +func newDialScheduler(config dialConfig, it enode.Iterator, setupFunc dialSetupFunc) *dialScheduler { + d := &dialScheduler{ + dialConfig: config.withDefaults(), + setupFunc: setupFunc, + dialing: make(map[enode.ID]*dialTask), + static: make(map[enode.ID]*dialTask), + peers: make(map[enode.ID]connFlag), + doneCh: make(chan *dialTask), + nodesIn: make(chan *enode.Node), + addStaticCh: make(chan *enode.Node), + remStaticCh: make(chan *enode.Node), + addPeerCh: make(chan *conn), + remPeerCh: make(chan *conn), + } + d.lastStatsLog = d.clock.Now() + d.ctx, d.cancel = context.WithCancel(context.Background()) + d.wg.Add(2) + go d.readNodes(it) + go d.loop(it) + return d } -func (s *dialstate) removeStatic(n *enode.Node) { - // This removes a task so future attempts to connect will not be made. - delete(s.static, n.ID()) +// stop shuts down the dialer, canceling all current dial tasks. +func (d *dialScheduler) stop() { + d.cancel() + d.wg.Wait() } -func (s *dialstate) isStatic(n *enode.Node) bool { - _, isStatic := s.static[n.ID()] - return isStatic +// addStatic adds a static dial candidate. +func (d *dialScheduler) addStatic(n *enode.Node) { + select { + case d.addStaticCh <- n: + case <-d.ctx.Done(): + } } -func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task { - var newtasks []task - addDial := func(flag connFlag, n *enode.Node) bool { - if err := s.checkDial(n, peers); err != nil { - s.log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err) - return false - } - s.dialing[n.ID()] = flag - newtasks = append(newtasks, &dialTask{flags: flag, dest: n}) - return true - } - - if s.start.IsZero() { - s.start = now - } - s.hist.expire(now) - - // Create dials for static nodes if they are not connected. - for id, t := range s.static { - err := s.checkDial(t.dest, peers) - switch err { - case errNotWhitelisted, errSelf: - s.log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err) - delete(s.static, t.dest.ID()) - case nil: - s.dialing[id] = t.flags - newtasks = append(newtasks, t) - } +// removeStatic removes a static dial candidate. +func (d *dialScheduler) removeStatic(n *enode.Node) { + select { + case d.remStaticCh <- n: + case <-d.ctx.Done(): } +} - // Compute number of dynamic dials needed. - needDynDials := s.maxDynDials - for _, p := range peers { - if p.rw.is(dynDialedConn) { - needDynDials-- - } +// peerAdded updates the peer set. +func (d *dialScheduler) peerAdded(c *conn) { + select { + case d.addPeerCh <- c: + case <-d.ctx.Done(): } - for _, flag := range s.dialing { - if flag&dynDialedConn != 0 { - needDynDials-- - } +} + +// peerRemoved updates the peer set. +func (d *dialScheduler) peerRemoved(c *conn) { + select { + case d.remPeerCh <- c: + case <-d.ctx.Done(): } +} + +// loop is the main loop of the dialer. +func (d *dialScheduler) loop(it enode.Iterator) { + var ( + nodesCh chan *enode.Node + historyExp = make(chan struct{}, 1) + ) + +loop: + for { + // Launch new dials if slots are available. + slots := d.freeDialSlots() + slots -= d.startStaticDials() + if slots > 0 { + nodesCh = d.nodesIn + } else { + nodesCh = nil + } + d.rearmHistoryTimer(historyExp) + d.logStats() + + select { + case node := <-nodesCh: + if err := d.checkDial(node); err != nil { + d.log.Trace("Discarding dial candidate", "id", node.ID(), "ip", node.IP(), "reason", err) + } else { + d.startDial(newDialTask(node, dynDialedConn)) + } + + case task := <-d.doneCh: + id := task.dest.ID() + delete(d.dialing, id) + d.updateStaticPool(id) + d.doneSinceLastLog++ + + case c := <-d.addPeerCh: + if c.is(dynDialedConn) || c.is(staticDialedConn) { + d.dialPeers++ + } + id := c.node.ID() + d.peers[id] = c.flags + // Remove from static pool because the node is now connected. + task := d.static[id] + if task != nil && task.staticPoolIndex >= 0 { + d.removeFromStaticPool(task.staticPoolIndex) + } + // TODO: cancel dials to connected peers + + case c := <-d.remPeerCh: + if c.is(dynDialedConn) || c.is(staticDialedConn) { + d.dialPeers-- + } + delete(d.peers, c.node.ID()) + d.updateStaticPool(c.node.ID()) + + case node := <-d.addStaticCh: + id := node.ID() + _, exists := d.static[id] + d.log.Trace("Adding static node", "id", id, "ip", node.IP(), "added", !exists) + if exists { + continue loop + } + task := newDialTask(node, staticDialedConn) + d.static[id] = task + if d.checkDial(node) == nil { + d.addToStaticPool(task) + } - // If we don't have any peers whatsoever, try to dial a random bootnode. This - // scenario is useful for the testnet (and private networks) where the discovery - // table might be full of mostly bad peers, making it hard to find good ones. - if len(peers) == 0 && len(s.bootnodes) > 0 && needDynDials > 0 && now.Sub(s.start) > fallbackInterval { - bootnode := s.bootnodes[0] - s.bootnodes = append(s.bootnodes[:0], s.bootnodes[1:]...) - s.bootnodes = append(s.bootnodes, bootnode) - if addDial(dynDialedConn, bootnode) { - needDynDials-- + case node := <-d.remStaticCh: + id := node.ID() + task := d.static[id] + d.log.Trace("Removing static node", "id", id, "ok", task != nil) + if task != nil { + delete(d.static, id) + if task.staticPoolIndex >= 0 { + d.removeFromStaticPool(task.staticPoolIndex) + } + } + + case <-historyExp: + d.expireHistory() + + case <-d.ctx.Done(): + it.Close() + break loop } } - // Create dynamic dials from discovery results. - i := 0 - for ; i < len(s.lookupBuf) && needDynDials > 0; i++ { - if addDial(dynDialedConn, s.lookupBuf[i]) { - needDynDials-- + d.stopHistoryTimer(historyExp) + for range d.dialing { + <-d.doneCh + } + d.wg.Done() +} + +// readNodes runs in its own goroutine and delivers nodes from +// the input iterator to the nodesIn channel. +func (d *dialScheduler) readNodes(it enode.Iterator) { + defer d.wg.Done() + + for it.Next() { + discoveredPeersCounter.Inc(1) + select { + case d.nodesIn <- it.Node(): + case <-d.ctx.Done(): } } - s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])] +} - // Launch a discovery lookup if more candidates are needed. - if len(s.lookupBuf) < needDynDials && !s.lookupRunning { - s.lookupRunning = true - newtasks = append(newtasks, &discoverTask{want: needDynDials - len(s.lookupBuf)}) +// logStats prints dialer statistics to the log. The message is suppressed when enough +// peers are connected because users should only see it while their client is starting up +// or comes back online. +func (d *dialScheduler) logStats() { + now := d.clock.Now() + if d.lastStatsLog.Add(dialStatsLogInterval) > now { + return + } + if d.dialPeers < dialStatsPeerLimit && d.dialPeers < d.maxDialPeers { + d.log.Info("Looking for peers", "peercount", len(d.peers), "tried", d.doneSinceLastLog, "static", len(d.static)) } + d.doneSinceLastLog = 0 + d.lastStatsLog = now +} - // Launch a timer to wait for the next node to expire if all - // candidates have been tried and no task is currently active. - // This should prevent cases where the dialer logic is not ticked - // because there are no pending events. - if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 { - t := &waitExpireTask{s.hist.nextExpiry().Sub(now)} - newtasks = append(newtasks, t) +// rearmHistoryTimer configures d.historyTimer to fire when the +// next item in d.history expires. +func (d *dialScheduler) rearmHistoryTimer(ch chan struct{}) { + if len(d.history) == 0 || d.historyTimerTime == d.history.nextExpiry() { + return } - return newtasks + d.stopHistoryTimer(ch) + d.historyTimerTime = d.history.nextExpiry() + timeout := time.Duration(d.historyTimerTime - d.clock.Now()) + d.historyTimer = d.clock.AfterFunc(timeout, func() { ch <- struct{}{} }) } -var ( - errSelf = errors.New("is self") - errAlreadyDialing = errors.New("already dialing") - errAlreadyConnected = errors.New("already connected") - errRecentlyDialed = errors.New("recently dialed") - errNotWhitelisted = errors.New("not contained in netrestrict whitelist") -) +// stopHistoryTimer stops the timer and drains the channel it sends on. +func (d *dialScheduler) stopHistoryTimer(ch chan struct{}) { + if d.historyTimer != nil && !d.historyTimer.Stop() { + <-ch + } +} + +// expireHistory removes expired items from d.history. +func (d *dialScheduler) expireHistory() { + d.historyTimer.Stop() + d.historyTimer = nil + d.historyTimerTime = 0 + d.history.expire(d.clock.Now(), func(hkey string) { + var id enode.ID + copy(id[:], hkey) + d.updateStaticPool(id) + }) +} + +// freeDialSlots returns the number of free dial slots. The result can be negative +// when peers are connected while their task is still running, or because static dials +// are exempt from the limit. +func (d *dialScheduler) freeDialSlots() int { + slots := (d.maxDialPeers - d.dialPeers) * 2 + if slots > d.maxActiveDials { + slots = d.maxActiveDials + } + free := slots - len(d.dialing) + return free +} -func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error { - _, dialing := s.dialing[n.ID()] - switch { - case dialing: +// checkDial returns an error if node n should not be dialed. +func (d *dialScheduler) checkDial(n *enode.Node) error { + if n.ID() == d.self { + return errSelf + } + if _, ok := d.dialing[n.ID()]; ok { return errAlreadyDialing - case peers[n.ID()] != nil: + } + if _, ok := d.peers[n.ID()]; ok { return errAlreadyConnected - case n.ID() == s.self: - return errSelf - case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()): + } + if d.netRestrict != nil && !d.netRestrict.Contains(n.IP()) { return errNotWhitelisted - case s.hist.contains(string(n.ID().Bytes())): + } + if d.history.contains(string(n.ID().Bytes())) { return errRecentlyDialed } return nil } -func (s *dialstate) taskDone(t task, now time.Time) { - switch t := t.(type) { - case *dialTask: - s.hist.add(string(t.dest.ID().Bytes()), now.Add(dialHistoryExpiration)) - delete(s.dialing, t.dest.ID()) - case *discoverTask: - s.lookupRunning = false - s.lookupBuf = append(s.lookupBuf, t.results...) +// startStaticDials starts static dials for all nodes in the static pool. +func (d *dialScheduler) startStaticDials() (started int) { + for started = 0; len(d.staticPool) > 0; started++ { + idx := 0 + task := d.staticPool[idx] + d.startDial(task) + d.removeFromStaticPool(idx) + } + return started +} + +// updateStaticPool attempts to move the given static dial back into staticPool. +func (d *dialScheduler) updateStaticPool(id enode.ID) { + task, ok := d.static[id] + if ok && task.staticPoolIndex < 0 && d.checkDial(task.dest) == nil { + d.addToStaticPool(task) + } +} + +func (d *dialScheduler) addToStaticPool(task *dialTask) { + if task.staticPoolIndex >= 0 { + panic("attempt to add task to staticPool twice") } + d.staticPool = append(d.staticPool, task) + task.staticPoolIndex = len(d.staticPool) - 1 } -// A dialTask is generated for each node that is dialed. Its -// fields cannot be accessed while the task is running. +// removeFromStaticPool removes the task at idx from staticPool. It does that by moving the +// current last element of the pool to idx and then shortening the pool by one. +func (d *dialScheduler) removeFromStaticPool(idx int) { + task := d.staticPool[idx] + end := len(d.staticPool) - 1 + d.staticPool[idx] = d.staticPool[end] + d.staticPool[idx].staticPoolIndex = idx + d.staticPool[end] = nil + d.staticPool = d.staticPool[:end] + task.staticPoolIndex = -1 +} + +// startDial runs the given dial task in a separate goroutine. +func (d *dialScheduler) startDial(task *dialTask) { + d.log.Trace("Starting p2p dial", "id", task.dest.ID(), "ip", task.dest.IP(), "flag", task.flags) + hkey := string(task.dest.ID().Bytes()) + d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration)) + d.dialing[task.dest.ID()] = task + go func() { + task.run(d) + d.doneCh <- task + }() +} + +// A dialTask generated for each node that is dialed. type dialTask struct { - flags connFlag + staticPoolIndex int + flags connFlag + // These fields are private to the task and should not be + // accessed by dialScheduler while the task is running. dest *enode.Node - lastResolved time.Time + lastResolved mclock.AbsTime resolveDelay time.Duration } -func (t *dialTask) Do(srv *Server) { +func newDialTask(dest *enode.Node, flags connFlag) *dialTask { + return &dialTask{dest: dest, flags: flags, staticPoolIndex: -1} +} + +type dialError struct { + error +} + +func (t *dialTask) run(d *dialScheduler) { if t.dest.Incomplete() { - if !t.resolve(srv) { + if !t.resolve(d) { return } } - err := t.dial(srv, t.dest) + + err := t.dial(d, t.dest) if err != nil { - srv.log.Trace("Dial error", "task", t, "err", err) // Try resolving the ID of static nodes if dialing failed. if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 { - if t.resolve(srv) { - t.dial(srv, t.dest) + if t.resolve(d) { + t.dial(d, t.dest) } } } @@ -271,46 +499,42 @@ func (t *dialTask) Do(srv *Server) { // Resolve operations are throttled with backoff to avoid flooding the // discovery network with useless queries for nodes that don't exist. // The backoff delay resets when the node is found. -func (t *dialTask) resolve(srv *Server) bool { - if srv.staticNodeResolver == nil { - srv.log.Debug("Can't resolve node", "id", t.dest.ID(), "err", "discovery is disabled") +func (t *dialTask) resolve(d *dialScheduler) bool { + if d.resolver == nil { return false } if t.resolveDelay == 0 { t.resolveDelay = initialResolveDelay } - if time.Since(t.lastResolved) < t.resolveDelay { + if t.lastResolved > 0 && time.Duration(d.clock.Now()-t.lastResolved) < t.resolveDelay { return false } - resolved := srv.staticNodeResolver.Resolve(t.dest) - t.lastResolved = time.Now() + resolved := d.resolver.Resolve(t.dest) + t.lastResolved = d.clock.Now() if resolved == nil { t.resolveDelay *= 2 if t.resolveDelay > maxResolveDelay { t.resolveDelay = maxResolveDelay } - srv.log.Debug("Resolving node failed", "id", t.dest.ID(), "newdelay", t.resolveDelay) + d.log.Debug("Resolving node failed", "id", t.dest.ID(), "newdelay", t.resolveDelay) return false } // The node was found. t.resolveDelay = initialResolveDelay t.dest = resolved - srv.log.Debug("Resolved node", "id", t.dest.ID(), "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) + d.log.Debug("Resolved node", "id", t.dest.ID(), "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) return true } -type dialError struct { - error -} - // dial performs the actual connection attempt. -func (t *dialTask) dial(srv *Server, dest *enode.Node) error { - fd, err := srv.Dialer.Dial(dest) +func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error { + fd, err := d.dialer.Dial(d.ctx, t.dest) if err != nil { + d.log.Trace("Dial error", "id", t.dest.ID(), "addr", nodeAddr(t.dest), "conn", t.flags, "err", cleanupDialErr(err)) return &dialError{err} } mfd := newMeteredConn(fd, false, &net.TCPAddr{IP: dest.IP(), Port: dest.TCP()}) - return srv.SetupConn(mfd, t.flags, dest) + return d.setupFunc(mfd, t.flags, dest) } func (t *dialTask) String() string { @@ -318,38 +542,9 @@ func (t *dialTask) String() string { return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], t.dest.IP(), t.dest.TCP()) } -// discoverTask runs discovery table operations. -// Only one discoverTask is active at any time. -// discoverTask.Do performs a random lookup. -type discoverTask struct { - want int - results []*enode.Node -} - -func (t *discoverTask) Do(srv *Server) { - t.results = enode.ReadNodes(srv.discmix, t.want) - discoveredPeersCounter.Inc(int64(len(t.results))) -} - -func (t *discoverTask) String() string { - s := "discovery query" - if len(t.results) > 0 { - s += fmt.Sprintf(" (%d results)", len(t.results)) - } else { - s += fmt.Sprintf(" (want %d)", t.want) +func cleanupDialErr(err error) error { + if netErr, ok := err.(*net.OpError); ok && netErr.Op == "dial" { + return netErr.Err } - return s -} - -// A waitExpireTask is generated if there are no other tasks -// to keep the loop in Server.run ticking. -type waitExpireTask struct { - time.Duration -} - -func (t waitExpireTask) Do(*Server) { - time.Sleep(t.Duration) -} -func (t waitExpireTask) String() string { - return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration) + return err } diff --git a/p2p/dial_test.go b/p2p/dial_test.go index 6189ec4d0b85..b25677e3f767 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -17,574 +17,627 @@ package p2p import ( - "encoding/binary" + "context" + "errors" + "fmt" + "math/rand" "net" "reflect" - "strings" + "sync" "testing" "time" - "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/internal/testlog" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/netutil" ) -func init() { - spew.Config.Indent = "\t" -} - -type dialtest struct { - init *dialstate // state before and after the test. - rounds []round -} - -type round struct { - peers []*Peer // current peer set - done []task // tasks that got done this round - new []task // the result must match this one -} +// This test checks that dynamic dials are launched from discovery results. +func TestDialSchedDynDial(t *testing.T) { + t.Parallel() -func runDialTest(t *testing.T, test dialtest) { - var ( - vtime time.Time - running int - ) - pm := func(ps []*Peer) map[enode.ID]*Peer { - m := make(map[enode.ID]*Peer) - for _, p := range ps { - m[p.ID()] = p - } - return m + config := dialConfig{ + maxActiveDials: 5, + maxDialPeers: 4, } - for i, round := range test.rounds { - for _, task := range round.done { - running-- - if running < 0 { - panic("running task counter underflow") - } - test.init.taskDone(task, vtime) - } + runDialTest(t, config, []dialTestRound{ + // 3 out of 4 peers are connected, leaving 2 dial slots. + // 9 nodes are discovered, but only 2 are dialed. + { + update: func(d *dialScheduler) { + d.addStatic(newNode(uintID(0x01), "127.0.0.1:30303")) + }, + peersAdded: []*conn{ + {flags: staticDialedConn, node: newNode(uintID(0x01), "")}, + {flags: dynDialedConn, node: newNode(uintID(0x02), "")}, + {flags: dynDialedConn, node: newNode(uintID(0x03), "")}, + }, + discovered: []*enode.Node{ + newNode(uintID(0x01), "127.0.0.1:30303"), // not dialed because already connected as static peer + newNode(uintID(0x03), "127.0.0.1:30303"), // ... + newNode(uintID(0x04), "127.0.0.1:30303"), + newNode(uintID(0x05), "127.0.0.1:30303"), + newNode(uintID(0x06), "127.0.0.1:30303"), // not dialed because there are only two slots + newNode(uintID(0x07), "127.0.0.1:30303"), // ... + newNode(uintID(0x08), "127.0.0.1:30303"), // ... + newNode(uintID(0x09), "127.0.0.1:30303"), // ... + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x04), "127.0.0.1:30303"), + newNode(uintID(0x05), "127.0.0.1:30303"), + }, + }, - new := test.init.newTasks(running, pm(round.peers), vtime) - if !sametasks(new, round.new) { - t.Errorf("ERROR round %d: got %v\nwant %v\nstate: %v\nrunning: %v", - i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running)) - } - t.Logf("round %d (running %d) new tasks: %s", i, running, strings.TrimSpace(spew.Sdump(new))) + // One dial completes, freeing one dial slot. + { + failed: []enode.ID{ + uintID(0x05), + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x06), "127.0.0.1:30303"), + }, + }, - // Time advances by 16 seconds on every round. - vtime = vtime.Add(16 * time.Second) - running += len(new) - } -} + // Dial to 0x03 completes, filling the last remaining peer slot. + { + succeeded: []enode.ID{ + uintID(0x04), + }, + failed: []enode.ID{ + uintID(0x06), + }, + discovered: []*enode.Node{ + newNode(uintID(0x0a), "127.0.0.1:30303"), // not dialed because there are no free slots + }, + }, -// This test checks that dynamic dials are launched from discovery results. -func TestDialStateDynDial(t *testing.T) { - config := &Config{Logger: testlog.Logger(t, log.LvlTrace)} - runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, 5, config), - rounds: []round{ - // A discovery query is launched. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - }, - new: []task{ - &discoverTask{want: 3}, - }, - }, - // Dynamic dials are launched when it completes. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - }, - done: []task{ - &discoverTask{results: []*enode.Node{ - newNode(uintID(2), nil), // this one is already connected and not dialed. - newNode(uintID(3), nil), - newNode(uintID(4), nil), - newNode(uintID(5), nil), - newNode(uintID(6), nil), // these are not tried because max dyn dials is 5 - newNode(uintID(7), nil), // ... - }}, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - }, - }, - // Some of the dials complete but no new ones are launched yet because - // the sum of active dial count and dynamic peer count is == maxDynDials. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(3), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}}, - }, - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - }, - }, - // No new dial tasks are launched in the this round because - // maxDynDials has been reached. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(3), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}}, - }, - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - }, - new: []task{ - &waitExpireTask{Duration: 19 * time.Second}, - }, - }, - // In this round, the peer with id 2 drops off. The query - // results from last discovery lookup are reused. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(3), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}}, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)}, - }, - }, - // More peers (3,4) drop off and dial for ID 6 completes. - // The last query result from the discovery lookup is reused - // and a new one is spawned because more candidates are needed. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}}, - }, - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)}, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(7), nil)}, - &discoverTask{want: 2}, - }, - }, - // Peer 7 is connected, but there still aren't enough dynamic peers - // (4 out of 5). However, a discovery is already running, so ensure - // no new is started. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(7), nil)}}, - }, - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(7), nil)}, - }, - }, - // Finish the running node discovery with an empty set. A new lookup - // should be immediately requested. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(7), nil)}}, - }, - done: []task{ - &discoverTask{}, - }, - new: []task{ - &discoverTask{want: 2}, - }, + // 3 peers drop off, creating 6 dial slots. Check that 5 of those slots + // (i.e. up to maxActiveDialTasks) are used. + { + peersRemoved: []enode.ID{ + uintID(0x02), + uintID(0x03), + uintID(0x04), + }, + discovered: []*enode.Node{ + newNode(uintID(0x0b), "127.0.0.1:30303"), + newNode(uintID(0x0c), "127.0.0.1:30303"), + newNode(uintID(0x0d), "127.0.0.1:30303"), + newNode(uintID(0x0e), "127.0.0.1:30303"), + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x07), "127.0.0.1:30303"), + newNode(uintID(0x08), "127.0.0.1:30303"), + newNode(uintID(0x09), "127.0.0.1:30303"), + newNode(uintID(0x0a), "127.0.0.1:30303"), + newNode(uintID(0x0b), "127.0.0.1:30303"), }, }, }) } -// Tests that bootnodes are dialed if no peers are connectd, but not otherwise. -func TestDialStateDynDialBootnode(t *testing.T) { - config := &Config{ - BootstrapNodes: []*enode.Node{ - newNode(uintID(1), nil), - newNode(uintID(2), nil), - newNode(uintID(3), nil), - }, - Logger: testlog.Logger(t, log.LvlTrace), +// This test checks that candidates that do not match the netrestrict list are not dialed. +func TestDialSchedNetRestrict(t *testing.T) { + t.Parallel() + + nodes := []*enode.Node{ + newNode(uintID(0x01), "127.0.0.1:30303"), + newNode(uintID(0x02), "127.0.0.2:30303"), + newNode(uintID(0x03), "127.0.0.3:30303"), + newNode(uintID(0x04), "127.0.0.4:30303"), + newNode(uintID(0x05), "127.0.2.5:30303"), + newNode(uintID(0x06), "127.0.2.6:30303"), + newNode(uintID(0x07), "127.0.2.7:30303"), + newNode(uintID(0x08), "127.0.2.8:30303"), } - runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, 5, config), - rounds: []round{ - { - new: []task{ - &discoverTask{want: 5}, - }, - }, - { - done: []task{ - &discoverTask{ - results: []*enode.Node{ - newNode(uintID(4), nil), - newNode(uintID(5), nil), - }, - }, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - &discoverTask{want: 3}, - }, - }, - // No dials succeed, bootnodes still pending fallback interval - {}, - // 1 bootnode attempted as fallback interval was reached - { - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - }, - }, - // No dials succeed, 2nd bootnode is attempted - { - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, - }, - }, - // No dials succeed, 3rd bootnode is attempted - { - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, - }, - }, - // No dials succeed, 1st bootnode is attempted again, expired random nodes retried - { - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, - &discoverTask{results: []*enode.Node{ - newNode(uintID(6), nil), - }}, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)}, - &discoverTask{want: 4}, - }, - }, - // Random dial succeeds, no more bootnodes are attempted - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(6), nil)}}, - }, + config := dialConfig{ + netRestrict: new(netutil.Netlist), + maxActiveDials: 10, + maxDialPeers: 10, + } + config.netRestrict.Add("127.0.2.0/24") + runDialTest(t, config, []dialTestRound{ + { + discovered: nodes, + wantNewDials: nodes[4:8], + }, + { + succeeded: []enode.ID{ + nodes[4].ID(), + nodes[5].ID(), + nodes[6].ID(), + nodes[7].ID(), }, }, }) } -func newNode(id enode.ID, ip net.IP) *enode.Node { - var r enr.Record - if ip != nil { - r.Set(enr.IP(ip)) - } - return enode.SignNull(&r, id) -} +// This test checks that static dials work and are exempt from the limits, but do +// count towards the limit when considering whether to start dynamic dials +func TestDialSchedStaticDial(t *testing.T) { + t.Parallel() -// // This test checks that candidates that do not match the netrestrict list are not dialed. -func TestDialStateNetRestrict(t *testing.T) { - // This table always returns the same random nodes - // in the order given below. - nodes := []*enode.Node{ - newNode(uintID(1), net.ParseIP("127.0.0.1")), - newNode(uintID(2), net.ParseIP("127.0.0.2")), - newNode(uintID(3), net.ParseIP("127.0.0.3")), - newNode(uintID(4), net.ParseIP("127.0.0.4")), - newNode(uintID(5), net.ParseIP("127.0.2.5")), - newNode(uintID(6), net.ParseIP("127.0.2.6")), - newNode(uintID(7), net.ParseIP("127.0.2.7")), - newNode(uintID(8), net.ParseIP("127.0.2.8")), + config := dialConfig{ + maxActiveDials: 5, + maxDialPeers: 4, } - restrict := new(netutil.Netlist) - restrict.Add("127.0.2.0/24") - - runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, 10, &Config{NetRestrict: restrict}), - rounds: []round{ - { - new: []task{ - &discoverTask{want: 10}, - }, - }, - { - done: []task{ - &discoverTask{results: nodes}, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: nodes[4]}, - &dialTask{flags: dynDialedConn, dest: nodes[5]}, - &dialTask{flags: dynDialedConn, dest: nodes[6]}, - &dialTask{flags: dynDialedConn, dest: nodes[7]}, - &discoverTask{want: 6}, - }, + runDialTest(t, config, []dialTestRound{ + // Static dials are launched for the nodes that + // aren't yet connected. + { + peersAdded: []*conn{ + {flags: dynDialedConn, node: newNode(uintID(0x01), "127.0.0.1:30303")}, + {flags: dynDialedConn, node: newNode(uintID(0x02), "127.0.0.2:30303")}, + {flags: dynDialedConn, node: newNode(uintID(0x03), "127.0.0.3:30303")}, + }, + update: func(d *dialScheduler) { + // These three are not dialed because they're already connected + // as dynamic peers. + d.addStatic(newNode(uintID(0x01), "127.0.0.1:30303")) + d.addStatic(newNode(uintID(0x02), "127.0.0.2:30303")) + d.addStatic(newNode(uintID(0x03), "127.0.0.3:30303")) + // These nodes will be dialed, even though only 2 dial slots are available, + // because these are static and are exempt from the limit + d.addStatic(newNode(uintID(0x04), "127.0.0.4:30303")) + d.addStatic(newNode(uintID(0x05), "127.0.0.5:30303")) + d.addStatic(newNode(uintID(0x06), "127.0.0.6:30303")) + }, + discovered: []*enode.Node{ + newNode(uintID(0x07), "127.0.0.7:30303"), + newNode(uintID(0x08), "127.0.0.8:30303"), + newNode(uintID(0x09), "127.0.0.9:30303"), + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x04), "127.0.0.4:30303"), + newNode(uintID(0x05), "127.0.0.5:30303"), + newNode(uintID(0x06), "127.0.0.6:30303"), + }, + }, + // Dial to 0x04 completes, filling the last peer slot + // 0x05 are 0x06 are static, but won't be dialed again yet because they're in the history + { + succeeded: []enode.ID{ + uintID(0x04), + }, + failed: []enode.ID{ + uintID(0x05), + uintID(0x06), + }, + wantResolves: map[enode.ID]*enode.Node{ + uintID(0x05): nil, + uintID(0x06): nil, + }, + }, + // Peer 0x01 drops, freeing two dial slots. They're filled by 0x01 (static) and 0x07 + { + peersRemoved: []enode.ID{ + uintID(0x01), + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x01), "127.0.0.1:30303"), + newNode(uintID(0x07), "127.0.0.7:30303"), + }, + }, + // 0x05 and 0x06 will be dialed now, having expired from the history + { + wantNewDials: []*enode.Node{ + newNode(uintID(0x05), "127.0.0.5:30303"), + newNode(uintID(0x06), "127.0.0.6:30303"), }, }, }) } -// This test checks that static dials are launched. -func TestDialStateStaticDial(t *testing.T) { - config := &Config{ - StaticNodes: []*enode.Node{ - newNode(uintID(1), nil), - newNode(uintID(2), nil), - newNode(uintID(3), nil), - newNode(uintID(4), nil), - newNode(uintID(5), nil), - }, - Logger: testlog.Logger(t, log.LvlTrace), +// This test checks that removing static nodes stops connecting to them. +func TestDialSchedRemoveStatic(t *testing.T) { + t.Parallel() + + config := dialConfig{ + maxActiveDials: 1, + maxDialPeers: 1, } - runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, 0, config), - rounds: []round{ - // Static dials are launched for the nodes that - // aren't yet connected. - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - }, - new: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, - &dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)}, - }, - }, - // No new tasks are launched in this round because all static - // nodes are either connected or still being dialed. - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}}, - }, - done: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, - }, - }, - // No new dial tasks are launched because all static - // nodes are now connected. - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(4), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(5), nil)}}, - }, - done: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)}, - }, - new: []task{ - &waitExpireTask{Duration: 19 * time.Second}, - }, - }, - // Wait a round for dial history to expire, no new tasks should spawn. - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(4), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(5), nil)}}, - }, - }, - // If a static node is dropped, it should be immediately redialed, - // irrespective whether it was originally static or dynamic. - { - done: []task{ - &waitExpireTask{Duration: 19 * time.Second}, - }, - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(5), nil)}}, - }, - new: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)}, - }, + runDialTest(t, config, []dialTestRound{ + // Add static nodes. + { + update: func(d *dialScheduler) { + d.addStatic(newNode(uintID(0x01), "127.0.0.1:30303")) + d.addStatic(newNode(uintID(0x02), "127.0.0.2:30303")) + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x01), "127.0.0.1:30303"), + newNode(uintID(0x02), "127.0.0.2:30303"), + }, + }, + // Dial to 0x01 fails. + { + failed: []enode.ID{ + uintID(0x01), + }, + wantResolves: map[enode.ID]*enode.Node{ + uintID(0x01): nil, + }, + }, + // All static nodes are removed. 0x01 is in history, 0x02 is being dialed + { + update: func(d *dialScheduler) { + d.removeStatic(newNode(uintID(0x01), "127.0.0.1:30303")) + d.removeStatic(newNode(uintID(0x02), "127.0.0.2:30303")) + d.removeStatic(newNode(uintID(0x03), "127.0.0.3:30303")) }, }, + // 0x02 fails and moves to the history, 0x01 is no longer in the history + { + failed: []enode.ID{ + uintID(0x02), + }, + wantResolves: map[enode.ID]*enode.Node{ + uintID(0x02): nil, + }, + }, + // Since all static nodes are removed, they should not be dialed again. + {}, {}, {}, }) } // This test checks that past dials are not retried for some time. -func TestDialStateCache(t *testing.T) { - config := &Config{ - StaticNodes: []*enode.Node{ - newNode(uintID(1), nil), - newNode(uintID(2), nil), - newNode(uintID(3), nil), +func TestDialSchedHistory(t *testing.T) { + t.Parallel() + + config := dialConfig{ + maxActiveDials: 3, + maxDialPeers: 3, + } + runDialTest(t, config, []dialTestRound{ + { + update: func(d *dialScheduler) { + d.addStatic(newNode(uintID(0x01), "127.0.0.1:30303")) + d.addStatic(newNode(uintID(0x02), "127.0.0.2:30303")) + d.addStatic(newNode(uintID(0x03), "127.0.0.3:30303")) + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x01), "127.0.0.1:30303"), + newNode(uintID(0x02), "127.0.0.2:30303"), + newNode(uintID(0x03), "127.0.0.3:30303"), + }, + }, + // No new tasks are launched in this round because all static + // nodes are either connected or still being dialed. + { + succeeded: []enode.ID{ + uintID(0x01), + uintID(0x02), + }, + failed: []enode.ID{ + uintID(0x03), + }, + wantResolves: map[enode.ID]*enode.Node{ + uintID(0x03): nil, + }, }, - Logger: testlog.Logger(t, log.LvlTrace), + // Nothing happens in this round because we're waiting for + // node 0x3's history entry to expire. + {}, + // The cache entry for node 0x03 has expired and is retried. + { + wantNewDials: []*enode.Node{ + newNode(uintID(0x03), "127.0.0.3:30303"), + }, + }, + }) +} + +func TestDialSchedResolve(t *testing.T) { + t.Parallel() + + config := dialConfig{ + maxActiveDials: 1, + maxDialPeers: 1, } - runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, 0, config), - rounds: []round{ - // Static dials are launched for the nodes that - // aren't yet connected. - { - peers: nil, - new: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)}, - &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, - }, - }, - // No new tasks are launched in this round because all static - // nodes are either connected or still being dialed. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, - }, - done: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)}, - }, - }, - // A salvage task is launched to wait for node 3's history - // entry to expire. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, - }, - done: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, - }, - new: []task{ - &waitExpireTask{Duration: 19 * time.Second}, - }, - }, - // Still waiting for node 3's entry to expire in the cache. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, - }, - }, - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, - }, - }, - // The cache entry for node 3 has expired and is retried. - { - done: []task{ - &waitExpireTask{Duration: 19 * time.Second}, - }, - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, - }, - new: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, - }, + node := newNode(uintID(0x01), "") + resolved := newNode(uintID(0x01), "127.0.0.1:30303") + resolved2 := newNode(uintID(0x01), "127.0.0.55:30303") + runDialTest(t, config, []dialTestRound{ + { + update: func(d *dialScheduler) { + d.addStatic(node) + }, + wantResolves: map[enode.ID]*enode.Node{ + uintID(0x01): resolved, + }, + wantNewDials: []*enode.Node{ + resolved, + }, + }, + { + failed: []enode.ID{ + uintID(0x01), + }, + wantResolves: map[enode.ID]*enode.Node{ + uintID(0x01): resolved2, + }, + wantNewDials: []*enode.Node{ + resolved2, }, }, }) } -func TestDialResolve(t *testing.T) { - config := &Config{ - Logger: testlog.Logger(t, log.LvlTrace), - Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}, +// ------- +// Code below here is the framework for the tests above. + +type dialTestRound struct { + peersAdded []*conn + peersRemoved []enode.ID + update func(*dialScheduler) // called at beginning of round + discovered []*enode.Node // newly discovered nodes + succeeded []enode.ID // dials which succeed this round + failed []enode.ID // dials which fail this round + wantResolves map[enode.ID]*enode.Node + wantNewDials []*enode.Node // dials that should be launched in this round +} + +func runDialTest(t *testing.T, config dialConfig, rounds []dialTestRound) { + var ( + clock = new(mclock.Simulated) + iterator = newDialTestIterator() + dialer = newDialTestDialer() + resolver = new(dialTestResolver) + peers = make(map[enode.ID]*conn) + setupCh = make(chan *conn) + ) + + // Override config. + config.clock = clock + config.dialer = dialer + config.resolver = resolver + config.log = testlog.Logger(t, log.LvlTrace) + config.rand = rand.New(rand.NewSource(0x1111)) + + // Set up the dialer. The setup function below runs on the dialTask + // goroutine and adds the peer. + var dialsched *dialScheduler + setup := func(fd net.Conn, f connFlag, node *enode.Node) error { + conn := &conn{flags: f, node: node} + dialsched.peerAdded(conn) + setupCh <- conn + return nil } - resolved := newNode(uintID(1), net.IP{127, 0, 55, 234}) - resolver := &resolveMock{answer: resolved} - state := newDialState(enode.ID{}, 0, config) - - // Check that the task is generated with an incomplete ID. - dest := newNode(uintID(1), nil) - state.addStatic(dest) - tasks := state.newTasks(0, nil, time.Time{}) - if !reflect.DeepEqual(tasks, []task{&dialTask{flags: staticDialedConn, dest: dest}}) { - t.Fatalf("expected dial task, got %#v", tasks) + dialsched = newDialScheduler(config, iterator, setup) + defer dialsched.stop() + + for i, round := range rounds { + // Apply peer set updates. + for _, c := range round.peersAdded { + if peers[c.node.ID()] != nil { + t.Fatalf("round %d: peer %v already connected", i, c.node.ID()) + } + dialsched.peerAdded(c) + peers[c.node.ID()] = c + } + for _, id := range round.peersRemoved { + c := peers[id] + if c == nil { + t.Fatalf("round %d: can't remove non-existent peer %v", i, id) + } + dialsched.peerRemoved(c) + delete(peers, c.node.ID()) + } + + // Init round. + t.Logf("round %d (%d peers)", i, len(peers)) + resolver.setAnswers(round.wantResolves) + if round.update != nil { + round.update(dialsched) + } + iterator.addNodes(round.discovered) + + // Unblock dialTask goroutines. + if err := dialer.completeDials(round.succeeded, nil); err != nil { + t.Fatalf("round %d: %v", i, err) + } + for range round.succeeded { + conn := <-setupCh + peers[conn.node.ID()] = conn + } + if err := dialer.completeDials(round.failed, errors.New("oops")); err != nil { + t.Fatalf("round %d: %v", i, err) + } + + // Wait for new tasks. + if err := dialer.waitForDials(round.wantNewDials); err != nil { + t.Fatalf("round %d: %v", i, err) + } + if !resolver.checkCalls() { + t.Fatalf("unexpected calls to Resolve: %v", resolver.calls) + } + + clock.Run(16 * time.Second) } +} + +// dialTestIterator is the input iterator for dialer tests. This works a bit like a channel +// with infinite buffer: nodes are added to the buffer with addNodes, which unblocks Next +// and returns them from the iterator. +type dialTestIterator struct { + cur *enode.Node + + mu sync.Mutex + buf []*enode.Node + cond *sync.Cond + closed bool +} + +func newDialTestIterator() *dialTestIterator { + it := &dialTestIterator{} + it.cond = sync.NewCond(&it.mu) + return it +} + +// addNodes adds nodes to the iterator buffer and unblocks Next. +func (it *dialTestIterator) addNodes(nodes []*enode.Node) { + it.mu.Lock() + defer it.mu.Unlock() + + it.buf = append(it.buf, nodes...) + it.cond.Signal() +} + +// Node returns the current node. +func (it *dialTestIterator) Node() *enode.Node { + return it.cur +} - // Now run the task, it should resolve the ID once. - srv := &Server{ - Config: *config, - log: config.Logger, - staticNodeResolver: resolver, +// Next moves to the next node. +func (it *dialTestIterator) Next() bool { + it.mu.Lock() + defer it.mu.Unlock() + + it.cur = nil + for len(it.buf) == 0 && !it.closed { + it.cond.Wait() } - tasks[0].Do(srv) - if !reflect.DeepEqual(resolver.calls, []*enode.Node{dest}) { - t.Fatalf("wrong resolve calls, got %v", resolver.calls) + if it.closed { + return false } + it.cur = it.buf[0] + copy(it.buf[:], it.buf[1:]) + it.buf = it.buf[:len(it.buf)-1] + return true +} + +// Close ends the iterator, unblocking Next. +func (it *dialTestIterator) Close() { + it.mu.Lock() + defer it.mu.Unlock() - // Report it as done to the dialer, which should update the static node record. - state.taskDone(tasks[0], time.Now()) - if state.static[uintID(1)].dest != resolved { - t.Fatalf("state.dest not updated") + it.closed = true + it.buf = nil + it.cond.Signal() +} + +// dialTestDialer is the NodeDialer used by runDialTest. +type dialTestDialer struct { + init chan *dialTestReq + blocked map[enode.ID]*dialTestReq +} + +type dialTestReq struct { + n *enode.Node + unblock chan error +} + +func newDialTestDialer() *dialTestDialer { + return &dialTestDialer{ + init: make(chan *dialTestReq), + blocked: make(map[enode.ID]*dialTestReq), } } -// compares task lists but doesn't care about the order. -func sametasks(a, b []task) bool { - if len(a) != len(b) { - return false +// Dial implements NodeDialer. +func (d *dialTestDialer) Dial(ctx context.Context, n *enode.Node) (net.Conn, error) { + req := &dialTestReq{n: n, unblock: make(chan error, 1)} + select { + case d.init <- req: + select { + case err := <-req.unblock: + pipe, _ := net.Pipe() + return pipe, err + case <-ctx.Done(): + return nil, ctx.Err() + } + case <-ctx.Done(): + return nil, ctx.Err() } -next: - for _, ta := range a { - for _, tb := range b { - if reflect.DeepEqual(ta, tb) { - continue next +} + +// waitForDials waits for calls to Dial with the given nodes as argument. +// Those calls will be held blocking until completeDials is called with the same nodes. +func (d *dialTestDialer) waitForDials(nodes []*enode.Node) error { + waitset := make(map[enode.ID]*enode.Node) + for _, n := range nodes { + waitset[n.ID()] = n + } + timeout := time.NewTimer(1 * time.Second) + defer timeout.Stop() + + for len(waitset) > 0 { + select { + case req := <-d.init: + want, ok := waitset[req.n.ID()] + if !ok { + return fmt.Errorf("attempt to dial unexpected node %v", req.n.ID()) + } + if !reflect.DeepEqual(req.n, want) { + return fmt.Errorf("ENR of dialed node %v does not match test", req.n.ID()) + } + delete(waitset, req.n.ID()) + d.blocked[req.n.ID()] = req + case <-timeout.C: + var waitlist []enode.ID + for id := range waitset { + waitlist = append(waitlist, id) } + return fmt.Errorf("timed out waiting for dials to %v", waitlist) } - return false } - return true + + return d.checkUnexpectedDial() +} + +func (d *dialTestDialer) checkUnexpectedDial() error { + select { + case req := <-d.init: + return fmt.Errorf("attempt to dial unexpected node %v", req.n.ID()) + case <-time.After(150 * time.Millisecond): + return nil + } +} + +// completeDials unblocks calls to Dial for the given nodes. +func (d *dialTestDialer) completeDials(ids []enode.ID, err error) error { + for _, id := range ids { + req := d.blocked[id] + if req == nil { + return fmt.Errorf("can't complete dial to %v", id) + } + req.unblock <- err + } + return nil } -func uintID(i uint32) enode.ID { - var id enode.ID - binary.BigEndian.PutUint32(id[:], i) - return id +// dialTestResolver tracks calls to resolve. +type dialTestResolver struct { + mu sync.Mutex + calls []enode.ID + answers map[enode.ID]*enode.Node } -// for TestDialResolve -type resolveMock struct { - calls []*enode.Node - answer *enode.Node +func (t *dialTestResolver) setAnswers(m map[enode.ID]*enode.Node) { + t.mu.Lock() + defer t.mu.Unlock() + + t.answers = m + t.calls = nil } -func (t *resolveMock) Resolve(n *enode.Node) *enode.Node { - t.calls = append(t.calls, n) - return t.answer +func (t *dialTestResolver) checkCalls() bool { + t.mu.Lock() + defer t.mu.Unlock() + + for _, id := range t.calls { + if _, ok := t.answers[id]; !ok { + return false + } + } + return true +} + +func (t *dialTestResolver) Resolve(n *enode.Node) *enode.Node { + t.mu.Lock() + defer t.mu.Unlock() + + t.calls = append(t.calls, n.ID()) + return t.answers[n.ID()] } diff --git a/p2p/peer_test.go b/p2p/peer_test.go index a55ce6861498..c0d9674d6bd0 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -17,15 +17,20 @@ package p2p import ( + "encoding/binary" "errors" "fmt" "math/rand" "net" "reflect" + "strconv" + "strings" "testing" "time" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" ) var discard = Protocol{ @@ -45,10 +50,45 @@ var discard = Protocol{ }, } +// uintID encodes i into a node ID. +func uintID(i uint16) enode.ID { + var id enode.ID + binary.BigEndian.PutUint16(id[:], i) + return id +} + +// newNode creates a node record with the given address. +func newNode(id enode.ID, addr string) *enode.Node { + var r enr.Record + if addr != "" { + // Set the port if present. + if strings.Contains(addr, ":") { + hs, ps, err := net.SplitHostPort(addr) + if err != nil { + panic(fmt.Errorf("invalid address %q", addr)) + } + port, err := strconv.Atoi(ps) + if err != nil { + panic(fmt.Errorf("invalid port in %q", addr)) + } + r.Set(enr.TCP(port)) + r.Set(enr.UDP(port)) + addr = hs + } + // Set the IP. + ip := net.ParseIP(addr) + if ip == nil { + panic(fmt.Errorf("invalid IP %q", addr)) + } + r.Set(enr.IP(ip)) + } + return enode.SignNull(&r, id) +} + func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan error) { fd1, fd2 := net.Pipe() - c1 := &conn{fd: fd1, node: newNode(randomID(), nil), transport: newTestTransport(&newkey().PublicKey, fd1)} - c2 := &conn{fd: fd2, node: newNode(randomID(), nil), transport: newTestTransport(&newkey().PublicKey, fd2)} + c1 := &conn{fd: fd1, node: newNode(randomID(), ""), transport: newTestTransport(&newkey().PublicKey, fd1)} + c2 := &conn{fd: fd2, node: newNode(randomID(), ""), transport: newTestTransport(&newkey().PublicKey, fd2)} for _, p := range protos { c1.caps = append(c1.caps, p.cap()) c2.caps = append(c2.caps, p.cap()) diff --git a/p2p/server.go b/p2p/server.go index c91601b28f57..eddd6bd1a1da 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -51,7 +51,6 @@ const ( discmixTimeout = 5 * time.Second // Connectivity defaults. - maxActiveDialTasks = 16 defaultMaxPendingPeers = 50 defaultDialRatio = 3 @@ -166,6 +165,8 @@ type Config struct { // Logger is a custom logger to use with the p2p.Server. Logger log.Logger `toml:",omitempty"` + + clock mclock.Clock } // Server manages all peer connections. @@ -193,13 +194,12 @@ type Server struct { ntab *discover.UDPv4 DiscV5 *discv5.Network discmix *enode.FairMix - - staticNodeResolver nodeResolver + dialsched *dialScheduler // Channels into the run loop. quit chan struct{} addstatic chan *nodeArgs - removestatic chan *nodeArgs + removestatic chan removestaticArgs addtrusted chan *nodeArgs removetrusted chan *nodeArgs peerOp chan peerOpFunc @@ -274,6 +274,11 @@ type nodeArgs struct { purpose PurposeFlag } +type removestaticArgs struct { + *nodeArgs + done chan<- struct{} +} + type peerDrop struct { *Peer err error @@ -372,31 +377,20 @@ func (srv *Server) LocalNode() *enode.LocalNode { // Peers returns all connected peers. func (srv *Server) Peers() []*Peer { var ps []*Peer - select { - // Note: We'd love to put this function into a variable but - // that seems to cause a weird compiler error in some - // environments. - case srv.peerOp <- func(peers map[enode.ID]*Peer) { + srv.doPeerOp(func(peers map[enode.ID]*Peer) { for _, p := range peers { ps = append(ps, p) } - }: - <-srv.peerOpDone - case <-srv.quit: - } + }) return ps } // PeerCount returns the number of connected peers. func (srv *Server) PeerCount() int { var count int - select { - case srv.peerOp <- func(ps map[enode.ID]*Peer) { + srv.doPeerOp(func(ps map[enode.ID]*Peer) { count = len(ps) - }: - <-srv.peerOpDone - case <-srv.quit: - } + }) return count } @@ -425,8 +419,10 @@ func (srv *Server) AddPeer(node *enode.Node, purpose PurposeFlag) { // RemovePeer disconnects from the given node func (srv *Server) RemovePeer(node *enode.Node, purpose PurposeFlag) { + done := make(chan struct{}) select { - case srv.removestatic <- &nodeArgs{node: node, purpose: purpose}: + case srv.removestatic <- removestaticArgs{nodeArgs: &nodeArgs{node: node, purpose: purpose}, done: done}: + <-done case <-srv.quit: } } @@ -531,6 +527,9 @@ func (srv *Server) Start() (err error) { if srv.log == nil { srv.log = log.Root() } + if srv.clock == nil { + srv.clock = mclock.System{} + } if srv.NoDial && srv.ListenAddr == "" { srv.log.Warn("P2P server will be useless, neither dialing nor listening") } @@ -545,15 +544,12 @@ func (srv *Server) Start() (err error) { if srv.listenFunc == nil { srv.listenFunc = net.Listen } - if srv.Dialer == nil { - srv.Dialer = TCPDialer{&net.Dialer{Timeout: defaultDialTimeout}} - } srv.quit = make(chan struct{}) srv.delpeer = make(chan peerDrop) srv.checkpointPostHandshake = make(chan *conn) srv.checkpointAddPeer = make(chan *conn) srv.addstatic = make(chan *nodeArgs) - srv.removestatic = make(chan *nodeArgs) + srv.removestatic = make(chan removestaticArgs) srv.addtrusted = make(chan *nodeArgs) srv.removetrusted = make(chan *nodeArgs) srv.peerOp = make(chan peerOpFunc) @@ -572,11 +568,10 @@ func (srv *Server) Start() (err error) { if err := srv.setupDiscovery(); err != nil { return err } + srv.setupDialScheduler() - dynPeers := srv.maxDialedConns() - dialer := newDialState(srv.localnode.ID(), dynPeers, &srv.Config) srv.loopWG.Add(1) - go srv.run(dialer) + go srv.run() return nil } @@ -688,7 +683,6 @@ func (srv *Server) setupDiscovery() error { } srv.ntab = ntab srv.discmix.AddSource(ntab.RandomNodes()) - srv.staticNodeResolver = ntab } // Discovery V5 @@ -711,6 +705,47 @@ func (srv *Server) setupDiscovery() error { return nil } +func (srv *Server) setupDialScheduler() { + config := dialConfig{ + self: srv.localnode.ID(), + maxDialPeers: srv.maxDialedConns(), + maxActiveDials: srv.MaxPendingPeers, + log: srv.Logger, + netRestrict: srv.NetRestrict, + dialer: srv.Dialer, + clock: srv.clock, + } + if srv.ntab != nil { + config.resolver = srv.ntab + } + if config.dialer == nil { + config.dialer = tcpDialer{&net.Dialer{Timeout: defaultDialTimeout}} + } + srv.dialsched = newDialScheduler(config, srv.discmix, srv.SetupConn) + for _, n := range srv.StaticNodes { + srv.dialsched.addStatic(n) + } +} + +func (srv *Server) maxInboundConns() int { + return srv.MaxPeers - srv.maxDialedConns() +} + +func (srv *Server) maxDialedConns() (limit int) { + if srv.NoDial || srv.MaxPeers == 0 { + return 0 + } + if srv.DialRatio == 0 { + limit = srv.MaxPeers / defaultDialRatio + } else { + limit = srv.MaxPeers / srv.DialRatio + } + if limit == 0 { + limit = 1 + } + return limit +} + func (srv *Server) setupListening() error { // Launch the listener. listener, err := srv.listenFunc("tcp", srv.ListenAddr) @@ -737,32 +772,29 @@ func (srv *Server) setupListening() error { return nil } -type dialer interface { - newTasks(running int, peers map[enode.ID]*Peer, now time.Time) []task - taskDone(task, time.Time) - addStatic(*enode.Node) - removeStatic(*enode.Node) - isStatic(*enode.Node) bool +// doPeerOp runs fn on the main loop. +func (srv *Server) doPeerOp(fn peerOpFunc) { + select { + case srv.peerOp <- fn: + <-srv.peerOpDone + case <-srv.quit: + } } -func (srv *Server) run(dialstate dialer) { +// run is the main loop of the server. +func (srv *Server) run() { srv.log.Info("Started P2P networking", "self", srv.localnode.Node().URLv4()) defer srv.loopWG.Done() defer srv.nodedb.Close() defer srv.discmix.Close() + defer srv.dialsched.stop() var ( static = make(map[enode.ID]PurposeFlag, len(srv.StaticNodes)) peers = make(map[enode.ID]*Peer) inboundCount = 0 trusted = make(map[enode.ID]PurposeFlag, len(srv.TrustedNodes)) - taskdone = make(chan task, maxActiveDialTasks) - tick = time.NewTicker(30 * time.Second) - runningTasks []task - queuedTasks []task // tasks that can't run yet ) - defer tick.Stop() - // Put trusted nodes into a map to speed up checks. // Trusted peers are loaded on startup or added via AddTrustedPeer RPC. for _, n := range srv.TrustedNodes { @@ -774,36 +806,6 @@ func (srv *Server) run(dialstate dialer) { static[n.ID()] = ExplicitStaticPurpose } - // removes t from runningTasks - delTask := func(t task) { - for i := range runningTasks { - if runningTasks[i] == t { - runningTasks = append(runningTasks[:i], runningTasks[i+1:]...) - break - } - } - } - // starts until max number of active tasks is satisfied - startTasks := func(ts []task) (rest []task) { - i := 0 - for ; len(runningTasks) < maxActiveDialTasks && i < len(ts); i++ { - t := ts[i] - srv.log.Trace("New dial task", "task", t) - go func() { t.Do(srv); taskdone <- t }() - runningTasks = append(runningTasks, t) - } - return ts[i:] - } - scheduleTasks := func() { - // Start from queue first. - queuedTasks = append(queuedTasks[:0], startTasks(queuedTasks)...) - // Query dialer for new tasks and start as many as possible now. - if len(runningTasks) < maxActiveDialTasks { - nt := dialstate.newTasks(len(runningTasks)+len(queuedTasks), peers, time.Now()) - queuedTasks = append(queuedTasks, startTasks(nt)...) - } - } - addStatic := func(n *enode.Node, purpose PurposeFlag) { newPurpose := static[n.ID()].Add(purpose) static[n.ID()] = newPurpose @@ -813,16 +815,29 @@ func (srv *Server) run(dialstate dialer) { p.AddPurpose(purpose) } - dialstate.addStatic(n) + srv.dialsched.addStatic(n) } - removeStatic := func(n *enode.Node, purpose PurposeFlag) { + removeStatic := func(n *enode.Node, purpose PurposeFlag, done chan<- struct{}) { newPurpose := static[n.ID()].Remove(purpose) - + disconnecting := false if newPurpose.HasNoPurpose() { - dialstate.removeStatic(n) + srv.dialsched.removeStatic(n) delete(static, n.ID()) if p, ok := peers[n.ID()]; ok { + disconnecting = true + ch := make(chan *PeerEvent, 1) + sub := srv.peerFeed.Subscribe(ch) + // Wait for the peer connection to end, in a new thread + go func() { + for ev := range ch { + if ev.Peer == n.ID() && ev.Type == PeerEventTypeDrop { + sub.Unsubscribe() + close(done) + return + } + } + }() // Since we're disconnecting (destroying), no need to remove the purpose p.Disconnect(DiscRequested) } @@ -832,6 +847,10 @@ func (srv *Server) run(dialstate dialer) { p.RemovePurpose(purpose) } } + if !disconnecting { + // We aren't disconnecting the peer, so no need to wait for anything further + close(done) + } } addTrusted := func(n *enode.Node, purpose PurposeFlag) { @@ -865,12 +884,7 @@ func (srv *Server) run(dialstate dialer) { running: for { - scheduleTasks() - select { - case <-tick.C: - // This is just here to ensure the dial scheduler runs occasionally. - case <-srv.quit: // The server was stopped. Run the cleanup logic. break running @@ -885,7 +899,7 @@ running: // disconnect request to a peer and begin the // stop keeping the node connected. srv.log.Trace("Removing static node", "node", removeStaticArgs.node, "purpose", removeStaticArgs.purpose) - removeStatic(removeStaticArgs.node, removeStaticArgs.purpose) + removeStatic(removeStaticArgs.node, removeStaticArgs.purpose, removeStaticArgs.done) case addTrustedArgs := <-srv.addtrusted: // This channel is used by AddTrustedPeer to add an enode // to the trusted node set. @@ -903,14 +917,6 @@ running: case cb := <-srv.getInboundCount: cb(inboundCount) srv.getInboundCountDone <- struct{}{} - case t := <-taskdone: - // A task got done. Tell dialstate about it so it - // can update its state and remove it from the active - // tasks list. - srv.log.Trace("Dial task done", "task", t) - dialstate.taskDone(t, time.Now()) - delTask(t) - case c := <-srv.checkpointPostHandshake: // A connection has passed the encryption handshake so // the remote identity is known (but hasn't been verified yet). @@ -926,36 +932,28 @@ running: err := srv.addPeerChecks(peers, c) if err == nil { // The handshakes are done and it passed all checks. - purpose := static[c.node.ID()].Add(trusted[c.node.ID()]) - - // TODO add Peer potential "connected" point - p := newPeer(srv.log, c, srv.Protocols, purpose, srv) - // If message events are enabled, pass the peerFeed - // to the peer - if srv.EnableMsgEvents { - p.events = &srv.peerFeed - } - name := truncateName(c.name) - p.log.Debug("Adding p2p peer", "addr", p.RemoteAddr(), "peers", len(peers)+1, "name", name) + p := srv.launchPeer(c, purpose) peers[c.node.ID()] = p if p.Inbound() { inboundCount++ } - go srv.runPeer(p) + srv.log.Debug("Adding p2p peer", "peercount", len(peers), "id", p.ID(), "conn", c.flags, "addr", p.RemoteAddr(), "name", truncateName(c.name)) + srv.dialsched.peerAdded(c) if conn, ok := c.fd.(*meteredConn); ok { conn.handshakeDone(p) } + if p.Inbound() { + inboundCount++ + } } - // The dialer logic relies on the assumption that - // dial tasks complete after the peer has been added or - // discarded. Unblock the task last. c.cont <- err case pd := <-srv.delpeer: // A peer disconnected. d := common.PrettyDuration(mclock.Now() - pd.created) - pd.log.Debug("Removing p2p peer", "addr", pd.RemoteAddr(), "peers", len(peers)-1, "duration", d, "req", pd.requested, "err", pd.err) delete(peers, pd.ID()) + srv.log.Debug("Removing p2p peer", "peercount", len(peers), "id", pd.ID(), "duration", d, "req", pd.requested, "err", pd.err) + srv.dialsched.peerRemoved(pd.rw) if pd.Inbound() { inboundCount-- } @@ -980,7 +978,7 @@ running: // is closed. for len(peers) > 0 { p := <-srv.delpeer - p.log.Trace("<-delpeer (spindown)", "remainingTasks", len(runningTasks)) + p.log.Trace("<-delpeer (spindown)") delete(peers, p.ID()) } } @@ -1024,21 +1022,6 @@ func (srv *Server) CheckPeerCounts(peer *Peer) error { } } -func (srv *Server) maxInboundConns() int { - return srv.MaxPeers - srv.maxDialedConns() -} - -func (srv *Server) maxDialedConns() int { - if srv.NoDiscovery || srv.NoDial { - return 0 - } - r := srv.DialRatio - if r == 0 { - r = defaultDialRatio - } - return srv.MaxPeers / r -} - // listenLoop runs in its own goroutine and accepts // inbound connections. func (srv *Server) listenLoop() { @@ -1107,18 +1090,20 @@ func (srv *Server) listenLoop() { } func (srv *Server) checkInboundConn(fd net.Conn, remoteIP net.IP) error { - if remoteIP != nil { - // Reject connections that do not match NetRestrict. - if srv.NetRestrict != nil && !srv.NetRestrict.Contains(remoteIP) { - return fmt.Errorf("not whitelisted in NetRestrict") - } - // Reject Internet peers that try too often. - srv.inboundHistory.expire(time.Now()) - if !netutil.IsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) { - return fmt.Errorf("too many attempts") - } - srv.inboundHistory.add(remoteIP.String(), time.Now().Add(inboundThrottleTime)) + if remoteIP == nil { + return nil + } + // Reject connections that do not match NetRestrict. + if srv.NetRestrict != nil && !srv.NetRestrict.Contains(remoteIP) { + return fmt.Errorf("not whitelisted in NetRestrict") } + // Reject Internet peers that try too often. + now := srv.clock.Now() + srv.inboundHistory.expire(now, nil) + if !netutil.IsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) { + return fmt.Errorf("too many attempts") + } + srv.inboundHistory.add(remoteIP.String(), now.Add(inboundThrottleTime)) return nil } @@ -1130,7 +1115,6 @@ func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node) err := srv.setupConn(c, flags, dialDest) if err != nil { c.close(err) - srv.log.Trace("Setting up connection failed", "addr", fd.RemoteAddr(), "err", err) } return err } @@ -1149,7 +1133,9 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro if dialDest != nil { dialPubkey = new(ecdsa.PublicKey) if err := dialDest.Load((*enode.Secp256k1)(dialPubkey)); err != nil { - return errors.New("dial destination doesn't have a secp256k1 public key") + err = errors.New("dial destination doesn't have a secp256k1 public key") + srv.log.Trace("Setting up connection failed", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err) + return err } } @@ -1178,7 +1164,7 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro // Run the capability negotiation handshake. phs, err := c.doProtoHandshake(srv.ourHandshake) if err != nil { - clog.Trace("Failed proto handshake", "err", err) + clog.Trace("Failed p2p handshake", "err", err) return err } if id := c.node.ID(); !bytes.Equal(crypto.Keccak256(phs.ID), id[:]) { @@ -1192,9 +1178,6 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro return err } - // If the checks completed successfully, the connection has been added as a peer and - // runPeer has been launched. - clog.Trace("Connection set up", "inbound", dialDest == nil) return nil } @@ -1226,15 +1209,22 @@ func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error { return <-c.cont } +func (srv *Server) launchPeer(c *conn, purpose PurposeFlag) *Peer { + p := newPeer(srv.log, c, srv.Protocols, purpose, srv) + if srv.EnableMsgEvents { + // If message events are enabled, pass the peerFeed + // to the peer. + p.events = &srv.peerFeed + } + go srv.runPeer(p) + return p +} + // runPeer runs in its own goroutine for each peer. -// it waits until the Peer logic returns and removes -// the peer. func (srv *Server) runPeer(p *Peer) { if srv.newPeerHook != nil { srv.newPeerHook(p) } - - // broadcast peer add srv.peerFeed.Send(&PeerEvent{ Type: PeerEventTypeAdd, Peer: p.ID(), @@ -1242,10 +1232,18 @@ func (srv *Server) runPeer(p *Peer) { LocalAddress: p.LocalAddr().String(), }) - // run the protocol + // Run the per-peer main loop. remoteRequested, err := p.run() - // broadcast peer drop + // Announce disconnect on the main loop to update the peer set. + // The main loop waits for existing peers to be sent on srv.delpeer + // before returning, so this send should not select on srv.quit. + srv.delpeer <- peerDrop{p, err, remoteRequested} + + // Broadcast peer drop to external subscribers. This needs to be + // after the send to delpeer so subscribers have a consistent view of + // the peer set (i.e. Server.Peers() doesn't include the peer when the + // event is received. srv.peerFeed.Send(&PeerEvent{ Type: PeerEventTypeDrop, Peer: p.ID(), @@ -1253,10 +1251,6 @@ func (srv *Server) runPeer(p *Peer) { RemoteAddress: p.RemoteAddr().String(), LocalAddress: p.LocalAddr().String(), }) - - // Note: run waits for existing peers to be sent on srv.delpeer - // before returning, so this send should not select on srv.quit. - srv.delpeer <- peerDrop{p, err, remoteRequested} } // NodeInfo represents a short summary of the information known about the host. diff --git a/p2p/server_test.go b/p2p/server_test.go index b449f1a12959..08c2a930ef84 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -34,10 +34,6 @@ import ( "golang.org/x/crypto/sha3" ) -// func init() { -// log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(false)))) -// } - type testTransport struct { rpub *ecdsa.PublicKey *rlpx @@ -72,11 +68,12 @@ func (c *testTransport) close(err error) { func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) *Server { config := Config{ - Name: "test", - MaxPeers: 10, - ListenAddr: "127.0.0.1:0", - PrivateKey: newkey(), - Logger: testlog.Logger(t, log.LvlTrace), + Name: "test", + MaxPeers: 10, + ListenAddr: "127.0.0.1:0", + NoDiscovery: true, + PrivateKey: newkey(), + Logger: testlog.Logger(t, log.LvlTrace), } server := &Server{ Config: config, @@ -131,11 +128,10 @@ func TestServerDial(t *testing.T) { t.Fatalf("could not setup listener: %v", err) } defer listener.Close() - accepted := make(chan net.Conn) + accepted := make(chan net.Conn, 1) go func() { conn, err := listener.Accept() if err != nil { - t.Error("accept error:", err) return } accepted <- conn @@ -205,158 +201,84 @@ func TestServerDial(t *testing.T) { } } -// This test checks that tasks generated by dialstate are -// actually executed and taskdone is called for them. -func TestServerTaskScheduling(t *testing.T) { - var ( - done = make(chan *testTask) - quit, returned = make(chan struct{}), make(chan struct{}) - tc = 0 - tg = taskgen{ - newFunc: func(running int, peers map[enode.ID]*Peer) []task { - tc++ - return []task{&testTask{index: tc - 1}} - }, - doneFunc: func(t task) { - select { - case done <- t.(*testTask): - case <-quit: - } - }, - } - ) - - // The Server in this test isn't actually running - // because we're only interested in what run does. - db, _ := enode.OpenDB("") - srv := &Server{ - Config: Config{MaxPeers: 10}, - localnode: enode.NewLocalNode(db, newkey(), 1), - nodedb: db, - discmix: enode.NewFairMix(0), - quit: make(chan struct{}), - running: true, - log: log.New(), - } - srv.loopWG.Add(1) - go func() { - srv.run(tg) - close(returned) - }() - - var gotdone []*testTask - for i := 0; i < 100; i++ { - gotdone = append(gotdone, <-done) - } - for i, task := range gotdone { - if task.index != i { - t.Errorf("task %d has wrong index, got %d", i, task.index) - break - } - if !task.called { - t.Errorf("task %d was not called", i) - break - } - } - - close(quit) - srv.Stop() - select { - case <-returned: - case <-time.After(500 * time.Millisecond): - t.Error("Server.run did not return within 500ms") - } -} - -// This test checks that Server doesn't drop tasks, -// even if newTasks returns more than the maximum number of tasks. -func TestServerManyTasks(t *testing.T) { - alltasks := make([]task, 300) - for i := range alltasks { - alltasks[i] = &testTask{index: i} - } - - var ( - db, _ = enode.OpenDB("") - srv = &Server{ - quit: make(chan struct{}), - localnode: enode.NewLocalNode(db, newkey(), 1), - nodedb: db, - running: true, - log: log.New(), - discmix: enode.NewFairMix(0), - } - done = make(chan *testTask) - start, end = 0, 0 - ) +// This test checks that RemovePeer disconnects the peer if it is connected. +func TestServerRemovePeerDisconnect(t *testing.T) { + srv1 := &Server{Config: Config{ + PrivateKey: newkey(), + MaxPeers: 1, + NoDiscovery: true, + Logger: testlog.Logger(t, log.LvlTrace).New("server", "1"), + }} + srv2 := &Server{Config: Config{ + PrivateKey: newkey(), + MaxPeers: 1, + NoDiscovery: true, + NoDial: true, + ListenAddr: "127.0.0.1:0", + Logger: testlog.Logger(t, log.LvlTrace).New("server", "2"), + }} + srv1.Start() + defer srv1.Stop() + srv2.Start() + defer srv2.Stop() + + if !syncAddPeer(srv1, srv2.Self()) { + t.Fatal("peer not connected") + } + srv1.RemovePeer(srv2.Self(), ExplicitStaticPurpose) + if srv1.PeerCount() > 0 { + t.Fatal("removed peer still connected") + } +} + +// This test checks that RemovePeer returns if trying to remove a peer we're not connected to +func TestServerRemovePeerNotConnected(t *testing.T) { + srv := &Server{Config: Config{ + PrivateKey: newkey(), + MaxPeers: 1, + NoDiscovery: true, + Logger: testlog.Logger(t, log.LvlTrace).New("server", "1"), + }} + srv.Start() defer srv.Stop() - srv.loopWG.Add(1) - go srv.run(taskgen{ - newFunc: func(running int, peers map[enode.ID]*Peer) []task { - start, end = end, end+maxActiveDialTasks+10 - if end > len(alltasks) { - end = len(alltasks) - } - return alltasks[start:end] - }, - doneFunc: func(tt task) { - done <- tt.(*testTask) - }, - }) - doneset := make(map[int]bool) - timeout := time.After(2 * time.Second) - for len(doneset) < len(alltasks) { - select { - case tt := <-done: - if doneset[tt.index] { - t.Errorf("task %d got done more than once", tt.index) - } else { - doneset[tt.index] = true - } - case <-timeout: - t.Errorf("%d of %d tasks got done within 2s", len(doneset), len(alltasks)) - for i := 0; i < len(alltasks); i++ { - if !doneset[i] { - t.Logf("task %d not done", i) - } - } - return - } - } + peer := enode.NewV4(&newkey().PublicKey, net.IP{127, 0, 0, 1}, 1, 1) + srv.RemovePeer(peer, ValidatorPurpose) } -type taskgen struct { - newFunc func(running int, peers map[enode.ID]*Peer) []task - doneFunc func(task) -} - -func (tg taskgen) newTasks(running int, peers map[enode.ID]*Peer, now time.Time) []task { - return tg.newFunc(running, peers) -} -func (tg taskgen) taskDone(t task, now time.Time) { - tg.doneFunc(t) -} -func (tg taskgen) addStatic(*enode.Node) { -} -func (tg taskgen) removeStatic(*enode.Node) { -} -func (tg taskgen) isStatic(*enode.Node) bool { - return true -} +// This test checks that RemovePeer returns (without disconnecting the peer) if the peer still has a purpose +func TestServerRemovePeerNoDisconnect(t *testing.T) { + srv1 := &Server{Config: Config{ + PrivateKey: newkey(), + MaxPeers: 1, + NoDiscovery: true, + Logger: testlog.Logger(t, log.LvlTrace).New("server", "1"), + }} + srv2 := &Server{Config: Config{ + PrivateKey: newkey(), + MaxPeers: 1, + NoDiscovery: true, + NoDial: true, + ListenAddr: "127.0.0.1:0", + Logger: testlog.Logger(t, log.LvlTrace).New("server", "2"), + }} + srv1.Start() + defer srv1.Stop() + srv2.Start() + defer srv2.Stop() -type testTask struct { - index int - called bool -} + if !syncAddPeer(srv1, srv2.Self()) { + t.Fatal("peer not connected") + } -func (t *testTask) Do(srv *Server) { - t.called = true + srv1.RemovePeer(srv2.Self(), ValidatorPurpose) + if srv1.PeerCount() == 0 { + t.Fatal("peer was removed, but shouldn't have been") + } } -// This test checks that connections are disconnected -// just after the encryption handshake when the server is -// at capacity. Trusted connections should still be accepted. +// This test checks that connections are disconnected just after the encryption handshake +// when the server is at capacity. Trusted connections should still be accepted. func TestServerAtCap(t *testing.T) { trustedNode := newkey() trustedID := enode.PubkeyToIDV4(&trustedNode.PublicKey) @@ -366,7 +288,8 @@ func TestServerAtCap(t *testing.T) { MaxPeers: 10, NoDial: true, NoDiscovery: true, - TrustedNodes: []*enode.Node{newNode(trustedID, nil)}, + TrustedNodes: []*enode.Node{newNode(trustedID, "")}, + Logger: testlog.Logger(t, log.LvlTrace), }, } if err := srv.Start(); err != nil { @@ -409,7 +332,7 @@ func TestServerAtCap(t *testing.T) { // Remove from trusted set and try again /* KJUE - Re-enable after restoring peer check in server.go - srv.RemoveTrustedPeer(newNode(trustedID, nil), ExplicitTrustedPurpose) + srv.RemoveTrustedPeer(newNode(trustedID, ""), ExplicitTrustedPurpose) c = newconn(trustedID) if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers { @@ -418,7 +341,7 @@ func TestServerAtCap(t *testing.T) { */ // Add anotherID to trusted set and try again - srv.AddTrustedPeer(newNode(anotherID, nil), ExplicitTrustedPurpose) + srv.AddTrustedPeer(newNode(anotherID, ""), ExplicitTrustedPurpose) c = newconn(anotherID) if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil { t.Error("unexpected error for trusted conn @posthandshake:", err) @@ -449,9 +372,9 @@ func TestServerPeerLimits(t *testing.T) { NoDial: true, NoDiscovery: true, Protocols: []Protocol{discard}, + Logger: testlog.Logger(t, log.LvlTrace), }, newTransport: func(fd net.Conn) transport { return tp }, - log: log.New(), } if err := srv.Start(); err != nil { t.Fatalf("couldn't start server: %v", err) @@ -741,3 +664,23 @@ func (l *fakeAddrListener) Accept() (net.Conn, error) { func (c *fakeAddrConn) RemoteAddr() net.Addr { return c.remoteAddr } + +func syncAddPeer(srv *Server, node *enode.Node) bool { + var ( + ch = make(chan *PeerEvent) + sub = srv.SubscribeEvents(ch) + timeout = time.After(2 * time.Second) + ) + defer sub.Unsubscribe() + srv.AddPeer(node, ExplicitStaticPurpose) + for { + select { + case ev := <-ch: + if ev.Type == PeerEventTypeAdd && ev.Peer == node.ID() { + return true + } + case <-timeout: + return false + } + } +} diff --git a/p2p/simulations/adapters/inproc.go b/p2p/simulations/adapters/inproc.go index 9787082e1825..651d9546aec2 100644 --- a/p2p/simulations/adapters/inproc.go +++ b/p2p/simulations/adapters/inproc.go @@ -17,6 +17,7 @@ package adapters import ( + "context" "errors" "fmt" "math" @@ -126,7 +127,7 @@ func (s *SimAdapter) NewNode(config *NodeConfig) (Node, error) { // Dial implements the p2p.NodeDialer interface by connecting to the node using // an in-memory net.Pipe -func (s *SimAdapter) Dial(dest *enode.Node) (conn net.Conn, err error) { +func (s *SimAdapter) Dial(ctx context.Context, dest *enode.Node) (conn net.Conn, err error) { node, ok := s.GetNode(dest.ID()) if !ok { return nil, fmt.Errorf("unknown node: %s", dest.ID()) diff --git a/p2p/util.go b/p2p/util.go index 018cc40e98a9..3c5f6b8508d5 100644 --- a/p2p/util.go +++ b/p2p/util.go @@ -18,7 +18,8 @@ package p2p import ( "container/heap" - "time" + + "github.com/ethereum/go-ethereum/common/mclock" ) // expHeap tracks strings and their expiry time. @@ -27,16 +28,16 @@ type expHeap []expItem // expItem is an entry in addrHistory. type expItem struct { item string - exp time.Time + exp mclock.AbsTime } // nextExpiry returns the next expiry time. -func (h *expHeap) nextExpiry() time.Time { +func (h *expHeap) nextExpiry() mclock.AbsTime { return (*h)[0].exp } // add adds an item and sets its expiry time. -func (h *expHeap) add(item string, exp time.Time) { +func (h *expHeap) add(item string, exp mclock.AbsTime) { heap.Push(h, expItem{item, exp}) } @@ -51,15 +52,18 @@ func (h expHeap) contains(item string) bool { } // expire removes items with expiry time before 'now'. -func (h *expHeap) expire(now time.Time) { - for h.Len() > 0 && h.nextExpiry().Before(now) { - heap.Pop(h) +func (h *expHeap) expire(now mclock.AbsTime, onExp func(string)) { + for h.Len() > 0 && h.nextExpiry() < now { + item := heap.Pop(h) + if onExp != nil { + onExp(item.(expItem).item) + } } } // heap.Interface boilerplate func (h expHeap) Len() int { return len(h) } -func (h expHeap) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) } +func (h expHeap) Less(i, j int) bool { return h[i].exp < h[j].exp } func (h expHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } func (h *expHeap) Push(x interface{}) { *h = append(*h, x.(expItem)) } func (h *expHeap) Pop() interface{} { diff --git a/p2p/util_test.go b/p2p/util_test.go index c9f2648dc92f..cc0d2b215fee 100644 --- a/p2p/util_test.go +++ b/p2p/util_test.go @@ -19,30 +19,32 @@ package p2p import ( "testing" "time" + + "github.com/ethereum/go-ethereum/common/mclock" ) func TestExpHeap(t *testing.T) { var h expHeap var ( - basetime = time.Unix(4000, 0) + basetime = mclock.AbsTime(10) exptimeA = basetime.Add(2 * time.Second) exptimeB = basetime.Add(3 * time.Second) exptimeC = basetime.Add(4 * time.Second) ) - h.add("a", exptimeA) h.add("b", exptimeB) + h.add("a", exptimeA) h.add("c", exptimeC) - if !h.nextExpiry().Equal(exptimeA) { + if h.nextExpiry() != exptimeA { t.Fatal("wrong nextExpiry") } if !h.contains("a") || !h.contains("b") || !h.contains("c") { t.Fatal("heap doesn't contain all live items") } - h.expire(exptimeA.Add(1)) - if !h.nextExpiry().Equal(exptimeB) { + h.expire(exptimeA.Add(1), nil) + if h.nextExpiry() != exptimeB { t.Fatal("wrong nextExpiry") } if h.contains("a") {