Skip to content

Commit

Permalink
p2p, p2p/discover: add signed ENR generation (ethereum#17753)
Browse files Browse the repository at this point in the history
This PR adds enode.LocalNode and integrates it into the p2p
subsystem. This new object is the keeper of the local node
record. For now, a new version of the record is produced every
time the client restarts. We'll make it smarter to avoid that in
the future.

There are a couple of other changes in this commit: discovery now
waits for all of its goroutines at shutdown and the p2p server
now closes the node database after discovery has shut down. This
fixes a leveldb crash in tests. p2p server startup is faster
because it doesn't need to wait for the external IP query
anymore.
  • Loading branch information
fjl authored and wanwiset25 committed Jun 28, 2024
1 parent c72aabb commit 3ca7a7c
Show file tree
Hide file tree
Showing 24 changed files with 976 additions and 275 deletions.
11 changes: 6 additions & 5 deletions cmd/bootnode/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,17 @@ func main() {
}

if *runv5 {
if _, err := discv5.ListenUDP(nodeKey, conn, realaddr, "", restrictList); err != nil {
if _, err := discv5.ListenUDP(nodeKey, conn, "", restrictList); err != nil {
utils.Fatalf("%v", err)
}
} else {
db, _ := enode.OpenDB("")
ln := enode.NewLocalNode(db, nodeKey)
cfg := discover.Config{
PrivateKey: nodeKey,
AnnounceAddr: realaddr,
NetRestrict: restrictList,
PrivateKey: nodeKey,
NetRestrict: restrictList,
}
if _, err := discover.ListenUDP(conn, cfg); err != nil {
if _, err := discover.ListenUDP(conn, ln, cfg); err != nil {
utils.Fatalf("%v", err)
}
}
Expand Down
8 changes: 4 additions & 4 deletions node/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,9 +453,9 @@ func TestProtocolGather(t *testing.T) {
Count int
Maker InstrumentingWrapper
}{
"Zero Protocols": {0, InstrumentedServiceMakerA},
"Single Protocol": {1, InstrumentedServiceMakerB},
"Many Protocols": {25, InstrumentedServiceMakerC},
"zero": {0, InstrumentedServiceMakerA},
"one": {1, InstrumentedServiceMakerB},
"many": {10, InstrumentedServiceMakerC},
}
for id, config := range services {
protocols := make([]p2p.Protocol, config.Count)
Expand All @@ -479,7 +479,7 @@ func TestProtocolGather(t *testing.T) {
defer stack.Stop()

protocols := stack.Server().Protocols
if len(protocols) != 26 {
if len(protocols) != 11 {
t.Fatalf("mismatching number of protocols launched: have %d, want %d", len(protocols), 26)
}
for id, config := range services {
Expand Down
7 changes: 4 additions & 3 deletions p2p/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ type dialstate struct {
maxDynDials int
ntab discoverTable
netrestrict *netutil.Netlist
self enode.ID

lookupRunning bool
dialing map[enode.ID]connFlag
Expand All @@ -84,7 +85,6 @@ type dialstate struct {
}

type discoverTable interface {
Self() *enode.Node
Close()
Resolve(*enode.Node) *enode.Node
LookupRandom() []*enode.Node
Expand Down Expand Up @@ -126,10 +126,11 @@ type waitExpireTask struct {
time.Duration
}

func newDialState(static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
func newDialState(self enode.ID, static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
s := &dialstate{
maxDynDials: maxdyn,
ntab: ntab,
self: self,
netrestrict: netrestrict,
static: make(map[enode.ID]*dialTask),
dialing: make(map[enode.ID]connFlag),
Expand Down Expand Up @@ -266,7 +267,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error {
return errAlreadyDialing
case peers[n.ID()] != nil:
return errAlreadyConnected
case s.ntab != nil && n.ID() == s.ntab.Self().ID():
case n.ID() == s.self:
return errSelf
case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()):
return errNotWhitelisted
Expand Down
16 changes: 8 additions & 8 deletions p2p/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t)
// This test checks that dynamic dials are launched from discovery results.
func TestDialStateDynDial(t *testing.T) {
runDialTest(t, dialtest{
init: newDialState(nil, nil, fakeTable{}, 5, nil),
init: newDialState(enode.ID{}, nil, nil, fakeTable{}, 5, nil),
rounds: []round{
// A discovery query is launched.
{
Expand Down Expand Up @@ -236,7 +236,7 @@ func TestDialStateDynDialBootnode(t *testing.T) {
newNode(uintID(8), nil),
}
runDialTest(t, dialtest{
init: newDialState(nil, bootnodes, table, 5, nil),
init: newDialState(enode.ID{}, nil, bootnodes, table, 5, nil),
rounds: []round{
// 2 dynamic dials attempted, bootnodes pending fallback interval
{
Expand Down Expand Up @@ -324,7 +324,7 @@ func TestDialStateDynDialFromTable(t *testing.T) {
}

runDialTest(t, dialtest{
init: newDialState(nil, nil, table, 10, nil),
init: newDialState(enode.ID{}, nil, nil, table, 10, nil),
rounds: []round{
// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
{
Expand Down Expand Up @@ -433,7 +433,7 @@ func TestDialStateNetRestrict(t *testing.T) {
restrict.Add("127.0.2.0/24")

runDialTest(t, dialtest{
init: newDialState(nil, nil, table, 10, restrict),
init: newDialState(enode.ID{}, nil, nil, table, 10, restrict),
rounds: []round{
{
new: []task{
Expand All @@ -456,7 +456,7 @@ func TestDialStateStaticDial(t *testing.T) {
}

runDialTest(t, dialtest{
init: newDialState(wantStatic, nil, fakeTable{}, 0, nil),
init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
rounds: []round{
// Static dials are launched for the nodes that
// aren't yet connected.
Expand Down Expand Up @@ -554,7 +554,7 @@ func TestDialStaticAfterReset(t *testing.T) {
},
}
dTest := dialtest{
init: newDialState(wantStatic, nil, fakeTable{}, 0, nil),
init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
rounds: rounds,
}
runDialTest(t, dTest)
Expand All @@ -577,7 +577,7 @@ func TestDialStateCache(t *testing.T) {
}

runDialTest(t, dialtest{
init: newDialState(wantStatic, nil, fakeTable{}, 0, nil),
init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
rounds: []round{
// Static dials are launched for the nodes that
// aren't yet connected.
Expand Down Expand Up @@ -636,7 +636,7 @@ func TestDialStateCache(t *testing.T) {
func TestDialResolve(t *testing.T) {
resolved := newNode(uintID(1), net.IP{127, 0, 55, 234})
table := &resolveMock{answer: resolved}
state := newDialState(nil, nil, table, 0, nil)
state := newDialState(enode.ID{}, nil, nil, table, 0, nil)

// Check that the task is generated with an incomplete ID.
dest := newNode(uintID(1), nil)
Expand Down
41 changes: 22 additions & 19 deletions p2p/discover/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,20 @@ type Table struct {
ips netutil.DistinctNetSet

db *enode.DB // database of known nodes
net transport
refreshReq chan chan struct{}
initDone chan struct{}
closeReq chan struct{}
closed chan struct{}

nodeAddedHook func(*node) // for testing

net transport
self *node // metadata of the local node
}

// transport is implemented by the UDP transport.
// it is an interface so we can test without opening lots of UDP
// sockets and without generating a private key.
type transport interface {
self() *enode.Node
ping(enode.ID, *net.UDPAddr) error
findnode(toid enode.ID, addr *net.UDPAddr, target encPubkey) ([]*node, error)
close()
Expand All @@ -100,11 +99,10 @@ type bucket struct {
ips netutil.DistinctNetSet
}

func newTable(t transport, self *enode.Node, db *enode.DB, bootnodes []*enode.Node) (*Table, error) {
func newTable(t transport, db *enode.DB, bootnodes []*enode.Node) (*Table, error) {
tab := &Table{
net: t,
db: db,
self: wrapNode(self),
refreshReq: make(chan chan struct{}),
initDone: make(chan struct{}),
closeReq: make(chan struct{}),
Expand All @@ -127,6 +125,10 @@ func newTable(t transport, self *enode.Node, db *enode.DB, bootnodes []*enode.No
return tab, nil
}

func (tab *Table) self() *enode.Node {
return tab.net.self()
}

func (tab *Table) seedRand() {
var b [8]byte
crand.Read(b[:])
Expand All @@ -136,11 +138,6 @@ func (tab *Table) seedRand() {
tab.mutex.Unlock()
}

// Self returns the local node.
func (tab *Table) Self() *enode.Node {
return unwrapNode(tab.self)
}

// ReadRandomNodes fills the given slice with random nodes from the table. The results
// are guaranteed to be unique for a single invocation, no node will appear twice.
func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) {
Expand Down Expand Up @@ -183,6 +180,10 @@ func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) {

// Close terminates the network listener and flushes the node database.
func (tab *Table) Close() {
if tab.net != nil {
tab.net.close()
}

select {
case <-tab.closed:
// already closed.
Expand Down Expand Up @@ -257,7 +258,7 @@ func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node {
)
// don't query further if we hit ourself.
// unlikely to happen often in practice.
asked[tab.self.ID()] = true
asked[tab.self().ID()] = true

for {
tab.mutex.Lock()
Expand Down Expand Up @@ -340,8 +341,8 @@ func (tab *Table) loop() {
revalidate = time.NewTimer(tab.nextRevalidateTime())
refresh = time.NewTicker(refreshInterval)
copyNodes = time.NewTicker(copyNodesInterval)
revalidateDone = make(chan struct{})
refreshDone = make(chan struct{}) // where doRefresh reports completion
revalidateDone chan struct{} // where doRevalidate reports completion
waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs
)
defer refresh.Stop()
Expand Down Expand Up @@ -372,25 +373,27 @@ loop:
}
waiting, refreshDone = nil, nil
case <-revalidate.C:
revalidateDone = make(chan struct{})
go tab.doRevalidate(revalidateDone)
case <-revalidateDone:
revalidate.Reset(tab.nextRevalidateTime())
revalidateDone = nil
case <-copyNodes.C:
go tab.copyLiveNodes()
case <-tab.closeReq:
break loop
}
}

if tab.net != nil {
tab.net.close()
}
if refreshDone != nil {
<-refreshDone
}
for _, ch := range waiting {
close(ch)
}
if revalidateDone != nil {
<-revalidateDone
}
close(tab.closed)
}

Expand All @@ -408,7 +411,7 @@ func (tab *Table) doRefresh(done chan struct{}) {
// Run self lookup to discover new neighbor nodes.
// We can only do this if we have a secp256k1 identity.
var key ecdsa.PublicKey
if err := tab.self.Load((*enode.Secp256k1)(&key)); err == nil {
if err := tab.self().Load((*enode.Secp256k1)(&key)); err == nil {
tab.lookup(encodePubkey(&key), false)
}

Expand Down Expand Up @@ -530,7 +533,7 @@ func (tab *Table) len() (n int) {

// bucket returns the bucket for the given node ID hash.
func (tab *Table) bucket(id enode.ID) *bucket {
d := enode.LogDist(tab.self.ID(), id)
d := enode.LogDist(tab.self().ID(), id)
if d <= bucketMinDistance {
return tab.buckets[0]
}
Expand All @@ -543,7 +546,7 @@ func (tab *Table) bucket(id enode.ID) *bucket {
//
// The caller must not hold tab.mutex.
func (tab *Table) add(n *node) {
if n.ID() == tab.self.ID() {
if n.ID() == tab.self().ID() {
return
}

Expand Down Expand Up @@ -576,7 +579,7 @@ func (tab *Table) stuff(nodes []*node) {
defer tab.mutex.Unlock()

for _, n := range nodes {
if n.ID() == tab.self.ID() {
if n.ID() == tab.self().ID() {
continue // don't add self
}
b := tab.bucket(n.ID())
Expand Down
10 changes: 7 additions & 3 deletions p2p/discover/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func TestTable_IPLimit(t *testing.T) {
defer db.Close()

for i := 0; i < tableIPLimit+1; i++ {
n := nodeAtDistance(tab.self.ID(), i, net.IP{172, 0, 1, byte(i)})
n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)})
tab.add(n)
}
if tab.len() > tableIPLimit {
Expand All @@ -159,7 +159,7 @@ func TestTable_BucketIPLimit(t *testing.T) {

d := 3
for i := 0; i < bucketIPLimit+1; i++ {
n := nodeAtDistance(tab.self.ID(), d, net.IP{172, 0, 1, byte(i)})
n := nodeAtDistance(tab.self().ID(), d, net.IP{172, 0, 1, byte(i)})
tab.add(n)
}
if tab.len() > bucketIPLimit {
Expand Down Expand Up @@ -241,7 +241,7 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) {

for i := 0; i < len(buf); i++ {
ld := cfg.Rand.Intn(len(tab.buckets))
tab.stuff([]*node{nodeAtDistance(tab.self.ID(), ld, intIP(ld))})
tab.stuff([]*node{nodeAtDistance(tab.self().ID(), ld, intIP(ld))})
}
gotN := tab.ReadRandomNodes(buf)
if gotN != tab.len() {
Expand Down Expand Up @@ -511,6 +511,10 @@ type preminedTestnet struct {
dists [hashBits + 1][]encPubkey
}

func (tn *preminedTestnet) self() *enode.Node {
return nullNode
}

func (tn *preminedTestnet) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) {
// current log distance is encoded in port number
// fmt.Println("findnode query at dist", toaddr.Port)
Expand Down
Loading

0 comments on commit 3ca7a7c

Please sign in to comment.