Skip to content

Commit

Permalink
Discovery: add Context to Listen. (#3577)
Browse files Browse the repository at this point in the history
Add explicit Context to ListenV4 and ListenV5.
This makes it possible to stop listening by an external signal.
  • Loading branch information
battlmonstr authored and Alexey Sharp committed Mar 14, 2022
1 parent 9d9725b commit 0dbcd66
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 29 deletions.
9 changes: 7 additions & 2 deletions cmd/bootnode/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"crypto/ecdsa"
"flag"
"fmt"
"github.com/ledgerwatch/erigon-lib/common"
"net"
"os"

Expand Down Expand Up @@ -129,12 +130,16 @@ func main() {
PrivateKey: nodeKey,
NetRestrict: restrictList,
}

ctx, cancel := common.RootContext()
defer cancel()

if *runv5 {
if _, err := discover.ListenV5(conn, ln, cfg); err != nil {
if _, err := discover.ListenV5(ctx, conn, ln, cfg); err != nil {
utils.Fatalf("%v", err)
}
} else {
if _, err := discover.ListenUDP(conn, ln, cfg); err != nil {
if _, err := discover.ListenUDP(ctx, conn, ln, cfg); err != nil {
utils.Fatalf("%v", err)
}
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/sentry/download/sentry.go
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ func (ss *SentryServerImpl) HandShake(context.Context, *emptypb.Empty) (*proto_s
return reply, nil
}

func (ss *SentryServerImpl) SetStatus(_ context.Context, statusData *proto_sentry.StatusData) (*proto_sentry.SetStatusReply, error) {
func (ss *SentryServerImpl) SetStatus(ctx context.Context, statusData *proto_sentry.StatusData) (*proto_sentry.SetStatusReply, error) {
genesisHash := gointerfaces.ConvertH256ToHash(statusData.ForkData.Genesis)

ss.lock.Lock()
Expand All @@ -879,7 +879,7 @@ func (ss *SentryServerImpl) SetStatus(_ context.Context, statusData *proto_sentr
}

// Add protocol
if err = srv.Start(); err != nil {
if err = srv.Start(ctx); err != nil {
srv.Stop()
return reply, fmt.Errorf("could not start server: %w", err)
}
Expand Down
5 changes: 3 additions & 2 deletions p2p/discover/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package discover

import (
"context"
"crypto/ecdsa"
"net"

Expand Down Expand Up @@ -63,8 +64,8 @@ func (cfg Config) withDefaults() Config {
}

// ListenUDP starts listening for discovery packets on the given UDP socket.
func ListenUDP(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) {
return ListenV4(c, ln, cfg)
func ListenUDP(ctx context.Context, c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) {
return ListenV4(ctx, c, ln, cfg)
}

// ReadPacket is a packet that couldn't be handled. Those packets are sent to the unhandled
Expand Down
4 changes: 2 additions & 2 deletions p2p/discover/v4_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ type reply struct {
matched chan<- bool
}

func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) {
func ListenV4(ctx context.Context, c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) {
cfg = cfg.withDefaults()
closeCtx, cancel := context.WithCancel(context.Background())
closeCtx, cancel := context.WithCancel(ctx)
t := &UDPv4{
conn: c,
priv: cfg.PrivateKey,
Expand Down
7 changes: 5 additions & 2 deletions p2p/discover/v4_udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package discover

import (
"bytes"
"context"
"crypto/ecdsa"
crand "crypto/rand"
"encoding/binary"
Expand Down Expand Up @@ -74,7 +75,8 @@ func newUDPTest(t *testing.T) *udpTest {
panic(err)
}
ln := enode.NewLocalNode(test.db, test.localkey)
test.udp, _ = ListenV4(test.pipe, ln, Config{
ctx := context.Background()
test.udp, _ = ListenV4(ctx, test.pipe, ln, Config{
PrivateKey: test.localkey,
Log: testlog.Logger(t, log.LvlInfo),
})
Expand Down Expand Up @@ -583,7 +585,8 @@ func startLocalhostV4(t *testing.T, cfg Config) *UDPv4 {
realaddr := socket.LocalAddr().(*net.UDPAddr)
ln.SetStaticIP(realaddr.IP)
ln.SetFallbackUDP(realaddr.Port)
udp, err := ListenV4(socket, ln, cfg)
ctx := context.Background()
udp, err := ListenV4(ctx, socket, ln, cfg)
if err != nil {
t.Fatal(err)
}
Expand Down
8 changes: 4 additions & 4 deletions p2p/discover/v5_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ type callTimeout struct {
}

// ListenV5 listens on the given connection.
func ListenV5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) {
t, err := newUDPv5(conn, ln, cfg)
func ListenV5(ctx context.Context, conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) {
t, err := newUDPv5(ctx, conn, ln, cfg)
if err != nil {
return nil, err
}
Expand All @@ -136,8 +136,8 @@ func ListenV5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) {
}

// newUDPv5 creates a UDPv5 transport, but doesn't start any goroutines.
func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) {
closeCtx, cancelCloseCtx := context.WithCancel(context.Background())
func newUDPv5(ctx context.Context, conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) {
closeCtx, cancelCloseCtx := context.WithCancel(ctx)
cfg = cfg.withDefaults()
t := &UDPv5{
// static fields
Expand Down
7 changes: 5 additions & 2 deletions p2p/discover/v5_udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package discover

import (
"bytes"
"context"
"crypto/ecdsa"
"encoding/binary"
"errors"
Expand Down Expand Up @@ -101,7 +102,8 @@ func startLocalhostV5(t *testing.T, cfg Config) *UDPv5 {
realaddr := socket.LocalAddr().(*net.UDPAddr)
ln.SetStaticIP(realaddr.IP)
ln.Set(enr.UDP(realaddr.Port))
udp, err := ListenV5(socket, ln, cfg)
ctx := context.Background()
udp, err := ListenV5(ctx, socket, ln, cfg)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -725,7 +727,8 @@ func newUDPV5Test(t *testing.T) *udpV5Test {
ln := enode.NewLocalNode(test.db, test.localkey)
ln.SetStaticIP(net.IP{10, 0, 0, 1})
ln.Set(enr.UDP(30303))
test.udp, err = ListenV5(test.pipe, ln, Config{
ctx := context.Background()
test.udp, err = ListenV5(ctx, test.pipe, ln, Config{
PrivateKey: test.localkey,
Log: testlog.Logger(t, log.LvlInfo),
ValidSchemes: enode.ValidSchemesForTesting,
Expand Down
13 changes: 7 additions & 6 deletions p2p/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package p2p

import (
"bytes"
"context"
"crypto/ecdsa"
"encoding/hex"
"errors"
Expand Down Expand Up @@ -452,7 +453,7 @@ func (srv *Server) Running() bool {

// Start starts running the server.
// Servers can not be re-used after stopping.
func (srv *Server) Start() error {
func (srv *Server) Start(ctx context.Context) error {
srv.lock.Lock()
defer srv.lock.Unlock()
if srv.running {
Expand Down Expand Up @@ -497,7 +498,7 @@ func (srv *Server) Start() error {
return err
}
}
if err := srv.setupDiscovery(); err != nil {
if err := srv.setupDiscovery(ctx); err != nil {
return err
}
srv.setupDialScheduler()
Expand Down Expand Up @@ -552,7 +553,7 @@ func (srv *Server) setupLocalNode() error {
return nil
}

func (srv *Server) setupDiscovery() error {
func (srv *Server) setupDiscovery(ctx context.Context) error {
srv.discmix = enode.NewFairMix(discmixTimeout)

// Add protocol-specific discovery sources.
Expand Down Expand Up @@ -606,7 +607,7 @@ func (srv *Server) setupDiscovery() error {
Unhandled: unhandled,
Log: srv.log,
}
ntab, err := discover.ListenV4(conn, srv.localnode, cfg)
ntab, err := discover.ListenV4(ctx, conn, srv.localnode, cfg)
if err != nil {
return err
}
Expand All @@ -624,9 +625,9 @@ func (srv *Server) setupDiscovery() error {
}
var err error
if sconn != nil {
srv.DiscV5, err = discover.ListenV5(sconn, srv.localnode, cfg)
srv.DiscV5, err = discover.ListenV5(ctx, sconn, srv.localnode, cfg)
} else {
srv.DiscV5, err = discover.ListenV5(conn, srv.localnode, cfg)
srv.DiscV5, err = discover.ListenV5(ctx, conn, srv.localnode, cfg)
}
if err != nil {
return err
Expand Down
19 changes: 12 additions & 7 deletions p2p/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package p2p

import (
"context"
"crypto/ecdsa"
"crypto/sha256"
"errors"
Expand Down Expand Up @@ -82,12 +83,16 @@ func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) *
return newTestTransport(remoteKey, fd, dialDest)
},
}
if err := server.Start(); err != nil {
if err := server.TestStart(); err != nil {
t.Fatalf("Could not start server: %v", err)
}
return server
}

func (srv *Server) TestStart() error {
return srv.Start(context.Background())
}

func TestServerListen(t *testing.T) {
// start the test server
connected := make(chan *Peer)
Expand Down Expand Up @@ -219,11 +224,11 @@ func TestServerRemovePeerDisconnect(t *testing.T) {
ListenAddr: "127.0.0.1:0",
Logger: testlog.Logger(t, log.LvlTrace).New("server", "2"),
}}
if err := srv1.Start(); err != nil {
if err := srv1.TestStart(); err != nil {
t.Fatal("cant start srv1")
}
defer srv1.Stop()
if err := srv2.Start(); err != nil {
if err := srv2.TestStart(); err != nil {
t.Fatal("cant start srv2")
}
defer srv2.Stop()
Expand Down Expand Up @@ -252,7 +257,7 @@ func TestServerAtCap(t *testing.T) {
Logger: testlog.Logger(t, log.LvlTrace),
},
}
if err := srv.Start(); err != nil {
if err := srv.TestStart(); err != nil {
t.Fatalf("could not start: %v", err)
}
defer srv.Stop()
Expand Down Expand Up @@ -329,7 +334,7 @@ func TestServerPeerLimits(t *testing.T) {
},
newTransport: func(fd net.Conn, dialDest *ecdsa.PublicKey) transport { return tp },
}
if err := srv.Start(); err != nil {
if err := srv.TestStart(); err != nil {
t.Fatalf("couldn't start server: %v", err)
}
defer srv.Stop()
Expand Down Expand Up @@ -440,7 +445,7 @@ func TestServerSetupConn(t *testing.T) {
log: cfg.Logger,
}
if !test.dontstart {
if err := srv.Start(); err != nil {
if err := srv.TestStart(); err != nil {
t.Fatalf("couldn't start server: %v", err)
}
defer srv.Stop()
Expand Down Expand Up @@ -530,7 +535,7 @@ func TestServerInboundThrottle(t *testing.T) {
return listenFakeAddr(network, laddr, fakeAddr)
},
}
if err := srv.Start(); err != nil {
if err := srv.TestStart(); err != nil {
t.Fatal("can't start: ", err)
}
defer srv.Stop()
Expand Down

0 comments on commit 0dbcd66

Please sign in to comment.