diff --git a/cmd/clef/README.md b/cmd/clef/README.md
index 1180f369b6..bffa36fcce 100644
--- a/cmd/clef/README.md
+++ b/cmd/clef/README.md
@@ -115,9 +115,9 @@ Some snags and todos
Clef listens to HTTP requests on `rpcaddr`:`rpcport` (or to IPC on `ipcpath`), with the same JSON-RPC standard as Geth. The messages are expected to be [JSON-RPC 2.0 standard](https://www.jsonrpc.org/specification).
-Some of these call can require user interaction. Clients must be aware that responses may be delayed significantly or may never be received if a users decides to ignore the confirmation request.
+Some of these calls can require user interaction. Clients must be aware that responses may be delayed significantly or may never be received if a user decides to ignore the confirmation request.
-The External API is **untrusted**: it does not accept credentials over this API, nor does it expect that requests have any authority.
+The External API is **untrusted**: it does not accept credentials, nor does it expect that requests have any authority.
### Internal UI API
@@ -172,9 +172,9 @@ None
Response
```json
{
- "id": 0,
- "jsonrpc": "2.0",
- "result": "0xbea9183f8f4f03d427f6bcea17388bdff1cab133"
+ "id": 0,
+ "jsonrpc": "2.0",
+ "result": "0xbea9183f8f4f03d427f6bcea17388bdff1cab133"
}
```
@@ -370,7 +370,7 @@ Response
### account_signTypedData
#### Sign data
- Signs a chunk of structured data conformant to [EIP712]([EIP-712](https://github.com/ethereum/EIPs/blob/master/EIPS/eip-712.md)) and returns the calculated signature.
+ Signs a chunk of structured data conformant to [EIP-712](https://github.com/ethereum/EIPs/blob/master/EIPS/eip-712.md) and returns the calculated signature.
#### Arguments
- account [address]: account to sign with
diff --git a/cmd/devp2p/discv4cmd.go b/cmd/devp2p/discv4cmd.go
index 8580c61216..99b0957ab3 100644
--- a/cmd/devp2p/discv4cmd.go
+++ b/cmd/devp2p/discv4cmd.go
@@ -19,11 +19,14 @@ package main
import (
"fmt"
"net"
+ "os"
"strings"
"time"
+ "github.com/ethereum/go-ethereum/cmd/devp2p/internal/v4test"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/internal/utesting"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/params"
@@ -40,6 +43,7 @@ var (
discv4ResolveCommand,
discv4ResolveJSONCommand,
discv4CrawlCommand,
+ discv4TestCommand,
},
}
discv4PingCommand = cli.Command{
@@ -74,6 +78,12 @@ var (
Action: discv4Crawl,
Flags: []cli.Flag{bootnodesFlag, crawlTimeoutFlag},
}
+ discv4TestCommand = cli.Command{
+ Name: "test",
+ Usage: "Runs tests against a node",
+ Action: discv4Test,
+ Flags: []cli.Flag{remoteEnodeFlag, testPatternFlag, testListen1Flag, testListen2Flag},
+ }
)
var (
@@ -98,6 +108,25 @@ var (
Usage: "Time limit for the crawl.",
Value: 30 * time.Minute,
}
+ remoteEnodeFlag = cli.StringFlag{
+ Name: "remote",
+ Usage: "Enode of the remote node under test",
+ EnvVar: "REMOTE_ENODE",
+ }
+ testPatternFlag = cli.StringFlag{
+ Name: "run",
+ Usage: "Pattern of test suite(s) to run",
+ }
+ testListen1Flag = cli.StringFlag{
+ Name: "listen1",
+ Usage: "IP address of the first tester",
+ Value: v4test.Listen1,
+ }
+ testListen2Flag = cli.StringFlag{
+ Name: "listen2",
+ Usage: "IP address of the second tester",
+ Value: v4test.Listen2,
+ }
)
func discv4Ping(ctx *cli.Context) error {
@@ -184,6 +213,28 @@ func discv4Crawl(ctx *cli.Context) error {
return nil
}
+func discv4Test(ctx *cli.Context) error {
+ // Configure test package globals.
+ if !ctx.IsSet(remoteEnodeFlag.Name) {
+ return fmt.Errorf("Missing -%v", remoteEnodeFlag.Name)
+ }
+ v4test.Remote = ctx.String(remoteEnodeFlag.Name)
+ v4test.Listen1 = ctx.String(testListen1Flag.Name)
+ v4test.Listen2 = ctx.String(testListen2Flag.Name)
+
+ // Filter and run test cases.
+ tests := v4test.AllTests
+ if ctx.IsSet(testPatternFlag.Name) {
+ tests = utesting.MatchTests(tests, ctx.String(testPatternFlag.Name))
+ }
+ results := utesting.RunTests(tests, os.Stdout)
+ if fails := utesting.CountFailures(results); fails > 0 {
+ return fmt.Errorf("%v/%v tests passed.", len(tests)-fails, len(tests))
+ }
+ fmt.Printf("%v/%v passed\n", len(tests), len(tests))
+ return nil
+}
+
// startV4 starts an ephemeral discovery V4 node.
func startV4(ctx *cli.Context) *discover.UDPv4 {
ln, config := makeDiscoveryConfig(ctx)
diff --git a/cmd/devp2p/internal/v4test/discv4tests.go b/cmd/devp2p/internal/v4test/discv4tests.go
new file mode 100644
index 0000000000..140b96bfa5
--- /dev/null
+++ b/cmd/devp2p/internal/v4test/discv4tests.go
@@ -0,0 +1,467 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of go-ethereum.
+//
+// go-ethereum is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// go-ethereum is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with go-ethereum. If not, see .
+
+package v4test
+
+import (
+ "bytes"
+ "crypto/rand"
+ "fmt"
+ "net"
+ "reflect"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/internal/utesting"
+ "github.com/ethereum/go-ethereum/p2p/discover/v4wire"
+)
+
+const (
+ expiration = 20 * time.Second
+ wrongPacket = 66
+ macSize = 256 / 8
+)
+
+var (
+ // Remote node under test
+ Remote string
+ // IP where the first tester is listening, port will be assigned
+ Listen1 string = "127.0.0.1"
+ // IP where the second tester is listening, port will be assigned
+ // Before running the test, you may have to `sudo ifconfig lo0 add 127.0.0.2` (on MacOS at least)
+ Listen2 string = "127.0.0.2"
+)
+
+type pingWithJunk struct {
+ Version uint
+ From, To v4wire.Endpoint
+ Expiration uint64
+ JunkData1 uint
+ JunkData2 []byte
+}
+
+func (req *pingWithJunk) Name() string { return "PING/v4" }
+func (req *pingWithJunk) Kind() byte { return v4wire.PingPacket }
+
+type pingWrongType struct {
+ Version uint
+ From, To v4wire.Endpoint
+ Expiration uint64
+}
+
+func (req *pingWrongType) Name() string { return "WRONG/v4" }
+func (req *pingWrongType) Kind() byte { return wrongPacket }
+
+func futureExpiration() uint64 {
+ return uint64(time.Now().Add(expiration).Unix())
+}
+
+// This test just sends a PING packet and expects a response.
+func BasicPing(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ pingHash := te.send(te.l1, &v4wire.Ping{
+ Version: 4,
+ From: te.localEndpoint(te.l1),
+ To: te.remoteEndpoint(),
+ Expiration: futureExpiration(),
+ })
+
+ reply, _, _ := te.read(te.l1)
+ if err := te.checkPong(reply, pingHash); err != nil {
+ t.Fatal(err)
+ }
+}
+
+// checkPong verifies that reply is a valid PONG matching the given ping hash.
+func (te *testenv) checkPong(reply v4wire.Packet, pingHash []byte) error {
+ if reply == nil || reply.Kind() != v4wire.PongPacket {
+ return fmt.Errorf("expected PONG reply, got %v", reply)
+ }
+ pong := reply.(*v4wire.Pong)
+ if !bytes.Equal(pong.ReplyTok, pingHash) {
+ return fmt.Errorf("PONG reply token mismatch: got %x, want %x", pong.ReplyTok, pingHash)
+ }
+ wantEndpoint := te.localEndpoint(te.l1)
+ if !reflect.DeepEqual(pong.To, wantEndpoint) {
+ return fmt.Errorf("PONG 'to' endpoint mismatch: got %+v, want %+v", pong.To, wantEndpoint)
+ }
+ if v4wire.Expired(pong.Expiration) {
+ return fmt.Errorf("PONG is expired (%v)", pong.Expiration)
+ }
+ return nil
+}
+
+// This test sends a PING packet with wrong 'to' field and expects a PONG response.
+func PingWrongTo(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")}
+ pingHash := te.send(te.l1, &v4wire.Ping{
+ Version: 4,
+ From: te.localEndpoint(te.l1),
+ To: wrongEndpoint,
+ Expiration: futureExpiration(),
+ })
+
+ reply, _, _ := te.read(te.l1)
+ if err := te.checkPong(reply, pingHash); err != nil {
+ t.Fatal(err)
+ }
+}
+
+// This test sends a PING packet with wrong 'from' field and expects a PONG response.
+func PingWrongFrom(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")}
+ pingHash := te.send(te.l1, &v4wire.Ping{
+ Version: 4,
+ From: wrongEndpoint,
+ To: te.remoteEndpoint(),
+ Expiration: futureExpiration(),
+ })
+
+ reply, _, _ := te.read(te.l1)
+ if err := te.checkPong(reply, pingHash); err != nil {
+ t.Fatal(err)
+ }
+}
+
+// This test sends a PING packet with additional data at the end and expects a PONG
+// response. The remote node should respond because EIP-8 mandates ignoring additional
+// trailing data.
+func PingExtraData(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ pingHash := te.send(te.l1, &pingWithJunk{
+ Version: 4,
+ From: te.localEndpoint(te.l1),
+ To: te.remoteEndpoint(),
+ Expiration: futureExpiration(),
+ JunkData1: 42,
+ JunkData2: []byte{9, 8, 7, 6, 5, 4, 3, 2, 1},
+ })
+
+ reply, _, _ := te.read(te.l1)
+ if err := te.checkPong(reply, pingHash); err != nil {
+ t.Fatal(err)
+ }
+}
+
+// This test sends a PING packet with additional data and wrong 'from' field
+// and expects a PONG response.
+func PingExtraDataWrongFrom(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")}
+ req := pingWithJunk{
+ Version: 4,
+ From: wrongEndpoint,
+ To: te.remoteEndpoint(),
+ Expiration: futureExpiration(),
+ JunkData1: 42,
+ JunkData2: []byte{9, 8, 7, 6, 5, 4, 3, 2, 1},
+ }
+ pingHash := te.send(te.l1, &req)
+ reply, _, _ := te.read(te.l1)
+ if err := te.checkPong(reply, pingHash); err != nil {
+ t.Fatal(err)
+ }
+}
+
+// This test sends a PING packet with an expiration in the past.
+// The remote node should not respond.
+func PingPastExpiration(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ te.send(te.l1, &v4wire.Ping{
+ Version: 4,
+ From: te.localEndpoint(te.l1),
+ To: te.remoteEndpoint(),
+ Expiration: -futureExpiration(),
+ })
+
+ reply, _, _ := te.read(te.l1)
+ if reply != nil {
+ t.Fatal("Expected no reply, got", reply)
+ }
+}
+
+// This test sends an invalid packet. The remote node should not respond.
+func WrongPacketType(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ te.send(te.l1, &pingWrongType{
+ Version: 4,
+ From: te.localEndpoint(te.l1),
+ To: te.remoteEndpoint(),
+ Expiration: futureExpiration(),
+ })
+
+ reply, _, _ := te.read(te.l1)
+ if reply != nil {
+ t.Fatal("Expected no reply, got", reply)
+ }
+}
+
+// This test verifies that the default behaviour of ignoring 'from' fields is unaffected by
+// the bonding process. After bonding, it pings the target with a different from endpoint.
+func BondThenPingWithWrongFrom(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+ bond(t, te)
+
+ wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")}
+ pingHash := te.send(te.l1, &v4wire.Ping{
+ Version: 4,
+ From: wrongEndpoint,
+ To: te.remoteEndpoint(),
+ Expiration: futureExpiration(),
+ })
+
+ reply, _, _ := te.read(te.l1)
+ if err := te.checkPong(reply, pingHash); err != nil {
+ t.Fatal(err)
+ }
+}
+
+// This test just sends FINDNODE. The remote node should not reply
+// because the endpoint proof has not completed.
+func FindnodeWithoutEndpointProof(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ req := v4wire.Findnode{Expiration: futureExpiration()}
+ rand.Read(req.Target[:])
+ te.send(te.l1, &req)
+
+ reply, _, _ := te.read(te.l1)
+ if reply != nil {
+ t.Fatal("Expected no response, got", reply)
+ }
+}
+
+// BasicFindnode sends a FINDNODE request after performing the endpoint
+// proof. The remote node should respond.
+func BasicFindnode(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+ bond(t, te)
+
+ findnode := v4wire.Findnode{Expiration: futureExpiration()}
+ rand.Read(findnode.Target[:])
+ te.send(te.l1, &findnode)
+
+ reply, _, err := te.read(te.l1)
+ if err != nil {
+ t.Fatal("read find nodes", err)
+ }
+ if reply.Kind() != v4wire.NeighborsPacket {
+ t.Fatal("Expected neighbors, got", reply.Name())
+ }
+}
+
+// This test sends an unsolicited NEIGHBORS packet after the endpoint proof, then sends
+// FINDNODE to read the remote table. The remote node should not return the node contained
+// in the unsolicited NEIGHBORS packet.
+func UnsolicitedNeighbors(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+ bond(t, te)
+
+ // Send unsolicited NEIGHBORS response.
+ fakeKey, _ := crypto.GenerateKey()
+ encFakeKey := v4wire.EncodePubkey(&fakeKey.PublicKey)
+ neighbors := v4wire.Neighbors{
+ Expiration: futureExpiration(),
+ Nodes: []v4wire.Node{{
+ ID: encFakeKey,
+ IP: net.IP{1, 2, 3, 4},
+ UDP: 30303,
+ TCP: 30303,
+ }},
+ }
+ te.send(te.l1, &neighbors)
+
+ // Check if the remote node included the fake node.
+ te.send(te.l1, &v4wire.Findnode{
+ Expiration: futureExpiration(),
+ Target: encFakeKey,
+ })
+
+ reply, _, err := te.read(te.l1)
+ if err != nil {
+ t.Fatal("read find nodes", err)
+ }
+ if reply.Kind() != v4wire.NeighborsPacket {
+ t.Fatal("Expected neighbors, got", reply.Name())
+ }
+ nodes := reply.(*v4wire.Neighbors).Nodes
+ if contains(nodes, encFakeKey) {
+ t.Fatal("neighbors response contains node from earlier unsolicited neighbors response")
+ }
+}
+
+// This test sends FINDNODE with an expiration timestamp in the past.
+// The remote node should not respond.
+func FindnodePastExpiration(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+ bond(t, te)
+
+ findnode := v4wire.Findnode{Expiration: -futureExpiration()}
+ rand.Read(findnode.Target[:])
+ te.send(te.l1, &findnode)
+
+ for {
+ reply, _, _ := te.read(te.l1)
+ if reply == nil {
+ return
+ } else if reply.Kind() == v4wire.NeighborsPacket {
+ t.Fatal("Unexpected NEIGHBORS response for expired FINDNODE request")
+ }
+ }
+}
+
+// bond performs the endpoint proof with the remote node.
+func bond(t *utesting.T, te *testenv) {
+ te.send(te.l1, &v4wire.Ping{
+ Version: 4,
+ From: te.localEndpoint(te.l1),
+ To: te.remoteEndpoint(),
+ Expiration: futureExpiration(),
+ })
+
+ var gotPing, gotPong bool
+ for !gotPing || !gotPong {
+ req, hash, err := te.read(te.l1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ switch req.(type) {
+ case *v4wire.Ping:
+ te.send(te.l1, &v4wire.Pong{
+ To: te.remoteEndpoint(),
+ ReplyTok: hash,
+ Expiration: futureExpiration(),
+ })
+ gotPing = true
+ case *v4wire.Pong:
+ // TODO: maybe verify pong data here
+ gotPong = true
+ }
+ }
+}
+
+// This test attempts to perform a traffic amplification attack against a
+// 'victim' endpoint using FINDNODE. In this attack scenario, the attacker
+// attempts to complete the endpoint proof non-interactively by sending a PONG
+// with mismatching reply token from the 'victim' endpoint. The attack works if
+// the remote node does not verify the PONG reply token field correctly. The
+// attacker could then perform traffic amplification by sending many FINDNODE
+// requests to the discovery node, which would reply to the 'victim' address.
+func FindnodeAmplificationInvalidPongHash(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ // Send PING to start endpoint verification.
+ te.send(te.l1, &v4wire.Ping{
+ Version: 4,
+ From: te.localEndpoint(te.l1),
+ To: te.remoteEndpoint(),
+ Expiration: futureExpiration(),
+ })
+
+ var gotPing, gotPong bool
+ for !gotPing || !gotPong {
+ req, _, err := te.read(te.l1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ switch req.(type) {
+ case *v4wire.Ping:
+ // Send PONG from this node ID, but with invalid ReplyTok.
+ te.send(te.l1, &v4wire.Pong{
+ To: te.remoteEndpoint(),
+ ReplyTok: make([]byte, macSize),
+ Expiration: futureExpiration(),
+ })
+ gotPing = true
+ case *v4wire.Pong:
+ gotPong = true
+ }
+ }
+
+ // Now send FINDNODE. The remote node should not respond because our
+ // PONG did not reference the PING hash.
+ findnode := v4wire.Findnode{Expiration: futureExpiration()}
+ rand.Read(findnode.Target[:])
+ te.send(te.l1, &findnode)
+
+ // If we receive a NEIGHBORS response, the attack worked and the test fails.
+ reply, _, _ := te.read(te.l1)
+ if reply != nil && reply.Kind() == v4wire.NeighborsPacket {
+ t.Error("Got neighbors")
+ }
+}
+
+// This test attempts to perform a traffic amplification attack using FINDNODE.
+// The attack works if the remote node does not verify the IP address of FINDNODE
+// against the endpoint verification proof done by PING/PONG.
+func FindnodeAmplificationWrongIP(t *utesting.T) {
+ te := newTestEnv(Remote, Listen1, Listen2)
+ defer te.close()
+
+ // Do the endpoint proof from the l1 IP.
+ bond(t, te)
+
+ // Now send FINDNODE from the same node ID, but different IP address.
+ // The remote node should not respond.
+ findnode := v4wire.Findnode{Expiration: futureExpiration()}
+ rand.Read(findnode.Target[:])
+ te.send(te.l2, &findnode)
+
+ // If we receive a NEIGHBORS response, the attack worked and the test fails.
+ reply, _, _ := te.read(te.l2)
+ if reply != nil {
+ t.Error("Got NEIGHORS response for FINDNODE from wrong IP")
+ }
+}
+
+var AllTests = []utesting.Test{
+ {Name: "Ping/Basic", Fn: BasicPing},
+ {Name: "Ping/WrongTo", Fn: PingWrongTo},
+ {Name: "Ping/WrongFrom", Fn: PingWrongFrom},
+ {Name: "Ping/ExtraData", Fn: PingExtraData},
+ {Name: "Ping/ExtraDataWrongFrom", Fn: PingExtraDataWrongFrom},
+ {Name: "Ping/PastExpiration", Fn: PingPastExpiration},
+ {Name: "Ping/WrongPacketType", Fn: WrongPacketType},
+ {Name: "Ping/BondThenPingWithWrongFrom", Fn: BondThenPingWithWrongFrom},
+ {Name: "Findnode/WithoutEndpointProof", Fn: FindnodeWithoutEndpointProof},
+ {Name: "Findnode/BasicFindnode", Fn: BasicFindnode},
+ {Name: "Findnode/UnsolicitedNeighbors", Fn: UnsolicitedNeighbors},
+ {Name: "Findnode/PastExpiration", Fn: FindnodePastExpiration},
+ {Name: "Amplification/InvalidPongHash", Fn: FindnodeAmplificationInvalidPongHash},
+ {Name: "Amplification/WrongIP", Fn: FindnodeAmplificationWrongIP},
+}
diff --git a/cmd/devp2p/internal/v4test/framework.go b/cmd/devp2p/internal/v4test/framework.go
new file mode 100644
index 0000000000..9286594181
--- /dev/null
+++ b/cmd/devp2p/internal/v4test/framework.go
@@ -0,0 +1,123 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of go-ethereum.
+//
+// go-ethereum is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// go-ethereum is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with go-ethereum. If not, see .
+
+package v4test
+
+import (
+ "crypto/ecdsa"
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/p2p/discover/v4wire"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+)
+
+const waitTime = 300 * time.Millisecond
+
+type testenv struct {
+ l1, l2 net.PacketConn
+ key *ecdsa.PrivateKey
+ remote *enode.Node
+ remoteAddr *net.UDPAddr
+}
+
+func newTestEnv(remote string, listen1, listen2 string) *testenv {
+ l1, err := net.ListenPacket("udp", fmt.Sprintf("%v:0", listen1))
+ if err != nil {
+ panic(err)
+ }
+ l2, err := net.ListenPacket("udp", fmt.Sprintf("%v:0", listen2))
+ if err != nil {
+ panic(err)
+ }
+ key, err := crypto.GenerateKey()
+ if err != nil {
+ panic(err)
+ }
+ node, err := enode.Parse(enode.ValidSchemes, remote)
+ if err != nil {
+ panic(err)
+ }
+ if node.IP() == nil || node.UDP() == 0 {
+ var ip net.IP
+ var tcpPort, udpPort int
+ if ip = node.IP(); ip == nil {
+ ip = net.ParseIP("127.0.0.1")
+ }
+ if tcpPort = node.TCP(); tcpPort == 0 {
+ tcpPort = 30303
+ }
+ if udpPort = node.TCP(); udpPort == 0 {
+ udpPort = 30303
+ }
+ node = enode.NewV4(node.Pubkey(), ip, tcpPort, udpPort)
+ }
+ addr := &net.UDPAddr{IP: node.IP(), Port: node.UDP()}
+ return &testenv{l1, l2, key, node, addr}
+}
+
+func (te *testenv) close() {
+ te.l1.Close()
+ te.l2.Close()
+}
+
+func (te *testenv) send(c net.PacketConn, req v4wire.Packet) []byte {
+ packet, hash, err := v4wire.Encode(te.key, req)
+ if err != nil {
+ panic(fmt.Errorf("can't encode %v packet: %v", req.Name(), err))
+ }
+ if _, err := c.WriteTo(packet, te.remoteAddr); err != nil {
+ panic(fmt.Errorf("can't send %v: %v", req.Name(), err))
+ }
+ return hash
+}
+
+func (te *testenv) read(c net.PacketConn) (v4wire.Packet, []byte, error) {
+ buf := make([]byte, 2048)
+ if err := c.SetReadDeadline(time.Now().Add(waitTime)); err != nil {
+ return nil, nil, err
+ }
+ n, _, err := c.ReadFrom(buf)
+ if err != nil {
+ return nil, nil, err
+ }
+ p, _, hash, err := v4wire.Decode(buf[:n])
+ return p, hash, err
+}
+
+func (te *testenv) localEndpoint(c net.PacketConn) v4wire.Endpoint {
+ addr := c.LocalAddr().(*net.UDPAddr)
+ return v4wire.Endpoint{
+ IP: addr.IP.To4(),
+ UDP: uint16(addr.Port),
+ TCP: 0,
+ }
+}
+
+func (te *testenv) remoteEndpoint() v4wire.Endpoint {
+ return v4wire.NewEndpoint(te.remoteAddr, 0)
+}
+
+func contains(ns []v4wire.Node, key v4wire.Pubkey) bool {
+ for _, n := range ns {
+ if n.ID == key {
+ return true
+ }
+ }
+ return false
+}
diff --git a/common/math/integer.go b/common/math/integer.go
index 46d91ab2a4..93b1d036dd 100644
--- a/common/math/integer.go
+++ b/common/math/integer.go
@@ -18,7 +18,6 @@ package math
import (
"fmt"
- "math/bits"
"strconv"
)
@@ -88,12 +87,13 @@ func SafeSub(x, y uint64) (uint64, bool) {
// SafeAdd returns the result and whether overflow occurred.
func SafeAdd(x, y uint64) (uint64, bool) {
- sum, carry := bits.Add64(x, y, 0)
- return sum, carry != 0
+ return x + y, y > MaxUint64-x
}
// SafeMul returns multiplication result and whether overflow occurred.
func SafeMul(x, y uint64) (uint64, bool) {
- hi, lo := bits.Mul64(x, y)
- return lo, hi != 0
+ if x == 0 || y == 0 {
+ return 0, false
+ }
+ return x * y, y > MaxUint64/x
}
diff --git a/core/tx_list.go b/core/tx_list.go
index 8beb28bba9..164c73006b 100644
--- a/core/tx_list.go
+++ b/core/tx_list.go
@@ -99,30 +99,7 @@ func (m *txSortedMap) Forward(threshold uint64) types.Transactions {
// Filter iterates over the list of transactions and removes all of them for which
// the specified function evaluates to true.
-// Filter, as opposed to 'filter', re-initialises the heap after the operation is done.
-// If you want to do several consecutive filterings, it's therefore better to first
-// do a .filter(func1) followed by .Filter(func2) or reheap()
func (m *txSortedMap) Filter(filter func(*types.Transaction) bool) types.Transactions {
- removed := m.filter(filter)
- // If transactions were removed, the heap and cache are ruined
- if len(removed) > 0 {
- m.reheap()
- }
- return removed
-}
-
-func (m *txSortedMap) reheap() {
- *m.index = make([]uint64, 0, len(m.items))
- for nonce := range m.items {
- *m.index = append(*m.index, nonce)
- }
- heap.Init(m.index)
- m.cache = nil
-}
-
-// filter is identical to Filter, but **does not** regenerate the heap. This method
-// should only be used if followed immediately by a call to Filter or reheap()
-func (m *txSortedMap) filter(filter func(*types.Transaction) bool) types.Transactions {
var removed types.Transactions
// Collect all the transactions to filter out
@@ -132,7 +109,14 @@ func (m *txSortedMap) filter(filter func(*types.Transaction) bool) types.Transac
delete(m.items, nonce)
}
}
+ // If transactions were removed, the heap and cache are ruined
if len(removed) > 0 {
+ *m.index = make([]uint64, 0, len(m.items))
+ for nonce := range m.items {
+ *m.index = append(*m.index, nonce)
+ }
+ heap.Init(m.index)
+
m.cache = nil
}
return removed
@@ -213,7 +197,10 @@ func (m *txSortedMap) Len() int {
return len(m.items)
}
-func (m *txSortedMap) flatten() types.Transactions {
+// Flatten creates a nonce-sorted slice of transactions based on the loosely
+// sorted internal representation. The result of the sorting is cached in case
+// it's requested again before any modifications are made to the contents.
+func (m *txSortedMap) Flatten() types.Transactions {
// If the sorting was not cached yet, create and cache it
if m.cache == nil {
m.cache = make(types.Transactions, 0, len(m.items))
@@ -222,27 +209,12 @@ func (m *txSortedMap) flatten() types.Transactions {
}
sort.Sort(types.TxByNonce(m.cache))
}
- return m.cache
-}
-
-// Flatten creates a nonce-sorted slice of transactions based on the loosely
-// sorted internal representation. The result of the sorting is cached in case
-// it's requested again before any modifications are made to the contents.
-func (m *txSortedMap) Flatten() types.Transactions {
// Copy the cache to prevent accidental modifications
- cache := m.flatten()
- txs := make(types.Transactions, len(cache))
- copy(txs, cache)
+ txs := make(types.Transactions, len(m.cache))
+ copy(txs, m.cache)
return txs
}
-// LastElement returns the last element of a flattened list, thus, the
-// transaction with the highest nonce
-func (m *txSortedMap) LastElement() *types.Transaction {
- cache := m.flatten()
- return cache[len(cache)-1]
-}
-
// txList is a "list" of transactions belonging to an account, sorted by account
// nonce. The same type can be used both for storing contiguous transactions for
// the executable/pending queue; and for storing gapped transactions for the non-
@@ -251,16 +223,17 @@ type txList struct {
strict bool // Whether nonces are strictly continuous or not
txs *txSortedMap // Heap indexed sorted hash map of the transactions
- costcap uint64 // Price of the highest costing transaction (reset only if exceeds balance)
- gascap uint64 // Gas limit of the highest spending transaction (reset only if exceeds block limit)
+ costcap *big.Int // Price of the highest costing transaction (reset only if exceeds balance)
+ gascap uint64 // Gas limit of the highest spending transaction (reset only if exceeds block limit)
}
// newTxList create a new transaction list for maintaining nonce-indexable fast,
// gapped, sortable transaction lists.
func newTxList(strict bool) *txList {
return &txList{
- strict: strict,
- txs: newTxSortedMap(),
+ strict: strict,
+ txs: newTxSortedMap(),
+ costcap: new(big.Int),
}
}
@@ -279,11 +252,7 @@ func (l *txList) Add(tx *types.Transaction, priceBump uint64) (bool, *types.Tran
// If there's an older better transaction, abort
old := l.txs.Get(tx.Nonce())
if old != nil {
- // threshold = oldGP * (100 + priceBump) / 100
- a := big.NewInt(100 + int64(priceBump))
- a = a.Mul(a, old.GasPrice())
- b := big.NewInt(100)
- threshold := a.Div(a, b)
+ threshold := new(big.Int).Div(new(big.Int).Mul(old.GasPrice(), big.NewInt(100+int64(priceBump))), big.NewInt(100))
// Have to ensure that the new gas price is higher than the old gas
// price as well as checking the percentage threshold to ensure that
// this is accurate for low (Wei-level) gas price replacements
@@ -291,14 +260,9 @@ func (l *txList) Add(tx *types.Transaction, priceBump uint64) (bool, *types.Tran
return false, nil
}
}
- cost, overflow := tx.CostU64()
- if overflow {
- log.Warn("transaction cost overflown, txHash: %v txCost: %v", tx.Hash(), cost)
- return false, nil
- }
// Otherwise overwrite the old transaction with the current one
l.txs.Put(tx)
- if l.costcap < cost {
+ if cost := tx.Cost(); l.costcap.Cmp(cost) < 0 {
l.costcap = cost
}
if gas := tx.Gas(); l.gascap < gas {
@@ -323,35 +287,29 @@ func (l *txList) Forward(threshold uint64) types.Transactions {
// a point in calculating all the costs or if the balance covers all. If the threshold
// is lower than the costgas cap, the caps will be reset to a new high after removing
// the newly invalidated transactions.
-func (l *txList) Filter(costLimit uint64, gasLimit uint64) (types.Transactions, types.Transactions) {
+func (l *txList) Filter(costLimit *big.Int, gasLimit uint64) (types.Transactions, types.Transactions) {
// If all transactions are below the threshold, short circuit
- if l.costcap <= costLimit && l.gascap <= gasLimit {
+ if l.costcap.Cmp(costLimit) <= 0 && l.gascap <= gasLimit {
return nil, nil
}
- l.costcap = costLimit // Lower the caps to the thresholds
+ l.costcap = new(big.Int).Set(costLimit) // Lower the caps to the thresholds
l.gascap = gasLimit
// Filter out all the transactions above the account's funds
- removed := l.txs.filter(func(tx *types.Transaction) bool {
- cost, _ := tx.CostU64()
- return cost > costLimit || tx.Gas() > gasLimit
- })
+ removed := l.txs.Filter(func(tx *types.Transaction) bool { return tx.Cost().Cmp(costLimit) > 0 || tx.Gas() > gasLimit })
- if len(removed) == 0 {
- return nil, nil
- }
- var invalids types.Transactions
// If the list was strict, filter anything above the lowest nonce
- if l.strict {
+ var invalids types.Transactions
+
+ if l.strict && len(removed) > 0 {
lowest := uint64(math.MaxUint64)
for _, tx := range removed {
if nonce := tx.Nonce(); lowest > nonce {
lowest = nonce
}
}
- invalids = l.txs.filter(func(tx *types.Transaction) bool { return tx.Nonce() > lowest })
+ invalids = l.txs.Filter(func(tx *types.Transaction) bool { return tx.Nonce() > lowest })
}
- l.txs.reheap()
return removed, invalids
}
@@ -405,12 +363,6 @@ func (l *txList) Flatten() types.Transactions {
return l.txs.Flatten()
}
-// LastElement returns the last element of a flattened list, thus, the
-// transaction with the highest nonce
-func (l *txList) LastElement() *types.Transaction {
- return l.txs.LastElement()
-}
-
// priceHeap is a heap.Interface implementation over transactions for retrieving
// price-sorted transactions to discard when the pool fills up.
type priceHeap []*types.Transaction
@@ -543,29 +495,8 @@ func (l *txPricedList) Underpriced(tx *types.Transaction, local *accountSet) boo
// Discard finds a number of most underpriced transactions, removes them from the
// priced list and returns them for further removal from the entire pool.
func (l *txPricedList) Discard(slots int, local *accountSet) types.Transactions {
- // If we have some local accountset, those will not be discarded
- if !local.empty() {
- // In case the list is filled to the brim with 'local' txs, we do this
- // little check to avoid unpacking / repacking the heap later on, which
- // is very expensive
- discardable := 0
- for _, tx := range *l.items {
- if !local.containsTx(tx) {
- discardable++
- }
- if discardable >= slots {
- break
- }
- }
- if slots > discardable {
- slots = discardable
- }
- }
- if slots == 0 {
- return nil
- }
- drop := make(types.Transactions, 0, slots) // Remote underpriced transactions to drop
- save := make(types.Transactions, 0, len(*l.items)-slots) // Local underpriced transactions to keep
+ drop := make(types.Transactions, 0, slots) // Remote underpriced transactions to drop
+ save := make(types.Transactions, 0, 64) // Local underpriced transactions to keep
for len(*l.items) > 0 && slots > 0 {
// Discard stale transactions if found during cleanup
diff --git a/core/tx_list_test.go b/core/tx_list_test.go
index d9f4eba267..3a5842d2e8 100644
--- a/core/tx_list_test.go
+++ b/core/tx_list_test.go
@@ -17,6 +17,7 @@
package core
import (
+ "math/big"
"math/rand"
"testing"
@@ -50,22 +51,20 @@ func TestStrictTxListAdd(t *testing.T) {
}
}
-func BenchmarkTxListAdd(b *testing.B) {
+func BenchmarkTxListAdd(t *testing.B) {
// Generate a list of transactions to insert
key, _ := crypto.GenerateKey()
- txs := make(types.Transactions, 2000)
+ txs := make(types.Transactions, 100000)
for i := 0; i < len(txs); i++ {
txs[i] = transaction(uint64(i), 0, key)
}
// Insert the transactions in a random order
- b.ResetTimer()
- priceLimit := DefaultTxPoolConfig.PriceLimit
- for i := 0; i < b.N; i++ {
- list := newTxList(true)
- for _, v := range rand.Perm(len(txs)) {
- list.Add(txs[v], DefaultTxPoolConfig.PriceBump)
- list.Filter(priceLimit, DefaultTxPoolConfig.PriceBump)
- }
+ list := newTxList(true)
+ priceLimit := big.NewInt(int64(DefaultTxPoolConfig.PriceLimit))
+ t.ResetTimer()
+ for _, v := range rand.Perm(len(txs)) {
+ list.Add(txs[v], DefaultTxPoolConfig.PriceBump)
+ list.Filter(priceLimit, DefaultTxPoolConfig.PriceBump)
}
}
diff --git a/core/tx_pool.go b/core/tx_pool.go
index 6ad1679cfc..e707f0becd 100644
--- a/core/tx_pool.go
+++ b/core/tx_pool.go
@@ -550,11 +550,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error {
}
// Transactor should have enough funds to cover the costs
// cost == V + GP * GL
- cost, overflow := tx.CostU64()
- if overflow {
- return ErrInsufficientFunds
- }
- if pool.currentState.GetBalance(from).Uint64() < cost {
+ if pool.currentState.GetBalance(from).Cmp(tx.Cost()) < 0 {
return ErrInsufficientFunds
}
// Ensure the transaction has more gas than the basic tx fee.
@@ -1070,8 +1066,8 @@ func (pool *TxPool) runReorg(done chan struct{}, reset *txpoolResetRequest, dirt
// Update all accounts to the latest known pending nonce
for addr, list := range pool.pending {
- highestPending := list.LastElement()
- pool.pendingNonces.set(addr, highestPending.Nonce()+1)
+ txs := list.Flatten() // Heavy but will be cached and is needed by the miner anyway
+ pool.pendingNonces.set(addr, txs[len(txs)-1].Nonce()+1)
}
pool.mu.Unlock()
@@ -1201,7 +1197,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) []*types.Trans
}
log.Trace("Removed old queued transactions", "count", len(forwards))
// Drop all transactions that are too costly (low balance or out of gas)
- drops, _ := list.Filter(pool.currentState.GetBalance(addr).Uint64(), pool.currentMaxGas)
+ drops, _ := list.Filter(pool.currentState.GetBalance(addr), pool.currentMaxGas)
for _, tx := range drops {
hash := tx.Hash()
pool.all.Remove(hash)
@@ -1393,7 +1389,7 @@ func (pool *TxPool) demoteUnexecutables() {
log.Trace("Removed old pending transaction", "hash", hash)
}
// Drop all transactions that are too costly (low balance or out of gas), and queue any invalids back for later
- drops, invalids := list.Filter(pool.currentState.GetBalance(addr).Uint64(), pool.currentMaxGas)
+ drops, invalids := list.Filter(pool.currentState.GetBalance(addr), pool.currentMaxGas)
for _, tx := range drops {
hash := tx.Hash()
log.Trace("Removed unpayable pending transaction", "hash", hash)
@@ -1468,10 +1464,6 @@ func (as *accountSet) contains(addr common.Address) bool {
return exist
}
-func (as *accountSet) empty() bool {
- return len(as.accounts) == 0
-}
-
// containsTx checks if the sender of a given tx is within the set. If the sender
// cannot be derived, this method returns false.
func (as *accountSet) containsTx(tx *types.Transaction) bool {
diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go
index 36b83e82ec..95ade74252 100644
--- a/core/tx_pool_test.go
+++ b/core/tx_pool_test.go
@@ -1891,15 +1891,11 @@ func benchmarkFuturePromotion(b *testing.B, size int) {
}
// Benchmarks the speed of batched transaction insertion.
-func BenchmarkPoolBatchInsert100(b *testing.B) { benchmarkPoolBatchInsert(b, 100, false) }
-func BenchmarkPoolBatchInsert1000(b *testing.B) { benchmarkPoolBatchInsert(b, 1000, false) }
-func BenchmarkPoolBatchInsert10000(b *testing.B) { benchmarkPoolBatchInsert(b, 10000, false) }
+func BenchmarkPoolBatchInsert100(b *testing.B) { benchmarkPoolBatchInsert(b, 100) }
+func BenchmarkPoolBatchInsert1000(b *testing.B) { benchmarkPoolBatchInsert(b, 1000) }
+func BenchmarkPoolBatchInsert10000(b *testing.B) { benchmarkPoolBatchInsert(b, 10000) }
-func BenchmarkPoolBatchLocalInsert100(b *testing.B) { benchmarkPoolBatchInsert(b, 100, true) }
-func BenchmarkPoolBatchLocalInsert1000(b *testing.B) { benchmarkPoolBatchInsert(b, 1000, true) }
-func BenchmarkPoolBatchLocalInsert10000(b *testing.B) { benchmarkPoolBatchInsert(b, 10000, true) }
-
-func benchmarkPoolBatchInsert(b *testing.B, size int, local bool) {
+func benchmarkPoolBatchInsert(b *testing.B, size int) {
// Generate a batch of transactions to enqueue into the pool
pool, key := setupTxPool()
defer pool.Stop()
@@ -1917,10 +1913,6 @@ func benchmarkPoolBatchInsert(b *testing.B, size int, local bool) {
// Benchmark importing the transactions into the queue
b.ResetTimer()
for _, batch := range batches {
- if local {
- pool.AddLocals(batch)
- } else {
- pool.AddRemotes(batch)
- }
+ pool.AddRemotes(batch)
}
}
diff --git a/core/types/transaction.go b/core/types/transaction.go
index 347db2da0b..da691bb03f 100644
--- a/core/types/transaction.go
+++ b/core/types/transaction.go
@@ -25,7 +25,6 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
- "github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp"
)
@@ -42,7 +41,6 @@ type Transaction struct {
hash atomic.Value
size atomic.Value
from atomic.Value
- cost atomic.Value
}
type txdata struct {
@@ -260,15 +258,6 @@ func (tx *Transaction) Cost() *big.Int {
return total
}
-func (tx *Transaction) CostU64() (uint64, bool) {
- if tx.data.Price.BitLen() > 63 || tx.data.Amount.BitLen() > 63 {
- return 0, false
- }
- cost, overflowMul := math.SafeMul(tx.data.Price.Uint64(), tx.data.GasLimit)
- total, overflowAdd := math.SafeAdd(cost, tx.data.Amount.Uint64())
- return total, overflowMul || overflowAdd
-}
-
// RawSignatureValues returns the V, R, S signature values of the transaction.
// The return values should not be modified by the caller.
func (tx *Transaction) RawSignatureValues() (v, r, s *big.Int) {
diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go
index 0d35ff7082..9fde6f196b 100644
--- a/eth/downloader/downloader.go
+++ b/eth/downloader/downloader.go
@@ -330,7 +330,9 @@ func (d *Downloader) Synchronise(id string, head common.Hash, td *big.Int, mode
return err
}
- if errors.Is(err, errInvalidChain) {
+ if errors.Is(err, errInvalidChain) || errors.Is(err, errBadPeer) || errors.Is(err, errTimeout) ||
+ errors.Is(err, errStallingPeer) || errors.Is(err, errUnsyncedPeer) || errors.Is(err, errEmptyHeaderSet) ||
+ errors.Is(err, errPeersUnavailable) || errors.Is(err, errTooOld) || errors.Is(err, errInvalidAncestor) {
log.Warn("Synchronisation failed, dropping peer", "peer", id, "err", err)
if d.dropPeer == nil {
// The dropPeer method is nil when `--copydb` is used for a local copy.
@@ -341,22 +343,7 @@ func (d *Downloader) Synchronise(id string, head common.Hash, td *big.Int, mode
}
return err
}
-
- switch err {
- case errTimeout, errBadPeer, errStallingPeer, errUnsyncedPeer,
- errEmptyHeaderSet, errPeersUnavailable, errTooOld,
- errInvalidAncestor:
- log.Warn("Synchronisation failed, dropping peer", "peer", id, "err", err)
- if d.dropPeer == nil {
- // The dropPeer method is nil when `--copydb` is used for a local copy.
- // Timeouts can occur if e.g. compaction hits at the wrong time, and can be ignored
- log.Warn("Downloader wants to drop peer, but peerdrop-function is not set", "peer", id)
- } else {
- d.dropPeer(id)
- }
- default:
- log.Warn("Synchronisation failed, retrying", "err", err)
- }
+ log.Warn("Synchronisation failed, retrying", "err", err)
return err
}
@@ -651,7 +638,7 @@ func (d *Downloader) fetchHeight(p *peerConnection) (*types.Header, error) {
headers := packet.(*headerPack).headers
if len(headers) != 1 {
p.log.Debug("Multiple headers for single request", "headers", len(headers))
- return nil, errBadPeer
+ return nil, fmt.Errorf("%w: multiple headers (%d) for single request", errBadPeer, len(headers))
}
head := headers[0]
if (mode == FastSync || mode == LightSync) && head.Number.Uint64() < d.checkpoint {
@@ -921,7 +908,7 @@ func (d *Downloader) findAncestorBinarySearch(p *peerConnection, remoteHeight ui
headers := packer.(*headerPack).headers
if len(headers) != 1 {
p.log.Debug("Multiple headers for single request", "headers", len(headers))
- return 0, errBadPeer
+ return 0, fmt.Errorf("%w: multiple headers (%d) for single request", errBadPeer, len(headers))
}
arrived = true
@@ -947,7 +934,7 @@ func (d *Downloader) findAncestorBinarySearch(p *peerConnection, remoteHeight ui
header := d.lightchain.GetHeaderByHash(hash) // Independent of sync mode, header surely exists
if header.Number.Uint64() != check {
p.log.Debug("Received non requested header", "number", header.Number, "hash", header.Hash(), "request", check)
- return 0, errBadPeer
+ return 0, fmt.Errorf("%w: non-requested header (%d)", errBadPeer, header.Number)
}
start = check
@@ -1138,7 +1125,7 @@ func (d *Downloader) fetchHeaders(p *peerConnection, from uint64, pivot uint64)
case d.headerProcCh <- nil:
case <-d.cancelCh:
}
- return errBadPeer
+ return fmt.Errorf("%w: header request timed out", errBadPeer)
}
}
}
@@ -1566,7 +1553,7 @@ func (d *Downloader) processHeaders(origin uint64, pivot uint64, td *big.Int) er
inserts := d.queue.Schedule(chunk, origin)
if len(inserts) != len(chunk) {
log.Debug("Stale headers")
- return errBadPeer
+ return fmt.Errorf("%w: stale headers", errBadPeer)
}
}
headers = headers[limit:]
diff --git a/eth/downloader/statesync.go b/eth/downloader/statesync.go
index f875b3a84c..b022617bbc 100644
--- a/eth/downloader/statesync.go
+++ b/eth/downloader/statesync.go
@@ -63,6 +63,10 @@ func (d *Downloader) syncState(root common.Hash) *stateSync {
s := newStateSync(d, root)
select {
case d.stateSyncStart <- s:
+ // If we tell the statesync to restart with a new root, we also need
+ // to wait for it to actually also start -- when old requests have timed
+ // out or been delivered
+ <-s.started
case <-d.quitCh:
s.err = errCancelStateFetch
close(s.done)
@@ -95,15 +99,9 @@ func (d *Downloader) runStateSync(s *stateSync) *stateSync {
finished []*stateReq // Completed or failed requests
timeout = make(chan *stateReq) // Timed out active requests
)
- defer func() {
- // Cancel active request timers on exit. Also set peers to idle so they're
- // available for the next sync.
- for _, req := range active {
- req.timer.Stop()
- req.peer.SetNodeDataIdle(len(req.items))
- }
- }()
+
// Run the state sync.
+ log.Trace("State sync starting", "root", s.root)
go s.run()
defer s.Cancel()
@@ -126,9 +124,11 @@ func (d *Downloader) runStateSync(s *stateSync) *stateSync {
select {
// The stateSync lifecycle:
case next := <-d.stateSyncStart:
+ d.spindownStateSync(active, finished, timeout, peerDrop)
return next
case <-s.done:
+ d.spindownStateSync(active, finished, timeout, peerDrop)
return nil
// Send the next finished request to the current sync:
@@ -189,11 +189,9 @@ func (d *Downloader) runStateSync(s *stateSync) *stateSync {
// causes valid requests to go missing and sync to get stuck.
if old := active[req.peer.id]; old != nil {
log.Warn("Busy peer assigned new state fetch", "peer", old.peer.id)
-
- // Make sure the previous one doesn't get siletly lost
+ // Move the previous request to the finished set
old.timer.Stop()
old.dropped = true
-
finished = append(finished, old)
}
// Start a timer to notify the sync loop if the peer stalled.
@@ -210,6 +208,46 @@ func (d *Downloader) runStateSync(s *stateSync) *stateSync {
}
}
+// spindownStateSync 'drains' the outstanding requests; some will be delivered and other
+// will time out. This is to ensure that when the next stateSync starts working, all peers
+// are marked as idle and de facto _are_ idle.
+func (d *Downloader) spindownStateSync(active map[string]*stateReq, finished []*stateReq, timeout chan *stateReq, peerDrop chan *peerConnection) {
+ log.Trace("State sync spinning down", "active", len(active), "finished", len(finished))
+
+ for len(active) > 0 {
+ var (
+ req *stateReq
+ reason string
+ )
+ select {
+ // Handle (drop) incoming state packs:
+ case pack := <-d.stateCh:
+ req = active[pack.PeerId()]
+ reason = "delivered"
+ // Handle dropped peer connections:
+ case p := <-peerDrop:
+ req = active[p.id]
+ reason = "peerdrop"
+ // Handle timed-out requests:
+ case req = <-timeout:
+ reason = "timeout"
+ }
+ if req == nil {
+ continue
+ }
+ req.peer.log.Trace("State peer marked idle (spindown)", "req.items", len(req.items), "reason", reason)
+ req.timer.Stop()
+ delete(active, req.peer.id)
+ req.peer.SetNodeDataIdle(len(req.items))
+ }
+ // The 'finished' set contains deliveries that we were going to pass to processing.
+ // Those are now moot, but we still need to set those peers as idle, which would
+ // otherwise have been done after processing
+ for _, req := range finished {
+ req.peer.SetNodeDataIdle(len(req.items))
+ }
+}
+
// stateSync schedules requests for downloading a particular state trie defined
// by a given state root.
type stateSync struct {
@@ -222,11 +260,15 @@ type stateSync struct {
numUncommitted int
bytesUncommitted int
+ started chan struct{} // Started is signalled once the sync loop starts
+
deliver chan *stateReq // Delivery channel multiplexing peer responses
cancel chan struct{} // Channel to signal a termination request
cancelOnce sync.Once // Ensures cancel only ever gets called once
done chan struct{} // Channel to signal termination completion
err error // Any error hit during sync (set before completion)
+
+ root common.Hash
}
// stateTask represents a single trie node download task, containing a set of
@@ -246,6 +288,8 @@ func newStateSync(d *Downloader, root common.Hash) *stateSync {
deliver: make(chan *stateReq),
cancel: make(chan struct{}),
done: make(chan struct{}),
+ started: make(chan struct{}),
+ root: root,
}
}
@@ -276,6 +320,7 @@ func (s *stateSync) Cancel() error {
// pushed here async. The reason is to decouple processing from data receipt
// and timeouts.
func (s *stateSync) loop() (err error) {
+ close(s.started)
// Listen for new peer events to assign tasks to them
newPeer := make(chan *peerConnection, 1024)
peerSub := s.d.peers.SubscribeNewPeers(newPeer)
@@ -331,11 +376,11 @@ func (s *stateSync) loop() (err error) {
}
// Process all the received blobs and check for stale delivery
delivered, err := s.process(req)
+ req.peer.SetNodeDataIdle(delivered)
if err != nil {
log.Warn("Node data write error", "err", err)
return err
}
- req.peer.SetNodeDataIdle(delivered)
}
}
return nil
@@ -372,7 +417,7 @@ func (s *stateSync) assignTasks() {
// If the peer was assigned tasks to fetch, send the network request
if len(req.items) > 0 {
- req.peer.log.Trace("Requesting new batch of data", "type", "state", "count", len(req.items))
+ req.peer.log.Trace("Requesting new batch of data", "type", "state", "count", len(req.items), "root", s.root)
select {
case s.d.trackStateReq <- req:
req.peer.FetchNodeData(req.items)
diff --git a/eth/handler_test.go b/eth/handler_test.go
index c508f04f0c..b40cced79f 100644
--- a/eth/handler_test.go
+++ b/eth/handler_test.go
@@ -617,13 +617,16 @@ func testBroadcastBlock(t *testing.T, totalPeers, broadcastExpected int) {
select {
case <-doneCh:
received++
-
- case <-time.After(time.Second):
+ if received > broadcastExpected {
+ // We can bail early here
+ t.Errorf("broadcast count mismatch: have %d > want %d", received, broadcastExpected)
+ return
+ }
+ case <-time.After(2 * time.Second):
if received != broadcastExpected {
t.Errorf("broadcast count mismatch: have %d, want %d", received, broadcastExpected)
}
return
-
case err = <-errCh:
t.Fatalf("broadcast failed: %v", err)
}
diff --git a/go.mod b/go.mod
index f05369ad0c..40e011203c 100644
--- a/go.mod
+++ b/go.mod
@@ -27,7 +27,7 @@ require (
github.com/go-stack/stack v1.8.0
github.com/go-test/deep v1.0.5
github.com/golang/protobuf v1.3.2
- github.com/golang/snappy v0.0.2-0.20190904063534-ff6b7dc882cf
+ github.com/golang/snappy v0.0.2-0.20200707131729-196ae77b8a26
github.com/google/go-cmp v0.3.1 // indirect
github.com/gorilla/websocket v1.4.2
github.com/graph-gophers/graphql-go v0.0.0-20191115155744-f33e81362277
diff --git a/go.sum b/go.sum
index 56d10bcff8..5b34a7b1a4 100644
--- a/go.sum
+++ b/go.sum
@@ -126,10 +126,9 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
-github.com/golang/snappy v0.0.2-0.20190904063534-ff6b7dc882cf h1:gFVkHXmVAhEbxZVDln5V9GKrLaluNoFHDbrZwAWZgws=
-github.com/golang/snappy v0.0.2-0.20190904063534-ff6b7dc882cf/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
+github.com/golang/snappy v0.0.2-0.20200707131729-196ae77b8a26 h1:lMm2hD9Fy0ynom5+85/pbdkiYcBqM1JWmhpAXLmy0fw=
+github.com/golang/snappy v0.0.2-0.20200707131729-196ae77b8a26/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
@@ -400,9 +399,7 @@ golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAG
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20181108010431-42b317875d0f h1:Bl/8QSvNqXvPGPGXa2z5xUTmV7VDcZyvRZ+QQXkXTZQ=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU=
diff --git a/internal/utesting/utesting.go b/internal/utesting/utesting.go
new file mode 100644
index 0000000000..23c748cae9
--- /dev/null
+++ b/internal/utesting/utesting.go
@@ -0,0 +1,190 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// Package utesting provides a standalone replacement for package testing.
+//
+// This package exists because package testing cannot easily be embedded into a
+// standalone go program. It provides an API that mirrors the standard library
+// testing API.
+package utesting
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "regexp"
+ "runtime"
+ "sync"
+ "time"
+)
+
+// Test represents a single test.
+type Test struct {
+ Name string
+ Fn func(*T)
+}
+
+// Result is the result of a test execution.
+type Result struct {
+ Name string
+ Failed bool
+ Output string
+ Duration time.Duration
+}
+
+// MatchTests returns the tests whose name matches a regular expression.
+func MatchTests(tests []Test, expr string) []Test {
+ var results []Test
+ re, err := regexp.Compile(expr)
+ if err != nil {
+ return nil
+ }
+ for _, test := range tests {
+ if re.MatchString(test.Name) {
+ results = append(results, test)
+ }
+ }
+ return results
+}
+
+// RunTests executes all given tests in order and returns their results.
+// If the report writer is non-nil, a test report is written to it in real time.
+func RunTests(tests []Test, report io.Writer) []Result {
+ results := make([]Result, len(tests))
+ for i, test := range tests {
+ start := time.Now()
+ results[i].Name = test.Name
+ results[i].Failed, results[i].Output = Run(test)
+ results[i].Duration = time.Since(start)
+ if report != nil {
+ printResult(results[i], report)
+ }
+ }
+ return results
+}
+
+func printResult(r Result, w io.Writer) {
+ pd := r.Duration.Truncate(100 * time.Microsecond)
+ if r.Failed {
+ fmt.Fprintf(w, "-- FAIL %s (%v)\n", r.Name, pd)
+ fmt.Fprintln(w, r.Output)
+ } else {
+ fmt.Fprintf(w, "-- OK %s (%v)\n", r.Name, pd)
+ }
+}
+
+// CountFailures returns the number of failed tests in the result slice.
+func CountFailures(rr []Result) int {
+ count := 0
+ for _, r := range rr {
+ if r.Failed {
+ count++
+ }
+ }
+ return count
+}
+
+// Run executes a single test.
+func Run(test Test) (bool, string) {
+ t := new(T)
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ defer func() {
+ if err := recover(); err != nil {
+ buf := make([]byte, 4096)
+ i := runtime.Stack(buf, false)
+ t.Logf("panic: %v\n\n%s", err, buf[:i])
+ t.Fail()
+ }
+ }()
+ test.Fn(t)
+ }()
+ <-done
+ return t.failed, t.output.String()
+}
+
+// T is the value given to the test function. The test can signal failures
+// and log output by calling methods on this object.
+type T struct {
+ mu sync.Mutex
+ failed bool
+ output bytes.Buffer
+}
+
+// FailNow marks the test as having failed and stops its execution by calling
+// runtime.Goexit (which then runs all deferred calls in the current goroutine).
+func (t *T) FailNow() {
+ t.Fail()
+ runtime.Goexit()
+}
+
+// Fail marks the test as having failed but continues execution.
+func (t *T) Fail() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.failed = true
+}
+
+// Failed reports whether the test has failed.
+func (t *T) Failed() bool {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.failed
+}
+
+// Log formats its arguments using default formatting, analogous to Println, and records
+// the text in the error log.
+func (t *T) Log(vs ...interface{}) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ fmt.Fprintln(&t.output, vs...)
+}
+
+// Logf formats its arguments according to the format, analogous to Printf, and records
+// the text in the error log. A final newline is added if not provided.
+func (t *T) Logf(format string, vs ...interface{}) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if len(format) == 0 || format[len(format)-1] != '\n' {
+ format += "\n"
+ }
+ fmt.Fprintf(&t.output, format, vs...)
+}
+
+// Error is equivalent to Log followed by Fail.
+func (t *T) Error(vs ...interface{}) {
+ t.Log(vs...)
+ t.Fail()
+}
+
+// Errorf is equivalent to Logf followed by Fail.
+func (t *T) Errorf(format string, vs ...interface{}) {
+ t.Logf(format, vs...)
+ t.Fail()
+}
+
+// Fatal is equivalent to Log followed by FailNow.
+func (t *T) Fatal(vs ...interface{}) {
+ t.Log(vs...)
+ t.FailNow()
+}
+
+// Fatalf is equivalent to Logf followed by FailNow.
+func (t *T) Fatalf(format string, vs ...interface{}) {
+ t.Logf(format, vs...)
+ t.FailNow()
+}
diff --git a/internal/utesting/utesting_test.go b/internal/utesting/utesting_test.go
new file mode 100644
index 0000000000..1403a5c8f7
--- /dev/null
+++ b/internal/utesting/utesting_test.go
@@ -0,0 +1,55 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package utesting
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestTest(t *testing.T) {
+ tests := []Test{
+ {
+ Name: "successful test",
+ Fn: func(t *T) {},
+ },
+ {
+ Name: "failing test",
+ Fn: func(t *T) {
+ t.Log("output")
+ t.Error("failed")
+ },
+ },
+ {
+ Name: "panicking test",
+ Fn: func(t *T) {
+ panic("oh no")
+ },
+ },
+ }
+ results := RunTests(tests, nil)
+
+ if results[0].Failed || results[0].Output != "" {
+ t.Fatalf("wrong result for successful test: %#v", results[0])
+ }
+ if !results[1].Failed || results[1].Output != "output\nfailed\n" {
+ t.Fatalf("wrong result for failing test: %#v", results[1])
+ }
+ if !results[2].Failed || !strings.HasPrefix(results[2].Output, "panic: oh no\n") {
+ t.Fatalf("wrong result for panicking test: %#v", results[2])
+ }
+}