diff --git a/cmd/clef/README.md b/cmd/clef/README.md index 1180f369b6..bffa36fcce 100644 --- a/cmd/clef/README.md +++ b/cmd/clef/README.md @@ -115,9 +115,9 @@ Some snags and todos Clef listens to HTTP requests on `rpcaddr`:`rpcport` (or to IPC on `ipcpath`), with the same JSON-RPC standard as Geth. The messages are expected to be [JSON-RPC 2.0 standard](https://www.jsonrpc.org/specification). -Some of these call can require user interaction. Clients must be aware that responses may be delayed significantly or may never be received if a users decides to ignore the confirmation request. +Some of these calls can require user interaction. Clients must be aware that responses may be delayed significantly or may never be received if a user decides to ignore the confirmation request. -The External API is **untrusted**: it does not accept credentials over this API, nor does it expect that requests have any authority. +The External API is **untrusted**: it does not accept credentials, nor does it expect that requests have any authority. ### Internal UI API @@ -172,9 +172,9 @@ None Response ```json { - "id": 0, - "jsonrpc": "2.0", - "result": "0xbea9183f8f4f03d427f6bcea17388bdff1cab133" + "id": 0, + "jsonrpc": "2.0", + "result": "0xbea9183f8f4f03d427f6bcea17388bdff1cab133" } ``` @@ -370,7 +370,7 @@ Response ### account_signTypedData #### Sign data - Signs a chunk of structured data conformant to [EIP712]([EIP-712](https://github.com/ethereum/EIPs/blob/master/EIPS/eip-712.md)) and returns the calculated signature. + Signs a chunk of structured data conformant to [EIP-712](https://github.com/ethereum/EIPs/blob/master/EIPS/eip-712.md) and returns the calculated signature. #### Arguments - account [address]: account to sign with diff --git a/cmd/devp2p/discv4cmd.go b/cmd/devp2p/discv4cmd.go index 8580c61216..99b0957ab3 100644 --- a/cmd/devp2p/discv4cmd.go +++ b/cmd/devp2p/discv4cmd.go @@ -19,11 +19,14 @@ package main import ( "fmt" "net" + "os" "strings" "time" + "github.com/ethereum/go-ethereum/cmd/devp2p/internal/v4test" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/internal/utesting" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/params" @@ -40,6 +43,7 @@ var ( discv4ResolveCommand, discv4ResolveJSONCommand, discv4CrawlCommand, + discv4TestCommand, }, } discv4PingCommand = cli.Command{ @@ -74,6 +78,12 @@ var ( Action: discv4Crawl, Flags: []cli.Flag{bootnodesFlag, crawlTimeoutFlag}, } + discv4TestCommand = cli.Command{ + Name: "test", + Usage: "Runs tests against a node", + Action: discv4Test, + Flags: []cli.Flag{remoteEnodeFlag, testPatternFlag, testListen1Flag, testListen2Flag}, + } ) var ( @@ -98,6 +108,25 @@ var ( Usage: "Time limit for the crawl.", Value: 30 * time.Minute, } + remoteEnodeFlag = cli.StringFlag{ + Name: "remote", + Usage: "Enode of the remote node under test", + EnvVar: "REMOTE_ENODE", + } + testPatternFlag = cli.StringFlag{ + Name: "run", + Usage: "Pattern of test suite(s) to run", + } + testListen1Flag = cli.StringFlag{ + Name: "listen1", + Usage: "IP address of the first tester", + Value: v4test.Listen1, + } + testListen2Flag = cli.StringFlag{ + Name: "listen2", + Usage: "IP address of the second tester", + Value: v4test.Listen2, + } ) func discv4Ping(ctx *cli.Context) error { @@ -184,6 +213,28 @@ func discv4Crawl(ctx *cli.Context) error { return nil } +func discv4Test(ctx *cli.Context) error { + // Configure test package globals. + if !ctx.IsSet(remoteEnodeFlag.Name) { + return fmt.Errorf("Missing -%v", remoteEnodeFlag.Name) + } + v4test.Remote = ctx.String(remoteEnodeFlag.Name) + v4test.Listen1 = ctx.String(testListen1Flag.Name) + v4test.Listen2 = ctx.String(testListen2Flag.Name) + + // Filter and run test cases. + tests := v4test.AllTests + if ctx.IsSet(testPatternFlag.Name) { + tests = utesting.MatchTests(tests, ctx.String(testPatternFlag.Name)) + } + results := utesting.RunTests(tests, os.Stdout) + if fails := utesting.CountFailures(results); fails > 0 { + return fmt.Errorf("%v/%v tests passed.", len(tests)-fails, len(tests)) + } + fmt.Printf("%v/%v passed\n", len(tests), len(tests)) + return nil +} + // startV4 starts an ephemeral discovery V4 node. func startV4(ctx *cli.Context) *discover.UDPv4 { ln, config := makeDiscoveryConfig(ctx) diff --git a/cmd/devp2p/internal/v4test/discv4tests.go b/cmd/devp2p/internal/v4test/discv4tests.go new file mode 100644 index 0000000000..140b96bfa5 --- /dev/null +++ b/cmd/devp2p/internal/v4test/discv4tests.go @@ -0,0 +1,467 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package v4test + +import ( + "bytes" + "crypto/rand" + "fmt" + "net" + "reflect" + "time" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/internal/utesting" + "github.com/ethereum/go-ethereum/p2p/discover/v4wire" +) + +const ( + expiration = 20 * time.Second + wrongPacket = 66 + macSize = 256 / 8 +) + +var ( + // Remote node under test + Remote string + // IP where the first tester is listening, port will be assigned + Listen1 string = "127.0.0.1" + // IP where the second tester is listening, port will be assigned + // Before running the test, you may have to `sudo ifconfig lo0 add 127.0.0.2` (on MacOS at least) + Listen2 string = "127.0.0.2" +) + +type pingWithJunk struct { + Version uint + From, To v4wire.Endpoint + Expiration uint64 + JunkData1 uint + JunkData2 []byte +} + +func (req *pingWithJunk) Name() string { return "PING/v4" } +func (req *pingWithJunk) Kind() byte { return v4wire.PingPacket } + +type pingWrongType struct { + Version uint + From, To v4wire.Endpoint + Expiration uint64 +} + +func (req *pingWrongType) Name() string { return "WRONG/v4" } +func (req *pingWrongType) Kind() byte { return wrongPacket } + +func futureExpiration() uint64 { + return uint64(time.Now().Add(expiration).Unix()) +} + +// This test just sends a PING packet and expects a response. +func BasicPing(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + pingHash := te.send(te.l1, &v4wire.Ping{ + Version: 4, + From: te.localEndpoint(te.l1), + To: te.remoteEndpoint(), + Expiration: futureExpiration(), + }) + + reply, _, _ := te.read(te.l1) + if err := te.checkPong(reply, pingHash); err != nil { + t.Fatal(err) + } +} + +// checkPong verifies that reply is a valid PONG matching the given ping hash. +func (te *testenv) checkPong(reply v4wire.Packet, pingHash []byte) error { + if reply == nil || reply.Kind() != v4wire.PongPacket { + return fmt.Errorf("expected PONG reply, got %v", reply) + } + pong := reply.(*v4wire.Pong) + if !bytes.Equal(pong.ReplyTok, pingHash) { + return fmt.Errorf("PONG reply token mismatch: got %x, want %x", pong.ReplyTok, pingHash) + } + wantEndpoint := te.localEndpoint(te.l1) + if !reflect.DeepEqual(pong.To, wantEndpoint) { + return fmt.Errorf("PONG 'to' endpoint mismatch: got %+v, want %+v", pong.To, wantEndpoint) + } + if v4wire.Expired(pong.Expiration) { + return fmt.Errorf("PONG is expired (%v)", pong.Expiration) + } + return nil +} + +// This test sends a PING packet with wrong 'to' field and expects a PONG response. +func PingWrongTo(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")} + pingHash := te.send(te.l1, &v4wire.Ping{ + Version: 4, + From: te.localEndpoint(te.l1), + To: wrongEndpoint, + Expiration: futureExpiration(), + }) + + reply, _, _ := te.read(te.l1) + if err := te.checkPong(reply, pingHash); err != nil { + t.Fatal(err) + } +} + +// This test sends a PING packet with wrong 'from' field and expects a PONG response. +func PingWrongFrom(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")} + pingHash := te.send(te.l1, &v4wire.Ping{ + Version: 4, + From: wrongEndpoint, + To: te.remoteEndpoint(), + Expiration: futureExpiration(), + }) + + reply, _, _ := te.read(te.l1) + if err := te.checkPong(reply, pingHash); err != nil { + t.Fatal(err) + } +} + +// This test sends a PING packet with additional data at the end and expects a PONG +// response. The remote node should respond because EIP-8 mandates ignoring additional +// trailing data. +func PingExtraData(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + pingHash := te.send(te.l1, &pingWithJunk{ + Version: 4, + From: te.localEndpoint(te.l1), + To: te.remoteEndpoint(), + Expiration: futureExpiration(), + JunkData1: 42, + JunkData2: []byte{9, 8, 7, 6, 5, 4, 3, 2, 1}, + }) + + reply, _, _ := te.read(te.l1) + if err := te.checkPong(reply, pingHash); err != nil { + t.Fatal(err) + } +} + +// This test sends a PING packet with additional data and wrong 'from' field +// and expects a PONG response. +func PingExtraDataWrongFrom(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")} + req := pingWithJunk{ + Version: 4, + From: wrongEndpoint, + To: te.remoteEndpoint(), + Expiration: futureExpiration(), + JunkData1: 42, + JunkData2: []byte{9, 8, 7, 6, 5, 4, 3, 2, 1}, + } + pingHash := te.send(te.l1, &req) + reply, _, _ := te.read(te.l1) + if err := te.checkPong(reply, pingHash); err != nil { + t.Fatal(err) + } +} + +// This test sends a PING packet with an expiration in the past. +// The remote node should not respond. +func PingPastExpiration(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + te.send(te.l1, &v4wire.Ping{ + Version: 4, + From: te.localEndpoint(te.l1), + To: te.remoteEndpoint(), + Expiration: -futureExpiration(), + }) + + reply, _, _ := te.read(te.l1) + if reply != nil { + t.Fatal("Expected no reply, got", reply) + } +} + +// This test sends an invalid packet. The remote node should not respond. +func WrongPacketType(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + te.send(te.l1, &pingWrongType{ + Version: 4, + From: te.localEndpoint(te.l1), + To: te.remoteEndpoint(), + Expiration: futureExpiration(), + }) + + reply, _, _ := te.read(te.l1) + if reply != nil { + t.Fatal("Expected no reply, got", reply) + } +} + +// This test verifies that the default behaviour of ignoring 'from' fields is unaffected by +// the bonding process. After bonding, it pings the target with a different from endpoint. +func BondThenPingWithWrongFrom(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + bond(t, te) + + wrongEndpoint := v4wire.Endpoint{IP: net.ParseIP("192.0.2.0")} + pingHash := te.send(te.l1, &v4wire.Ping{ + Version: 4, + From: wrongEndpoint, + To: te.remoteEndpoint(), + Expiration: futureExpiration(), + }) + + reply, _, _ := te.read(te.l1) + if err := te.checkPong(reply, pingHash); err != nil { + t.Fatal(err) + } +} + +// This test just sends FINDNODE. The remote node should not reply +// because the endpoint proof has not completed. +func FindnodeWithoutEndpointProof(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + req := v4wire.Findnode{Expiration: futureExpiration()} + rand.Read(req.Target[:]) + te.send(te.l1, &req) + + reply, _, _ := te.read(te.l1) + if reply != nil { + t.Fatal("Expected no response, got", reply) + } +} + +// BasicFindnode sends a FINDNODE request after performing the endpoint +// proof. The remote node should respond. +func BasicFindnode(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + bond(t, te) + + findnode := v4wire.Findnode{Expiration: futureExpiration()} + rand.Read(findnode.Target[:]) + te.send(te.l1, &findnode) + + reply, _, err := te.read(te.l1) + if err != nil { + t.Fatal("read find nodes", err) + } + if reply.Kind() != v4wire.NeighborsPacket { + t.Fatal("Expected neighbors, got", reply.Name()) + } +} + +// This test sends an unsolicited NEIGHBORS packet after the endpoint proof, then sends +// FINDNODE to read the remote table. The remote node should not return the node contained +// in the unsolicited NEIGHBORS packet. +func UnsolicitedNeighbors(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + bond(t, te) + + // Send unsolicited NEIGHBORS response. + fakeKey, _ := crypto.GenerateKey() + encFakeKey := v4wire.EncodePubkey(&fakeKey.PublicKey) + neighbors := v4wire.Neighbors{ + Expiration: futureExpiration(), + Nodes: []v4wire.Node{{ + ID: encFakeKey, + IP: net.IP{1, 2, 3, 4}, + UDP: 30303, + TCP: 30303, + }}, + } + te.send(te.l1, &neighbors) + + // Check if the remote node included the fake node. + te.send(te.l1, &v4wire.Findnode{ + Expiration: futureExpiration(), + Target: encFakeKey, + }) + + reply, _, err := te.read(te.l1) + if err != nil { + t.Fatal("read find nodes", err) + } + if reply.Kind() != v4wire.NeighborsPacket { + t.Fatal("Expected neighbors, got", reply.Name()) + } + nodes := reply.(*v4wire.Neighbors).Nodes + if contains(nodes, encFakeKey) { + t.Fatal("neighbors response contains node from earlier unsolicited neighbors response") + } +} + +// This test sends FINDNODE with an expiration timestamp in the past. +// The remote node should not respond. +func FindnodePastExpiration(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + bond(t, te) + + findnode := v4wire.Findnode{Expiration: -futureExpiration()} + rand.Read(findnode.Target[:]) + te.send(te.l1, &findnode) + + for { + reply, _, _ := te.read(te.l1) + if reply == nil { + return + } else if reply.Kind() == v4wire.NeighborsPacket { + t.Fatal("Unexpected NEIGHBORS response for expired FINDNODE request") + } + } +} + +// bond performs the endpoint proof with the remote node. +func bond(t *utesting.T, te *testenv) { + te.send(te.l1, &v4wire.Ping{ + Version: 4, + From: te.localEndpoint(te.l1), + To: te.remoteEndpoint(), + Expiration: futureExpiration(), + }) + + var gotPing, gotPong bool + for !gotPing || !gotPong { + req, hash, err := te.read(te.l1) + if err != nil { + t.Fatal(err) + } + switch req.(type) { + case *v4wire.Ping: + te.send(te.l1, &v4wire.Pong{ + To: te.remoteEndpoint(), + ReplyTok: hash, + Expiration: futureExpiration(), + }) + gotPing = true + case *v4wire.Pong: + // TODO: maybe verify pong data here + gotPong = true + } + } +} + +// This test attempts to perform a traffic amplification attack against a +// 'victim' endpoint using FINDNODE. In this attack scenario, the attacker +// attempts to complete the endpoint proof non-interactively by sending a PONG +// with mismatching reply token from the 'victim' endpoint. The attack works if +// the remote node does not verify the PONG reply token field correctly. The +// attacker could then perform traffic amplification by sending many FINDNODE +// requests to the discovery node, which would reply to the 'victim' address. +func FindnodeAmplificationInvalidPongHash(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + // Send PING to start endpoint verification. + te.send(te.l1, &v4wire.Ping{ + Version: 4, + From: te.localEndpoint(te.l1), + To: te.remoteEndpoint(), + Expiration: futureExpiration(), + }) + + var gotPing, gotPong bool + for !gotPing || !gotPong { + req, _, err := te.read(te.l1) + if err != nil { + t.Fatal(err) + } + switch req.(type) { + case *v4wire.Ping: + // Send PONG from this node ID, but with invalid ReplyTok. + te.send(te.l1, &v4wire.Pong{ + To: te.remoteEndpoint(), + ReplyTok: make([]byte, macSize), + Expiration: futureExpiration(), + }) + gotPing = true + case *v4wire.Pong: + gotPong = true + } + } + + // Now send FINDNODE. The remote node should not respond because our + // PONG did not reference the PING hash. + findnode := v4wire.Findnode{Expiration: futureExpiration()} + rand.Read(findnode.Target[:]) + te.send(te.l1, &findnode) + + // If we receive a NEIGHBORS response, the attack worked and the test fails. + reply, _, _ := te.read(te.l1) + if reply != nil && reply.Kind() == v4wire.NeighborsPacket { + t.Error("Got neighbors") + } +} + +// This test attempts to perform a traffic amplification attack using FINDNODE. +// The attack works if the remote node does not verify the IP address of FINDNODE +// against the endpoint verification proof done by PING/PONG. +func FindnodeAmplificationWrongIP(t *utesting.T) { + te := newTestEnv(Remote, Listen1, Listen2) + defer te.close() + + // Do the endpoint proof from the l1 IP. + bond(t, te) + + // Now send FINDNODE from the same node ID, but different IP address. + // The remote node should not respond. + findnode := v4wire.Findnode{Expiration: futureExpiration()} + rand.Read(findnode.Target[:]) + te.send(te.l2, &findnode) + + // If we receive a NEIGHBORS response, the attack worked and the test fails. + reply, _, _ := te.read(te.l2) + if reply != nil { + t.Error("Got NEIGHORS response for FINDNODE from wrong IP") + } +} + +var AllTests = []utesting.Test{ + {Name: "Ping/Basic", Fn: BasicPing}, + {Name: "Ping/WrongTo", Fn: PingWrongTo}, + {Name: "Ping/WrongFrom", Fn: PingWrongFrom}, + {Name: "Ping/ExtraData", Fn: PingExtraData}, + {Name: "Ping/ExtraDataWrongFrom", Fn: PingExtraDataWrongFrom}, + {Name: "Ping/PastExpiration", Fn: PingPastExpiration}, + {Name: "Ping/WrongPacketType", Fn: WrongPacketType}, + {Name: "Ping/BondThenPingWithWrongFrom", Fn: BondThenPingWithWrongFrom}, + {Name: "Findnode/WithoutEndpointProof", Fn: FindnodeWithoutEndpointProof}, + {Name: "Findnode/BasicFindnode", Fn: BasicFindnode}, + {Name: "Findnode/UnsolicitedNeighbors", Fn: UnsolicitedNeighbors}, + {Name: "Findnode/PastExpiration", Fn: FindnodePastExpiration}, + {Name: "Amplification/InvalidPongHash", Fn: FindnodeAmplificationInvalidPongHash}, + {Name: "Amplification/WrongIP", Fn: FindnodeAmplificationWrongIP}, +} diff --git a/cmd/devp2p/internal/v4test/framework.go b/cmd/devp2p/internal/v4test/framework.go new file mode 100644 index 0000000000..9286594181 --- /dev/null +++ b/cmd/devp2p/internal/v4test/framework.go @@ -0,0 +1,123 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package v4test + +import ( + "crypto/ecdsa" + "fmt" + "net" + "time" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/discover/v4wire" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +const waitTime = 300 * time.Millisecond + +type testenv struct { + l1, l2 net.PacketConn + key *ecdsa.PrivateKey + remote *enode.Node + remoteAddr *net.UDPAddr +} + +func newTestEnv(remote string, listen1, listen2 string) *testenv { + l1, err := net.ListenPacket("udp", fmt.Sprintf("%v:0", listen1)) + if err != nil { + panic(err) + } + l2, err := net.ListenPacket("udp", fmt.Sprintf("%v:0", listen2)) + if err != nil { + panic(err) + } + key, err := crypto.GenerateKey() + if err != nil { + panic(err) + } + node, err := enode.Parse(enode.ValidSchemes, remote) + if err != nil { + panic(err) + } + if node.IP() == nil || node.UDP() == 0 { + var ip net.IP + var tcpPort, udpPort int + if ip = node.IP(); ip == nil { + ip = net.ParseIP("127.0.0.1") + } + if tcpPort = node.TCP(); tcpPort == 0 { + tcpPort = 30303 + } + if udpPort = node.TCP(); udpPort == 0 { + udpPort = 30303 + } + node = enode.NewV4(node.Pubkey(), ip, tcpPort, udpPort) + } + addr := &net.UDPAddr{IP: node.IP(), Port: node.UDP()} + return &testenv{l1, l2, key, node, addr} +} + +func (te *testenv) close() { + te.l1.Close() + te.l2.Close() +} + +func (te *testenv) send(c net.PacketConn, req v4wire.Packet) []byte { + packet, hash, err := v4wire.Encode(te.key, req) + if err != nil { + panic(fmt.Errorf("can't encode %v packet: %v", req.Name(), err)) + } + if _, err := c.WriteTo(packet, te.remoteAddr); err != nil { + panic(fmt.Errorf("can't send %v: %v", req.Name(), err)) + } + return hash +} + +func (te *testenv) read(c net.PacketConn) (v4wire.Packet, []byte, error) { + buf := make([]byte, 2048) + if err := c.SetReadDeadline(time.Now().Add(waitTime)); err != nil { + return nil, nil, err + } + n, _, err := c.ReadFrom(buf) + if err != nil { + return nil, nil, err + } + p, _, hash, err := v4wire.Decode(buf[:n]) + return p, hash, err +} + +func (te *testenv) localEndpoint(c net.PacketConn) v4wire.Endpoint { + addr := c.LocalAddr().(*net.UDPAddr) + return v4wire.Endpoint{ + IP: addr.IP.To4(), + UDP: uint16(addr.Port), + TCP: 0, + } +} + +func (te *testenv) remoteEndpoint() v4wire.Endpoint { + return v4wire.NewEndpoint(te.remoteAddr, 0) +} + +func contains(ns []v4wire.Node, key v4wire.Pubkey) bool { + for _, n := range ns { + if n.ID == key { + return true + } + } + return false +} diff --git a/common/math/integer.go b/common/math/integer.go index 46d91ab2a4..93b1d036dd 100644 --- a/common/math/integer.go +++ b/common/math/integer.go @@ -18,7 +18,6 @@ package math import ( "fmt" - "math/bits" "strconv" ) @@ -88,12 +87,13 @@ func SafeSub(x, y uint64) (uint64, bool) { // SafeAdd returns the result and whether overflow occurred. func SafeAdd(x, y uint64) (uint64, bool) { - sum, carry := bits.Add64(x, y, 0) - return sum, carry != 0 + return x + y, y > MaxUint64-x } // SafeMul returns multiplication result and whether overflow occurred. func SafeMul(x, y uint64) (uint64, bool) { - hi, lo := bits.Mul64(x, y) - return lo, hi != 0 + if x == 0 || y == 0 { + return 0, false + } + return x * y, y > MaxUint64/x } diff --git a/core/tx_list.go b/core/tx_list.go index 8beb28bba9..164c73006b 100644 --- a/core/tx_list.go +++ b/core/tx_list.go @@ -99,30 +99,7 @@ func (m *txSortedMap) Forward(threshold uint64) types.Transactions { // Filter iterates over the list of transactions and removes all of them for which // the specified function evaluates to true. -// Filter, as opposed to 'filter', re-initialises the heap after the operation is done. -// If you want to do several consecutive filterings, it's therefore better to first -// do a .filter(func1) followed by .Filter(func2) or reheap() func (m *txSortedMap) Filter(filter func(*types.Transaction) bool) types.Transactions { - removed := m.filter(filter) - // If transactions were removed, the heap and cache are ruined - if len(removed) > 0 { - m.reheap() - } - return removed -} - -func (m *txSortedMap) reheap() { - *m.index = make([]uint64, 0, len(m.items)) - for nonce := range m.items { - *m.index = append(*m.index, nonce) - } - heap.Init(m.index) - m.cache = nil -} - -// filter is identical to Filter, but **does not** regenerate the heap. This method -// should only be used if followed immediately by a call to Filter or reheap() -func (m *txSortedMap) filter(filter func(*types.Transaction) bool) types.Transactions { var removed types.Transactions // Collect all the transactions to filter out @@ -132,7 +109,14 @@ func (m *txSortedMap) filter(filter func(*types.Transaction) bool) types.Transac delete(m.items, nonce) } } + // If transactions were removed, the heap and cache are ruined if len(removed) > 0 { + *m.index = make([]uint64, 0, len(m.items)) + for nonce := range m.items { + *m.index = append(*m.index, nonce) + } + heap.Init(m.index) + m.cache = nil } return removed @@ -213,7 +197,10 @@ func (m *txSortedMap) Len() int { return len(m.items) } -func (m *txSortedMap) flatten() types.Transactions { +// Flatten creates a nonce-sorted slice of transactions based on the loosely +// sorted internal representation. The result of the sorting is cached in case +// it's requested again before any modifications are made to the contents. +func (m *txSortedMap) Flatten() types.Transactions { // If the sorting was not cached yet, create and cache it if m.cache == nil { m.cache = make(types.Transactions, 0, len(m.items)) @@ -222,27 +209,12 @@ func (m *txSortedMap) flatten() types.Transactions { } sort.Sort(types.TxByNonce(m.cache)) } - return m.cache -} - -// Flatten creates a nonce-sorted slice of transactions based on the loosely -// sorted internal representation. The result of the sorting is cached in case -// it's requested again before any modifications are made to the contents. -func (m *txSortedMap) Flatten() types.Transactions { // Copy the cache to prevent accidental modifications - cache := m.flatten() - txs := make(types.Transactions, len(cache)) - copy(txs, cache) + txs := make(types.Transactions, len(m.cache)) + copy(txs, m.cache) return txs } -// LastElement returns the last element of a flattened list, thus, the -// transaction with the highest nonce -func (m *txSortedMap) LastElement() *types.Transaction { - cache := m.flatten() - return cache[len(cache)-1] -} - // txList is a "list" of transactions belonging to an account, sorted by account // nonce. The same type can be used both for storing contiguous transactions for // the executable/pending queue; and for storing gapped transactions for the non- @@ -251,16 +223,17 @@ type txList struct { strict bool // Whether nonces are strictly continuous or not txs *txSortedMap // Heap indexed sorted hash map of the transactions - costcap uint64 // Price of the highest costing transaction (reset only if exceeds balance) - gascap uint64 // Gas limit of the highest spending transaction (reset only if exceeds block limit) + costcap *big.Int // Price of the highest costing transaction (reset only if exceeds balance) + gascap uint64 // Gas limit of the highest spending transaction (reset only if exceeds block limit) } // newTxList create a new transaction list for maintaining nonce-indexable fast, // gapped, sortable transaction lists. func newTxList(strict bool) *txList { return &txList{ - strict: strict, - txs: newTxSortedMap(), + strict: strict, + txs: newTxSortedMap(), + costcap: new(big.Int), } } @@ -279,11 +252,7 @@ func (l *txList) Add(tx *types.Transaction, priceBump uint64) (bool, *types.Tran // If there's an older better transaction, abort old := l.txs.Get(tx.Nonce()) if old != nil { - // threshold = oldGP * (100 + priceBump) / 100 - a := big.NewInt(100 + int64(priceBump)) - a = a.Mul(a, old.GasPrice()) - b := big.NewInt(100) - threshold := a.Div(a, b) + threshold := new(big.Int).Div(new(big.Int).Mul(old.GasPrice(), big.NewInt(100+int64(priceBump))), big.NewInt(100)) // Have to ensure that the new gas price is higher than the old gas // price as well as checking the percentage threshold to ensure that // this is accurate for low (Wei-level) gas price replacements @@ -291,14 +260,9 @@ func (l *txList) Add(tx *types.Transaction, priceBump uint64) (bool, *types.Tran return false, nil } } - cost, overflow := tx.CostU64() - if overflow { - log.Warn("transaction cost overflown, txHash: %v txCost: %v", tx.Hash(), cost) - return false, nil - } // Otherwise overwrite the old transaction with the current one l.txs.Put(tx) - if l.costcap < cost { + if cost := tx.Cost(); l.costcap.Cmp(cost) < 0 { l.costcap = cost } if gas := tx.Gas(); l.gascap < gas { @@ -323,35 +287,29 @@ func (l *txList) Forward(threshold uint64) types.Transactions { // a point in calculating all the costs or if the balance covers all. If the threshold // is lower than the costgas cap, the caps will be reset to a new high after removing // the newly invalidated transactions. -func (l *txList) Filter(costLimit uint64, gasLimit uint64) (types.Transactions, types.Transactions) { +func (l *txList) Filter(costLimit *big.Int, gasLimit uint64) (types.Transactions, types.Transactions) { // If all transactions are below the threshold, short circuit - if l.costcap <= costLimit && l.gascap <= gasLimit { + if l.costcap.Cmp(costLimit) <= 0 && l.gascap <= gasLimit { return nil, nil } - l.costcap = costLimit // Lower the caps to the thresholds + l.costcap = new(big.Int).Set(costLimit) // Lower the caps to the thresholds l.gascap = gasLimit // Filter out all the transactions above the account's funds - removed := l.txs.filter(func(tx *types.Transaction) bool { - cost, _ := tx.CostU64() - return cost > costLimit || tx.Gas() > gasLimit - }) + removed := l.txs.Filter(func(tx *types.Transaction) bool { return tx.Cost().Cmp(costLimit) > 0 || tx.Gas() > gasLimit }) - if len(removed) == 0 { - return nil, nil - } - var invalids types.Transactions // If the list was strict, filter anything above the lowest nonce - if l.strict { + var invalids types.Transactions + + if l.strict && len(removed) > 0 { lowest := uint64(math.MaxUint64) for _, tx := range removed { if nonce := tx.Nonce(); lowest > nonce { lowest = nonce } } - invalids = l.txs.filter(func(tx *types.Transaction) bool { return tx.Nonce() > lowest }) + invalids = l.txs.Filter(func(tx *types.Transaction) bool { return tx.Nonce() > lowest }) } - l.txs.reheap() return removed, invalids } @@ -405,12 +363,6 @@ func (l *txList) Flatten() types.Transactions { return l.txs.Flatten() } -// LastElement returns the last element of a flattened list, thus, the -// transaction with the highest nonce -func (l *txList) LastElement() *types.Transaction { - return l.txs.LastElement() -} - // priceHeap is a heap.Interface implementation over transactions for retrieving // price-sorted transactions to discard when the pool fills up. type priceHeap []*types.Transaction @@ -543,29 +495,8 @@ func (l *txPricedList) Underpriced(tx *types.Transaction, local *accountSet) boo // Discard finds a number of most underpriced transactions, removes them from the // priced list and returns them for further removal from the entire pool. func (l *txPricedList) Discard(slots int, local *accountSet) types.Transactions { - // If we have some local accountset, those will not be discarded - if !local.empty() { - // In case the list is filled to the brim with 'local' txs, we do this - // little check to avoid unpacking / repacking the heap later on, which - // is very expensive - discardable := 0 - for _, tx := range *l.items { - if !local.containsTx(tx) { - discardable++ - } - if discardable >= slots { - break - } - } - if slots > discardable { - slots = discardable - } - } - if slots == 0 { - return nil - } - drop := make(types.Transactions, 0, slots) // Remote underpriced transactions to drop - save := make(types.Transactions, 0, len(*l.items)-slots) // Local underpriced transactions to keep + drop := make(types.Transactions, 0, slots) // Remote underpriced transactions to drop + save := make(types.Transactions, 0, 64) // Local underpriced transactions to keep for len(*l.items) > 0 && slots > 0 { // Discard stale transactions if found during cleanup diff --git a/core/tx_list_test.go b/core/tx_list_test.go index d9f4eba267..3a5842d2e8 100644 --- a/core/tx_list_test.go +++ b/core/tx_list_test.go @@ -17,6 +17,7 @@ package core import ( + "math/big" "math/rand" "testing" @@ -50,22 +51,20 @@ func TestStrictTxListAdd(t *testing.T) { } } -func BenchmarkTxListAdd(b *testing.B) { +func BenchmarkTxListAdd(t *testing.B) { // Generate a list of transactions to insert key, _ := crypto.GenerateKey() - txs := make(types.Transactions, 2000) + txs := make(types.Transactions, 100000) for i := 0; i < len(txs); i++ { txs[i] = transaction(uint64(i), 0, key) } // Insert the transactions in a random order - b.ResetTimer() - priceLimit := DefaultTxPoolConfig.PriceLimit - for i := 0; i < b.N; i++ { - list := newTxList(true) - for _, v := range rand.Perm(len(txs)) { - list.Add(txs[v], DefaultTxPoolConfig.PriceBump) - list.Filter(priceLimit, DefaultTxPoolConfig.PriceBump) - } + list := newTxList(true) + priceLimit := big.NewInt(int64(DefaultTxPoolConfig.PriceLimit)) + t.ResetTimer() + for _, v := range rand.Perm(len(txs)) { + list.Add(txs[v], DefaultTxPoolConfig.PriceBump) + list.Filter(priceLimit, DefaultTxPoolConfig.PriceBump) } } diff --git a/core/tx_pool.go b/core/tx_pool.go index 6ad1679cfc..e707f0becd 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -550,11 +550,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { } // Transactor should have enough funds to cover the costs // cost == V + GP * GL - cost, overflow := tx.CostU64() - if overflow { - return ErrInsufficientFunds - } - if pool.currentState.GetBalance(from).Uint64() < cost { + if pool.currentState.GetBalance(from).Cmp(tx.Cost()) < 0 { return ErrInsufficientFunds } // Ensure the transaction has more gas than the basic tx fee. @@ -1070,8 +1066,8 @@ func (pool *TxPool) runReorg(done chan struct{}, reset *txpoolResetRequest, dirt // Update all accounts to the latest known pending nonce for addr, list := range pool.pending { - highestPending := list.LastElement() - pool.pendingNonces.set(addr, highestPending.Nonce()+1) + txs := list.Flatten() // Heavy but will be cached and is needed by the miner anyway + pool.pendingNonces.set(addr, txs[len(txs)-1].Nonce()+1) } pool.mu.Unlock() @@ -1201,7 +1197,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) []*types.Trans } log.Trace("Removed old queued transactions", "count", len(forwards)) // Drop all transactions that are too costly (low balance or out of gas) - drops, _ := list.Filter(pool.currentState.GetBalance(addr).Uint64(), pool.currentMaxGas) + drops, _ := list.Filter(pool.currentState.GetBalance(addr), pool.currentMaxGas) for _, tx := range drops { hash := tx.Hash() pool.all.Remove(hash) @@ -1393,7 +1389,7 @@ func (pool *TxPool) demoteUnexecutables() { log.Trace("Removed old pending transaction", "hash", hash) } // Drop all transactions that are too costly (low balance or out of gas), and queue any invalids back for later - drops, invalids := list.Filter(pool.currentState.GetBalance(addr).Uint64(), pool.currentMaxGas) + drops, invalids := list.Filter(pool.currentState.GetBalance(addr), pool.currentMaxGas) for _, tx := range drops { hash := tx.Hash() log.Trace("Removed unpayable pending transaction", "hash", hash) @@ -1468,10 +1464,6 @@ func (as *accountSet) contains(addr common.Address) bool { return exist } -func (as *accountSet) empty() bool { - return len(as.accounts) == 0 -} - // containsTx checks if the sender of a given tx is within the set. If the sender // cannot be derived, this method returns false. func (as *accountSet) containsTx(tx *types.Transaction) bool { diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index 36b83e82ec..95ade74252 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -1891,15 +1891,11 @@ func benchmarkFuturePromotion(b *testing.B, size int) { } // Benchmarks the speed of batched transaction insertion. -func BenchmarkPoolBatchInsert100(b *testing.B) { benchmarkPoolBatchInsert(b, 100, false) } -func BenchmarkPoolBatchInsert1000(b *testing.B) { benchmarkPoolBatchInsert(b, 1000, false) } -func BenchmarkPoolBatchInsert10000(b *testing.B) { benchmarkPoolBatchInsert(b, 10000, false) } +func BenchmarkPoolBatchInsert100(b *testing.B) { benchmarkPoolBatchInsert(b, 100) } +func BenchmarkPoolBatchInsert1000(b *testing.B) { benchmarkPoolBatchInsert(b, 1000) } +func BenchmarkPoolBatchInsert10000(b *testing.B) { benchmarkPoolBatchInsert(b, 10000) } -func BenchmarkPoolBatchLocalInsert100(b *testing.B) { benchmarkPoolBatchInsert(b, 100, true) } -func BenchmarkPoolBatchLocalInsert1000(b *testing.B) { benchmarkPoolBatchInsert(b, 1000, true) } -func BenchmarkPoolBatchLocalInsert10000(b *testing.B) { benchmarkPoolBatchInsert(b, 10000, true) } - -func benchmarkPoolBatchInsert(b *testing.B, size int, local bool) { +func benchmarkPoolBatchInsert(b *testing.B, size int) { // Generate a batch of transactions to enqueue into the pool pool, key := setupTxPool() defer pool.Stop() @@ -1917,10 +1913,6 @@ func benchmarkPoolBatchInsert(b *testing.B, size int, local bool) { // Benchmark importing the transactions into the queue b.ResetTimer() for _, batch := range batches { - if local { - pool.AddLocals(batch) - } else { - pool.AddRemotes(batch) - } + pool.AddRemotes(batch) } } diff --git a/core/types/transaction.go b/core/types/transaction.go index 347db2da0b..da691bb03f 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -25,7 +25,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" - "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" ) @@ -42,7 +41,6 @@ type Transaction struct { hash atomic.Value size atomic.Value from atomic.Value - cost atomic.Value } type txdata struct { @@ -260,15 +258,6 @@ func (tx *Transaction) Cost() *big.Int { return total } -func (tx *Transaction) CostU64() (uint64, bool) { - if tx.data.Price.BitLen() > 63 || tx.data.Amount.BitLen() > 63 { - return 0, false - } - cost, overflowMul := math.SafeMul(tx.data.Price.Uint64(), tx.data.GasLimit) - total, overflowAdd := math.SafeAdd(cost, tx.data.Amount.Uint64()) - return total, overflowMul || overflowAdd -} - // RawSignatureValues returns the V, R, S signature values of the transaction. // The return values should not be modified by the caller. func (tx *Transaction) RawSignatureValues() (v, r, s *big.Int) { diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index 0d35ff7082..9fde6f196b 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -330,7 +330,9 @@ func (d *Downloader) Synchronise(id string, head common.Hash, td *big.Int, mode return err } - if errors.Is(err, errInvalidChain) { + if errors.Is(err, errInvalidChain) || errors.Is(err, errBadPeer) || errors.Is(err, errTimeout) || + errors.Is(err, errStallingPeer) || errors.Is(err, errUnsyncedPeer) || errors.Is(err, errEmptyHeaderSet) || + errors.Is(err, errPeersUnavailable) || errors.Is(err, errTooOld) || errors.Is(err, errInvalidAncestor) { log.Warn("Synchronisation failed, dropping peer", "peer", id, "err", err) if d.dropPeer == nil { // The dropPeer method is nil when `--copydb` is used for a local copy. @@ -341,22 +343,7 @@ func (d *Downloader) Synchronise(id string, head common.Hash, td *big.Int, mode } return err } - - switch err { - case errTimeout, errBadPeer, errStallingPeer, errUnsyncedPeer, - errEmptyHeaderSet, errPeersUnavailable, errTooOld, - errInvalidAncestor: - log.Warn("Synchronisation failed, dropping peer", "peer", id, "err", err) - if d.dropPeer == nil { - // The dropPeer method is nil when `--copydb` is used for a local copy. - // Timeouts can occur if e.g. compaction hits at the wrong time, and can be ignored - log.Warn("Downloader wants to drop peer, but peerdrop-function is not set", "peer", id) - } else { - d.dropPeer(id) - } - default: - log.Warn("Synchronisation failed, retrying", "err", err) - } + log.Warn("Synchronisation failed, retrying", "err", err) return err } @@ -651,7 +638,7 @@ func (d *Downloader) fetchHeight(p *peerConnection) (*types.Header, error) { headers := packet.(*headerPack).headers if len(headers) != 1 { p.log.Debug("Multiple headers for single request", "headers", len(headers)) - return nil, errBadPeer + return nil, fmt.Errorf("%w: multiple headers (%d) for single request", errBadPeer, len(headers)) } head := headers[0] if (mode == FastSync || mode == LightSync) && head.Number.Uint64() < d.checkpoint { @@ -921,7 +908,7 @@ func (d *Downloader) findAncestorBinarySearch(p *peerConnection, remoteHeight ui headers := packer.(*headerPack).headers if len(headers) != 1 { p.log.Debug("Multiple headers for single request", "headers", len(headers)) - return 0, errBadPeer + return 0, fmt.Errorf("%w: multiple headers (%d) for single request", errBadPeer, len(headers)) } arrived = true @@ -947,7 +934,7 @@ func (d *Downloader) findAncestorBinarySearch(p *peerConnection, remoteHeight ui header := d.lightchain.GetHeaderByHash(hash) // Independent of sync mode, header surely exists if header.Number.Uint64() != check { p.log.Debug("Received non requested header", "number", header.Number, "hash", header.Hash(), "request", check) - return 0, errBadPeer + return 0, fmt.Errorf("%w: non-requested header (%d)", errBadPeer, header.Number) } start = check @@ -1138,7 +1125,7 @@ func (d *Downloader) fetchHeaders(p *peerConnection, from uint64, pivot uint64) case d.headerProcCh <- nil: case <-d.cancelCh: } - return errBadPeer + return fmt.Errorf("%w: header request timed out", errBadPeer) } } } @@ -1566,7 +1553,7 @@ func (d *Downloader) processHeaders(origin uint64, pivot uint64, td *big.Int) er inserts := d.queue.Schedule(chunk, origin) if len(inserts) != len(chunk) { log.Debug("Stale headers") - return errBadPeer + return fmt.Errorf("%w: stale headers", errBadPeer) } } headers = headers[limit:] diff --git a/eth/downloader/statesync.go b/eth/downloader/statesync.go index f875b3a84c..b022617bbc 100644 --- a/eth/downloader/statesync.go +++ b/eth/downloader/statesync.go @@ -63,6 +63,10 @@ func (d *Downloader) syncState(root common.Hash) *stateSync { s := newStateSync(d, root) select { case d.stateSyncStart <- s: + // If we tell the statesync to restart with a new root, we also need + // to wait for it to actually also start -- when old requests have timed + // out or been delivered + <-s.started case <-d.quitCh: s.err = errCancelStateFetch close(s.done) @@ -95,15 +99,9 @@ func (d *Downloader) runStateSync(s *stateSync) *stateSync { finished []*stateReq // Completed or failed requests timeout = make(chan *stateReq) // Timed out active requests ) - defer func() { - // Cancel active request timers on exit. Also set peers to idle so they're - // available for the next sync. - for _, req := range active { - req.timer.Stop() - req.peer.SetNodeDataIdle(len(req.items)) - } - }() + // Run the state sync. + log.Trace("State sync starting", "root", s.root) go s.run() defer s.Cancel() @@ -126,9 +124,11 @@ func (d *Downloader) runStateSync(s *stateSync) *stateSync { select { // The stateSync lifecycle: case next := <-d.stateSyncStart: + d.spindownStateSync(active, finished, timeout, peerDrop) return next case <-s.done: + d.spindownStateSync(active, finished, timeout, peerDrop) return nil // Send the next finished request to the current sync: @@ -189,11 +189,9 @@ func (d *Downloader) runStateSync(s *stateSync) *stateSync { // causes valid requests to go missing and sync to get stuck. if old := active[req.peer.id]; old != nil { log.Warn("Busy peer assigned new state fetch", "peer", old.peer.id) - - // Make sure the previous one doesn't get siletly lost + // Move the previous request to the finished set old.timer.Stop() old.dropped = true - finished = append(finished, old) } // Start a timer to notify the sync loop if the peer stalled. @@ -210,6 +208,46 @@ func (d *Downloader) runStateSync(s *stateSync) *stateSync { } } +// spindownStateSync 'drains' the outstanding requests; some will be delivered and other +// will time out. This is to ensure that when the next stateSync starts working, all peers +// are marked as idle and de facto _are_ idle. +func (d *Downloader) spindownStateSync(active map[string]*stateReq, finished []*stateReq, timeout chan *stateReq, peerDrop chan *peerConnection) { + log.Trace("State sync spinning down", "active", len(active), "finished", len(finished)) + + for len(active) > 0 { + var ( + req *stateReq + reason string + ) + select { + // Handle (drop) incoming state packs: + case pack := <-d.stateCh: + req = active[pack.PeerId()] + reason = "delivered" + // Handle dropped peer connections: + case p := <-peerDrop: + req = active[p.id] + reason = "peerdrop" + // Handle timed-out requests: + case req = <-timeout: + reason = "timeout" + } + if req == nil { + continue + } + req.peer.log.Trace("State peer marked idle (spindown)", "req.items", len(req.items), "reason", reason) + req.timer.Stop() + delete(active, req.peer.id) + req.peer.SetNodeDataIdle(len(req.items)) + } + // The 'finished' set contains deliveries that we were going to pass to processing. + // Those are now moot, but we still need to set those peers as idle, which would + // otherwise have been done after processing + for _, req := range finished { + req.peer.SetNodeDataIdle(len(req.items)) + } +} + // stateSync schedules requests for downloading a particular state trie defined // by a given state root. type stateSync struct { @@ -222,11 +260,15 @@ type stateSync struct { numUncommitted int bytesUncommitted int + started chan struct{} // Started is signalled once the sync loop starts + deliver chan *stateReq // Delivery channel multiplexing peer responses cancel chan struct{} // Channel to signal a termination request cancelOnce sync.Once // Ensures cancel only ever gets called once done chan struct{} // Channel to signal termination completion err error // Any error hit during sync (set before completion) + + root common.Hash } // stateTask represents a single trie node download task, containing a set of @@ -246,6 +288,8 @@ func newStateSync(d *Downloader, root common.Hash) *stateSync { deliver: make(chan *stateReq), cancel: make(chan struct{}), done: make(chan struct{}), + started: make(chan struct{}), + root: root, } } @@ -276,6 +320,7 @@ func (s *stateSync) Cancel() error { // pushed here async. The reason is to decouple processing from data receipt // and timeouts. func (s *stateSync) loop() (err error) { + close(s.started) // Listen for new peer events to assign tasks to them newPeer := make(chan *peerConnection, 1024) peerSub := s.d.peers.SubscribeNewPeers(newPeer) @@ -331,11 +376,11 @@ func (s *stateSync) loop() (err error) { } // Process all the received blobs and check for stale delivery delivered, err := s.process(req) + req.peer.SetNodeDataIdle(delivered) if err != nil { log.Warn("Node data write error", "err", err) return err } - req.peer.SetNodeDataIdle(delivered) } } return nil @@ -372,7 +417,7 @@ func (s *stateSync) assignTasks() { // If the peer was assigned tasks to fetch, send the network request if len(req.items) > 0 { - req.peer.log.Trace("Requesting new batch of data", "type", "state", "count", len(req.items)) + req.peer.log.Trace("Requesting new batch of data", "type", "state", "count", len(req.items), "root", s.root) select { case s.d.trackStateReq <- req: req.peer.FetchNodeData(req.items) diff --git a/eth/handler_test.go b/eth/handler_test.go index c508f04f0c..b40cced79f 100644 --- a/eth/handler_test.go +++ b/eth/handler_test.go @@ -617,13 +617,16 @@ func testBroadcastBlock(t *testing.T, totalPeers, broadcastExpected int) { select { case <-doneCh: received++ - - case <-time.After(time.Second): + if received > broadcastExpected { + // We can bail early here + t.Errorf("broadcast count mismatch: have %d > want %d", received, broadcastExpected) + return + } + case <-time.After(2 * time.Second): if received != broadcastExpected { t.Errorf("broadcast count mismatch: have %d, want %d", received, broadcastExpected) } return - case err = <-errCh: t.Fatalf("broadcast failed: %v", err) } diff --git a/go.mod b/go.mod index f05369ad0c..40e011203c 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,7 @@ require ( github.com/go-stack/stack v1.8.0 github.com/go-test/deep v1.0.5 github.com/golang/protobuf v1.3.2 - github.com/golang/snappy v0.0.2-0.20190904063534-ff6b7dc882cf + github.com/golang/snappy v0.0.2-0.20200707131729-196ae77b8a26 github.com/google/go-cmp v0.3.1 // indirect github.com/gorilla/websocket v1.4.2 github.com/graph-gophers/graphql-go v0.0.0-20191115155744-f33e81362277 diff --git a/go.sum b/go.sum index 56d10bcff8..5b34a7b1a4 100644 --- a/go.sum +++ b/go.sum @@ -126,10 +126,9 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.2-0.20190904063534-ff6b7dc882cf h1:gFVkHXmVAhEbxZVDln5V9GKrLaluNoFHDbrZwAWZgws= -github.com/golang/snappy v0.0.2-0.20190904063534-ff6b7dc882cf/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.2-0.20200707131729-196ae77b8a26 h1:lMm2hD9Fy0ynom5+85/pbdkiYcBqM1JWmhpAXLmy0fw= +github.com/golang/snappy v0.0.2-0.20200707131729-196ae77b8a26/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -400,9 +399,7 @@ golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAG golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f h1:Bl/8QSvNqXvPGPGXa2z5xUTmV7VDcZyvRZ+QQXkXTZQ= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU= diff --git a/internal/utesting/utesting.go b/internal/utesting/utesting.go new file mode 100644 index 0000000000..23c748cae9 --- /dev/null +++ b/internal/utesting/utesting.go @@ -0,0 +1,190 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package utesting provides a standalone replacement for package testing. +// +// This package exists because package testing cannot easily be embedded into a +// standalone go program. It provides an API that mirrors the standard library +// testing API. +package utesting + +import ( + "bytes" + "fmt" + "io" + "regexp" + "runtime" + "sync" + "time" +) + +// Test represents a single test. +type Test struct { + Name string + Fn func(*T) +} + +// Result is the result of a test execution. +type Result struct { + Name string + Failed bool + Output string + Duration time.Duration +} + +// MatchTests returns the tests whose name matches a regular expression. +func MatchTests(tests []Test, expr string) []Test { + var results []Test + re, err := regexp.Compile(expr) + if err != nil { + return nil + } + for _, test := range tests { + if re.MatchString(test.Name) { + results = append(results, test) + } + } + return results +} + +// RunTests executes all given tests in order and returns their results. +// If the report writer is non-nil, a test report is written to it in real time. +func RunTests(tests []Test, report io.Writer) []Result { + results := make([]Result, len(tests)) + for i, test := range tests { + start := time.Now() + results[i].Name = test.Name + results[i].Failed, results[i].Output = Run(test) + results[i].Duration = time.Since(start) + if report != nil { + printResult(results[i], report) + } + } + return results +} + +func printResult(r Result, w io.Writer) { + pd := r.Duration.Truncate(100 * time.Microsecond) + if r.Failed { + fmt.Fprintf(w, "-- FAIL %s (%v)\n", r.Name, pd) + fmt.Fprintln(w, r.Output) + } else { + fmt.Fprintf(w, "-- OK %s (%v)\n", r.Name, pd) + } +} + +// CountFailures returns the number of failed tests in the result slice. +func CountFailures(rr []Result) int { + count := 0 + for _, r := range rr { + if r.Failed { + count++ + } + } + return count +} + +// Run executes a single test. +func Run(test Test) (bool, string) { + t := new(T) + done := make(chan struct{}) + go func() { + defer close(done) + defer func() { + if err := recover(); err != nil { + buf := make([]byte, 4096) + i := runtime.Stack(buf, false) + t.Logf("panic: %v\n\n%s", err, buf[:i]) + t.Fail() + } + }() + test.Fn(t) + }() + <-done + return t.failed, t.output.String() +} + +// T is the value given to the test function. The test can signal failures +// and log output by calling methods on this object. +type T struct { + mu sync.Mutex + failed bool + output bytes.Buffer +} + +// FailNow marks the test as having failed and stops its execution by calling +// runtime.Goexit (which then runs all deferred calls in the current goroutine). +func (t *T) FailNow() { + t.Fail() + runtime.Goexit() +} + +// Fail marks the test as having failed but continues execution. +func (t *T) Fail() { + t.mu.Lock() + defer t.mu.Unlock() + t.failed = true +} + +// Failed reports whether the test has failed. +func (t *T) Failed() bool { + t.mu.Lock() + defer t.mu.Unlock() + return t.failed +} + +// Log formats its arguments using default formatting, analogous to Println, and records +// the text in the error log. +func (t *T) Log(vs ...interface{}) { + t.mu.Lock() + defer t.mu.Unlock() + fmt.Fprintln(&t.output, vs...) +} + +// Logf formats its arguments according to the format, analogous to Printf, and records +// the text in the error log. A final newline is added if not provided. +func (t *T) Logf(format string, vs ...interface{}) { + t.mu.Lock() + defer t.mu.Unlock() + if len(format) == 0 || format[len(format)-1] != '\n' { + format += "\n" + } + fmt.Fprintf(&t.output, format, vs...) +} + +// Error is equivalent to Log followed by Fail. +func (t *T) Error(vs ...interface{}) { + t.Log(vs...) + t.Fail() +} + +// Errorf is equivalent to Logf followed by Fail. +func (t *T) Errorf(format string, vs ...interface{}) { + t.Logf(format, vs...) + t.Fail() +} + +// Fatal is equivalent to Log followed by FailNow. +func (t *T) Fatal(vs ...interface{}) { + t.Log(vs...) + t.FailNow() +} + +// Fatalf is equivalent to Logf followed by FailNow. +func (t *T) Fatalf(format string, vs ...interface{}) { + t.Logf(format, vs...) + t.FailNow() +} diff --git a/internal/utesting/utesting_test.go b/internal/utesting/utesting_test.go new file mode 100644 index 0000000000..1403a5c8f7 --- /dev/null +++ b/internal/utesting/utesting_test.go @@ -0,0 +1,55 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package utesting + +import ( + "strings" + "testing" +) + +func TestTest(t *testing.T) { + tests := []Test{ + { + Name: "successful test", + Fn: func(t *T) {}, + }, + { + Name: "failing test", + Fn: func(t *T) { + t.Log("output") + t.Error("failed") + }, + }, + { + Name: "panicking test", + Fn: func(t *T) { + panic("oh no") + }, + }, + } + results := RunTests(tests, nil) + + if results[0].Failed || results[0].Output != "" { + t.Fatalf("wrong result for successful test: %#v", results[0]) + } + if !results[1].Failed || results[1].Output != "output\nfailed\n" { + t.Fatalf("wrong result for failing test: %#v", results[1]) + } + if !results[2].Failed || !strings.HasPrefix(results[2].Output, "panic: oh no\n") { + t.Fatalf("wrong result for panicking test: %#v", results[2]) + } +}