diff --git a/cmd/devp2p/internal/ethtest/helpers.go b/cmd/devp2p/internal/ethtest/helpers.go
index eeeb4f93cabf..b57649ade99d 100644
--- a/cmd/devp2p/internal/ethtest/helpers.go
+++ b/cmd/devp2p/internal/ethtest/helpers.go
@@ -357,9 +357,13 @@ func (s *Suite) waitAnnounce(conn *Conn, blockAnnouncement *NewBlock) error {
return fmt.Errorf("wrong block hash in announcement: expected %v, got %v", blockAnnouncement.Block.Hash(), hashes[0].Hash)
}
return nil
+
+ // ignore tx announcements from previous tests
case *NewPooledTransactionHashes:
- // ignore tx announcements from previous tests
continue
+ case *Transactions:
+ continue
+
default:
return fmt.Errorf("unexpected: %s", pretty.Sdump(msg))
}
diff --git a/cmd/devp2p/internal/ethtest/suite.go b/cmd/devp2p/internal/ethtest/suite.go
index 7059b4ba738c..4497478d72d6 100644
--- a/cmd/devp2p/internal/ethtest/suite.go
+++ b/cmd/devp2p/internal/ethtest/suite.go
@@ -544,9 +544,13 @@ func (s *Suite) TestNewPooledTxs(t *utesting.T) {
t.Fatalf("unexpected number of txs requested: wanted %d, got %d", len(hashes), len(msg.GetPooledTransactionsPacket))
}
return
+
// ignore propagated txs from previous tests
case *NewPooledTransactionHashes:
continue
+ case *Transactions:
+ continue
+
// ignore block announcements from previous tests
case *NewBlockHashes:
continue
diff --git a/cmd/devp2p/internal/ethtest/transaction.go b/cmd/devp2p/internal/ethtest/transaction.go
index c4748bf8f7d8..baa55bd49268 100644
--- a/cmd/devp2p/internal/ethtest/transaction.go
+++ b/cmd/devp2p/internal/ethtest/transaction.go
@@ -29,7 +29,7 @@ import (
"github.com/ethereum/go-ethereum/params"
)
-//var faucetAddr = common.HexToAddress("0x71562b71999873DB5b286dF957af199Ec94617F7")
+// var faucetAddr = common.HexToAddress("0x71562b71999873DB5b286dF957af199Ec94617F7")
var faucetKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
func (s *Suite) sendSuccessfulTxs(t *utesting.T) error {
@@ -192,10 +192,10 @@ func sendMultipleSuccessfulTxs(t *utesting.T, s *Suite, txs []*types.Transaction
nonce = txs[len(txs)-1].Nonce()
// Wait for the transaction announcement(s) and make sure all sent txs are being propagated.
- // all txs should be announced within 3 announcements.
+ // all txs should be announced within a couple announcements.
recvHashes := make([]common.Hash, 0)
- for i := 0; i < 3; i++ {
+ for i := 0; i < 20; i++ {
switch msg := recvConn.readAndServe(s.chain, timeout).(type) {
case *Transactions:
for _, tx := range *msg {
diff --git a/eth/fetcher/tx_fetcher.go b/eth/fetcher/tx_fetcher.go
index 035e0c2ec7d8..677a6422b011 100644
--- a/eth/fetcher/tx_fetcher.go
+++ b/eth/fetcher/tx_fetcher.go
@@ -262,54 +262,72 @@ func (f *TxFetcher) Notify(peer string, hashes []common.Hash) error {
// direct request replies. The differentiation is important so the fetcher can
// re-schedule missing transactions as soon as possible.
func (f *TxFetcher) Enqueue(peer string, txs []*types.Transaction, direct bool) error {
- // Keep track of all the propagated transactions
- if direct {
- txReplyInMeter.Mark(int64(len(txs)))
- } else {
- txBroadcastInMeter.Mark(int64(len(txs)))
+ var (
+ inMeter = txReplyInMeter
+ knownMeter = txReplyKnownMeter
+ underpricedMeter = txReplyUnderpricedMeter
+ otherRejectMeter = txReplyOtherRejectMeter
+ )
+ if !direct {
+ inMeter = txBroadcastInMeter
+ knownMeter = txBroadcastKnownMeter
+ underpricedMeter = txBroadcastUnderpricedMeter
+ otherRejectMeter = txBroadcastOtherRejectMeter
}
+ // Keep track of all the propagated transactions
+ inMeter.Mark(int64(len(txs)))
+
// Push all the transactions into the pool, tracking underpriced ones to avoid
// re-requesting them and dropping the peer in case of malicious transfers.
var (
- added = make([]common.Hash, 0, len(txs))
- duplicate int64
- underpriced int64
- otherreject int64
+ added = make([]common.Hash, 0, len(txs))
)
- errs := f.addTxs(txs)
- for i, err := range errs {
- // Track the transaction hash if the price is too low for us.
- // Avoid re-request this transaction when we receive another
- // announcement.
- if errors.Is(err, core.ErrUnderpriced) || errors.Is(err, core.ErrReplaceUnderpriced) {
- for f.underpriced.Cardinality() >= maxTxUnderpricedSetSize {
- f.underpriced.Pop()
- }
- f.underpriced.Add(txs[i].Hash())
+ // proceed in batches
+ for i := 0; i < len(txs); i += 128 {
+ end := i + 128
+ if end > len(txs) {
+ end = len(txs)
}
- // Track a few interesting failure types
- switch {
- case err == nil: // Noop, but need to handle to not count these
+ var (
+ duplicate int64
+ underpriced int64
+ otherreject int64
+ )
+ batch := txs[i:end]
+ for j, err := range f.addTxs(batch) {
+ // Track the transaction hash if the price is too low for us.
+ // Avoid re-request this transaction when we receive another
+ // announcement.
+ if errors.Is(err, core.ErrUnderpriced) || errors.Is(err, core.ErrReplaceUnderpriced) {
+ for f.underpriced.Cardinality() >= maxTxUnderpricedSetSize {
+ f.underpriced.Pop()
+ }
+ f.underpriced.Add(batch[j].Hash())
+ }
+ // Track a few interesting failure types
+ switch {
+ case err == nil: // Noop, but need to handle to not count these
- case errors.Is(err, core.ErrAlreadyKnown):
- duplicate++
+ case errors.Is(err, core.ErrAlreadyKnown):
+ duplicate++
- case errors.Is(err, core.ErrUnderpriced) || errors.Is(err, core.ErrReplaceUnderpriced):
- underpriced++
+ case errors.Is(err, core.ErrUnderpriced) || errors.Is(err, core.ErrReplaceUnderpriced):
+ underpriced++
- default:
- otherreject++
+ default:
+ otherreject++
+ }
+ added = append(added, batch[j].Hash())
+ }
+ knownMeter.Mark(duplicate)
+ underpricedMeter.Mark(underpriced)
+ otherRejectMeter.Mark(otherreject)
+
+ // If 'other reject' is >25% of the deliveries in any batch, sleep a bit.
+ if otherreject > 128/4 {
+ time.Sleep(200 * time.Millisecond)
+ log.Warn("Peer delivering stale transactions", "peer", peer, "rejected", otherreject)
}
- added = append(added, txs[i].Hash())
- }
- if direct {
- txReplyKnownMeter.Mark(duplicate)
- txReplyUnderpricedMeter.Mark(underpriced)
- txReplyOtherRejectMeter.Mark(otherreject)
- } else {
- txBroadcastKnownMeter.Mark(duplicate)
- txBroadcastUnderpricedMeter.Mark(underpriced)
- txBroadcastOtherRejectMeter.Mark(otherreject)
}
select {
case f.cleanup <- &txDelivery{origin: peer, hashes: added, direct: direct}:
diff --git a/node/config.go b/node/config.go
index 2047299fb5d7..49959d5ec5de 100644
--- a/node/config.go
+++ b/node/config.go
@@ -201,7 +201,7 @@ type Config struct {
// AllowUnprotectedTxs allows non EIP-155 protected transactions to be send over RPC.
AllowUnprotectedTxs bool `toml:",omitempty"`
- // JWTSecret is the hex-encoded jwt secret.
+ // JWTSecret is the path to the hex-encoded jwt secret.
JWTSecret string `toml:",omitempty"`
}
diff --git a/node/jwt_auth.go b/node/jwt_auth.go
new file mode 100644
index 000000000000..d4f8193ca7f2
--- /dev/null
+++ b/node/jwt_auth.go
@@ -0,0 +1,45 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package node
+
+import (
+ "fmt"
+ "net/http"
+ "time"
+
+ "github.com/ethereum/go-ethereum/rpc"
+ "github.com/golang-jwt/jwt/v4"
+)
+
+// NewJWTAuth creates an rpc client authentication provider that uses JWT. The
+// secret MUST be 32 bytes (256 bits) as defined by the Engine-API authentication spec.
+//
+// See https://github.com/ethereum/execution-apis/blob/main/src/engine/authentication.md
+// for more details about this authentication scheme.
+func NewJWTAuth(jwtsecret [32]byte) rpc.HTTPAuth {
+ return func(h http.Header) error {
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "iat": &jwt.NumericDate{Time: time.Now()},
+ })
+ s, err := token.SignedString(jwtsecret[:])
+ if err != nil {
+ return fmt.Errorf("failed to create JWT token: %w", err)
+ }
+ h.Set("Authorization", "Bearer "+s)
+ return nil
+ }
+}
diff --git a/node/node.go b/node/node.go
index b60e32f22fd2..3cbefef022e5 100644
--- a/node/node.go
+++ b/node/node.go
@@ -668,6 +668,19 @@ func (n *Node) WSEndpoint() string {
return "ws://" + n.ws.listenAddr() + n.ws.wsConfig.prefix
}
+// HTTPAuthEndpoint returns the URL of the authenticated HTTP server.
+func (n *Node) HTTPAuthEndpoint() string {
+ return "http://" + n.httpAuth.listenAddr()
+}
+
+// WSAuthEndpoint returns the current authenticated JSON-RPC over WebSocket endpoint.
+func (n *Node) WSAuthEndpoint() string {
+ if n.httpAuth.wsAllowed() {
+ return "ws://" + n.httpAuth.listenAddr() + n.httpAuth.wsConfig.prefix
+ }
+ return "ws://" + n.wsAuth.listenAddr() + n.wsAuth.wsConfig.prefix
+}
+
// EventMux retrieves the event multiplexer used by all the network services in
// the current protocol stack.
func (n *Node) EventMux() *event.TypeMux {
diff --git a/node/node_auth_test.go b/node/node_auth_test.go
new file mode 100644
index 000000000000..597cd8531f79
--- /dev/null
+++ b/node/node_auth_test.go
@@ -0,0 +1,237 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package node
+
+import (
+ "context"
+ crand "crypto/rand"
+ "fmt"
+ "net/http"
+ "os"
+ "path"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common/hexutil"
+ "github.com/ethereum/go-ethereum/rpc"
+ "github.com/golang-jwt/jwt/v4"
+)
+
+type helloRPC string
+
+func (ta helloRPC) HelloWorld() (string, error) {
+ return string(ta), nil
+}
+
+type authTest struct {
+ name string
+ endpoint string
+ prov rpc.HTTPAuth
+ expectDialFail bool
+ expectCall1Fail bool
+ expectCall2Fail bool
+}
+
+func (at *authTest) Run(t *testing.T) {
+ ctx := context.Background()
+ cl, err := rpc.DialOptions(ctx, at.endpoint, rpc.WithHTTPAuth(at.prov))
+ if at.expectDialFail {
+ if err == nil {
+ t.Fatal("expected initial dial to fail")
+ } else {
+ return
+ }
+ }
+ if err != nil {
+ t.Fatalf("failed to dial rpc endpoint: %v", err)
+ }
+
+ var x string
+ err = cl.CallContext(ctx, &x, "engine_helloWorld")
+ if at.expectCall1Fail {
+ if err == nil {
+ t.Fatal("expected call 1 to fail")
+ } else {
+ return
+ }
+ }
+ if err != nil {
+ t.Fatalf("failed to call rpc endpoint: %v", err)
+ }
+ if x != "hello engine" {
+ t.Fatalf("method was silent but did not return expected value: %q", x)
+ }
+
+ err = cl.CallContext(ctx, &x, "eth_helloWorld")
+ if at.expectCall2Fail {
+ if err == nil {
+ t.Fatal("expected call 2 to fail")
+ } else {
+ return
+ }
+ }
+ if err != nil {
+ t.Fatalf("failed to call rpc endpoint: %v", err)
+ }
+ if x != "hello eth" {
+ t.Fatalf("method was silent but did not return expected value: %q", x)
+ }
+}
+
+func TestAuthEndpoints(t *testing.T) {
+ var secret [32]byte
+ if _, err := crand.Read(secret[:]); err != nil {
+ t.Fatalf("failed to create jwt secret: %v", err)
+ }
+ // Geth must read it from a file, and does not support in-memory JWT secrets, so we create a temporary file.
+ jwtPath := path.Join(t.TempDir(), "jwt_secret")
+ if err := os.WriteFile(jwtPath, []byte(hexutil.Encode(secret[:])), 0600); err != nil {
+ t.Fatalf("failed to prepare jwt secret file: %v", err)
+ }
+ // We get ports assigned by the node automatically
+ conf := &Config{
+ HTTPHost: "127.0.0.1",
+ HTTPPort: 0,
+ WSHost: "127.0.0.1",
+ WSPort: 0,
+ AuthAddr: "127.0.0.1",
+ AuthPort: 0,
+ JWTSecret: jwtPath,
+
+ WSModules: []string{"eth", "engine"},
+ HTTPModules: []string{"eth", "engine"},
+ }
+ node, err := New(conf)
+ if err != nil {
+ t.Fatalf("could not create a new node: %v", err)
+ }
+ // register dummy apis so we can test the modules are available and reachable with authentication
+ node.RegisterAPIs([]rpc.API{
+ {
+ Namespace: "engine",
+ Version: "1.0",
+ Service: helloRPC("hello engine"),
+ Public: true,
+ Authenticated: true,
+ },
+ {
+ Namespace: "eth",
+ Version: "1.0",
+ Service: helloRPC("hello eth"),
+ Public: true,
+ Authenticated: true,
+ },
+ })
+ if err := node.Start(); err != nil {
+ t.Fatalf("failed to start test node: %v", err)
+ }
+ defer node.Close()
+
+ // sanity check we are running different endpoints
+ if a, b := node.WSEndpoint(), node.WSAuthEndpoint(); a == b {
+ t.Fatalf("expected ws and auth-ws endpoints to be different, got: %q and %q", a, b)
+ }
+ if a, b := node.HTTPEndpoint(), node.HTTPAuthEndpoint(); a == b {
+ t.Fatalf("expected http and auth-http endpoints to be different, got: %q and %q", a, b)
+ }
+
+ goodAuth := NewJWTAuth(secret)
+ var otherSecret [32]byte
+ if _, err := crand.Read(otherSecret[:]); err != nil {
+ t.Fatalf("failed to create jwt secret: %v", err)
+ }
+ badAuth := NewJWTAuth(otherSecret)
+
+ notTooLong := time.Second * 57
+ tooLong := time.Second * 60
+ requestDelay := time.Second
+
+ testCases := []authTest{
+ // Auth works
+ {name: "ws good", endpoint: node.WSAuthEndpoint(), prov: goodAuth, expectCall1Fail: false},
+ {name: "http good", endpoint: node.HTTPAuthEndpoint(), prov: goodAuth, expectCall1Fail: false},
+
+ // Try a bad auth
+ {name: "ws bad", endpoint: node.WSAuthEndpoint(), prov: badAuth, expectDialFail: true}, // ws auth is immediate
+ {name: "http bad", endpoint: node.HTTPAuthEndpoint(), prov: badAuth, expectCall1Fail: true}, // http auth is on first call
+
+ // A common mistake with JWT is to allow the "none" algorithm, which is a valid JWT but not secure.
+ {name: "ws none", endpoint: node.WSAuthEndpoint(), prov: noneAuth(secret), expectDialFail: true},
+ {name: "http none", endpoint: node.HTTPAuthEndpoint(), prov: noneAuth(secret), expectCall1Fail: true},
+
+ // claims of 5 seconds or more, older or newer, are not allowed
+ {name: "ws too old", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, -tooLong), expectDialFail: true},
+ {name: "http too old", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, -tooLong), expectCall1Fail: true},
+ // note: for it to be too long we need to add a delay, so that once we receive the request, the difference has not dipped below the "tooLong"
+ {name: "ws too new", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, tooLong+requestDelay), expectDialFail: true},
+ {name: "http too new", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, tooLong+requestDelay), expectCall1Fail: true},
+
+ // Try offset the time, but stay just within bounds
+ {name: "ws old", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, -notTooLong)},
+ {name: "http old", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, -notTooLong)},
+ {name: "ws new", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, notTooLong)},
+ {name: "http new", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, notTooLong)},
+
+ // ws only authenticates on initial dial, then continues communication
+ {name: "ws single auth", endpoint: node.WSAuthEndpoint(), prov: changingAuth(goodAuth, badAuth)},
+ {name: "http call fail auth", endpoint: node.HTTPAuthEndpoint(), prov: changingAuth(goodAuth, badAuth), expectCall2Fail: true},
+ {name: "http call fail time", endpoint: node.HTTPAuthEndpoint(), prov: changingAuth(goodAuth, offsetTimeAuth(secret, tooLong+requestDelay)), expectCall2Fail: true},
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.name, testCase.Run)
+ }
+}
+
+func noneAuth(secret [32]byte) rpc.HTTPAuth {
+ return func(header http.Header) error {
+ token := jwt.NewWithClaims(jwt.SigningMethodNone, jwt.MapClaims{
+ "iat": &jwt.NumericDate{Time: time.Now()},
+ })
+ s, err := token.SignedString(secret[:])
+ if err != nil {
+ return fmt.Errorf("failed to create JWT token: %w", err)
+ }
+ header.Set("Authorization", "Bearer "+s)
+ return nil
+ }
+}
+
+func changingAuth(provs ...rpc.HTTPAuth) rpc.HTTPAuth {
+ i := 0
+ return func(header http.Header) error {
+ i += 1
+ if i > len(provs) {
+ i = len(provs)
+ }
+ return provs[i-1](header)
+ }
+}
+
+func offsetTimeAuth(secret [32]byte, offset time.Duration) rpc.HTTPAuth {
+ return func(header http.Header) error {
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "iat": &jwt.NumericDate{Time: time.Now().Add(offset)},
+ })
+ s, err := token.SignedString(secret[:])
+ if err != nil {
+ return fmt.Errorf("failed to create JWT token: %w", err)
+ }
+ header.Set("Authorization", "Bearer "+s)
+ return nil
+ }
+}
diff --git a/rpc/client.go b/rpc/client.go
index d3ce0297754c..8288f976ebeb 100644
--- a/rpc/client.go
+++ b/rpc/client.go
@@ -22,6 +22,7 @@ import (
"errors"
"fmt"
"net/url"
+ "os"
"reflect"
"strconv"
"sync/atomic"
@@ -99,7 +100,7 @@ type Client struct {
reqTimeout chan *requestOp // removes response IDs when call timeout expires
}
-type reconnectFunc func(ctx context.Context) (ServerCodec, error)
+type reconnectFunc func(context.Context) (ServerCodec, error)
type clientContextKey struct{}
@@ -153,14 +154,16 @@ func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, erro
//
// The currently supported URL schemes are "http", "https", "ws" and "wss". If rawurl is a
// file name with no URL scheme, a local socket connection is established using UNIX
-// domain sockets on supported platforms and named pipes on Windows. If you want to
-// configure transport options, use DialHTTP, DialWebsocket or DialIPC instead.
+// domain sockets on supported platforms and named pipes on Windows.
+//
+// If you want to further configure the transport, use DialOptions instead of this
+// function.
//
// For websocket connections, the origin is set to the local host name.
//
-// The client reconnects automatically if the connection is lost.
+// The client reconnects automatically when the connection is lost.
func Dial(rawurl string) (*Client, error) {
- return DialContext(context.Background(), rawurl)
+ return DialOptions(context.Background(), rawurl)
}
// DialContext creates a new RPC client, just like Dial.
@@ -168,22 +171,46 @@ func Dial(rawurl string) (*Client, error) {
// The context is used to cancel or time out the initial connection establishment. It does
// not affect subsequent interactions with the client.
func DialContext(ctx context.Context, rawurl string) (*Client, error) {
+ return DialOptions(ctx, rawurl)
+}
+
+// DialOptions creates a new RPC client for the given URL. You can supply any of the
+// pre-defined client options to configure the underlying transport.
+//
+// The context is used to cancel or time out the initial connection establishment. It does
+// not affect subsequent interactions with the client.
+//
+// The client reconnects automatically when the connection is lost.
+func DialOptions(ctx context.Context, rawurl string, options ...ClientOption) (*Client, error) {
u, err := url.Parse(rawurl)
if err != nil {
return nil, err
}
+
+ cfg := new(clientConfig)
+ for _, opt := range options {
+ opt.applyOption(cfg)
+ }
+
+ var reconnect reconnectFunc
switch u.Scheme {
case "http", "https":
- return DialHTTP(rawurl)
+ reconnect = newClientTransportHTTP(rawurl, cfg)
case "ws", "wss":
- return DialWebsocket(ctx, rawurl, "")
+ rc, err := newClientTransportWS(rawurl, cfg)
+ if err != nil {
+ return nil, err
+ }
+ reconnect = rc
case "stdio":
- return DialStdIO(ctx)
+ reconnect = newClientTransportIO(os.Stdin, os.Stdout)
case "":
- return DialIPC(ctx, rawurl)
+ reconnect = newClientTransportIPC(rawurl)
default:
return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme)
}
+
+ return newClient(ctx, reconnect)
}
// ClientFromContext retrieves the client from the context, if any. This can be used to perform
diff --git a/rpc/client_opt.go b/rpc/client_opt.go
new file mode 100644
index 000000000000..5ad7c22b3ce7
--- /dev/null
+++ b/rpc/client_opt.go
@@ -0,0 +1,106 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rpc
+
+import (
+ "net/http"
+
+ "github.com/gorilla/websocket"
+)
+
+// ClientOption is a configuration option for the RPC client.
+type ClientOption interface {
+ applyOption(*clientConfig)
+}
+
+type clientConfig struct {
+ httpClient *http.Client
+ httpHeaders http.Header
+ httpAuth HTTPAuth
+
+ wsDialer *websocket.Dialer
+}
+
+func (cfg *clientConfig) initHeaders() {
+ if cfg.httpHeaders == nil {
+ cfg.httpHeaders = make(http.Header)
+ }
+}
+
+func (cfg *clientConfig) setHeader(key, value string) {
+ cfg.initHeaders()
+ cfg.httpHeaders.Set(key, value)
+}
+
+type optionFunc func(*clientConfig)
+
+func (fn optionFunc) applyOption(opt *clientConfig) {
+ fn(opt)
+}
+
+// WithWebsocketDialer configures the websocket.Dialer used by the RPC client.
+func WithWebsocketDialer(dialer websocket.Dialer) ClientOption {
+ return optionFunc(func(cfg *clientConfig) {
+ cfg.wsDialer = &dialer
+ })
+}
+
+// WithHeader configures HTTP headers set by the RPC client. Headers set using this option
+// will be used for both HTTP and WebSocket connections.
+func WithHeader(key, value string) ClientOption {
+ return optionFunc(func(cfg *clientConfig) {
+ cfg.initHeaders()
+ cfg.httpHeaders.Set(key, value)
+ })
+}
+
+// WithHeaders configures HTTP headers set by the RPC client. Headers set using this
+// option will be used for both HTTP and WebSocket connections.
+func WithHeaders(headers http.Header) ClientOption {
+ return optionFunc(func(cfg *clientConfig) {
+ cfg.initHeaders()
+ for k, vs := range headers {
+ cfg.httpHeaders[k] = vs
+ }
+ })
+}
+
+// WithHTTPClient configures the http.Client used by the RPC client.
+func WithHTTPClient(c *http.Client) ClientOption {
+ return optionFunc(func(cfg *clientConfig) {
+ cfg.httpClient = c
+ })
+}
+
+// WithHTTPAuth configures HTTP request authentication. The given provider will be called
+// whenever a request is made. Note that only one authentication provider can be active at
+// any time.
+func WithHTTPAuth(a HTTPAuth) ClientOption {
+ if a == nil {
+ panic("nil auth")
+ }
+ return optionFunc(func(cfg *clientConfig) {
+ cfg.httpAuth = a
+ })
+}
+
+// A HTTPAuth function is called by the client whenever a HTTP request is sent.
+// The function must be safe for concurrent use.
+//
+// Usually, HTTPAuth functions will call h.Set("authorization", "...") to add
+// auth information to the request.
+type HTTPAuth func(h http.Header) error
diff --git a/rpc/client_opt_test.go b/rpc/client_opt_test.go
new file mode 100644
index 000000000000..d7cc2572a776
--- /dev/null
+++ b/rpc/client_opt_test.go
@@ -0,0 +1,25 @@
+package rpc_test
+
+import (
+ "context"
+ "net/http"
+ "time"
+
+ "github.com/ethereum/go-ethereum/rpc"
+)
+
+// This example configures a HTTP-based RPC client with two options - one setting the
+// overall request timeout, the other adding a custom HTTP header to all requests.
+func ExampleDialOptions() {
+ tokenHeader := rpc.WithHeader("x-token", "foo")
+ httpClient := rpc.WithHTTPClient(&http.Client{
+ Timeout: 10 * time.Second,
+ })
+
+ ctx := context.Background()
+ c, err := rpc.DialOptions(ctx, "http://rpc.example.com", httpClient, tokenHeader)
+ if err != nil {
+ panic(err)
+ }
+ c.Close()
+}
diff --git a/rpc/http.go b/rpc/http.go
index 858d80858652..8595959afb66 100644
--- a/rpc/http.go
+++ b/rpc/http.go
@@ -45,6 +45,7 @@ type httpConn struct {
closeCh chan interface{}
mu sync.Mutex // protects headers
headers http.Header
+ auth HTTPAuth
}
// httpConn implements ServerCodec, but it is treated specially by Client
@@ -117,8 +118,15 @@ var DefaultHTTPTimeouts = HTTPTimeouts{
IdleTimeout: 120 * time.Second,
}
+// DialHTTP creates a new RPC client that connects to an RPC server over HTTP.
+func DialHTTP(endpoint string) (*Client, error) {
+ return DialHTTPWithClient(endpoint, new(http.Client))
+}
+
// DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP
// using the provided HTTP Client.
+//
+// Deprecated: use DialOptions and the WithHTTPClient option.
func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
// Sanity check URL so we don't end up with a client that will fail every request.
_, err := url.Parse(endpoint)
@@ -126,24 +134,35 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
return nil, err
}
- initctx := context.Background()
- headers := make(http.Header, 2)
+ var cfg clientConfig
+ fn := newClientTransportHTTP(endpoint, &cfg)
+ return newClient(context.Background(), fn)
+}
+
+func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc {
+ headers := make(http.Header, 2+len(cfg.httpHeaders))
headers.Set("accept", contentType)
headers.Set("content-type", contentType)
- return newClient(initctx, func(context.Context) (ServerCodec, error) {
- hc := &httpConn{
- client: client,
- headers: headers,
- url: endpoint,
- closeCh: make(chan interface{}),
- }
- return hc, nil
- })
-}
+ for key, values := range cfg.httpHeaders {
+ headers[key] = values
+ }
-// DialHTTP creates a new RPC client that connects to an RPC server over HTTP.
-func DialHTTP(endpoint string) (*Client, error) {
- return DialHTTPWithClient(endpoint, new(http.Client))
+ client := cfg.httpClient
+ if client == nil {
+ client = new(http.Client)
+ }
+
+ hc := &httpConn{
+ client: client,
+ headers: headers,
+ url: endpoint,
+ auth: cfg.httpAuth,
+ closeCh: make(chan interface{}),
+ }
+
+ return func(ctx context.Context) (ServerCodec, error) {
+ return hc, nil
+ }
}
func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) error {
@@ -195,6 +214,11 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos
hc.mu.Lock()
req.Header = hc.headers.Clone()
hc.mu.Unlock()
+ if hc.auth != nil {
+ if err := hc.auth(req.Header); err != nil {
+ return nil, err
+ }
+ }
// do request
resp, err := hc.client.Do(req)
diff --git a/rpc/ipc.go b/rpc/ipc.go
index 07a211c6277c..d9e0de62e877 100644
--- a/rpc/ipc.go
+++ b/rpc/ipc.go
@@ -46,11 +46,15 @@ func (s *Server) ServeListener(l net.Listener) error {
// The context is used for the initial connection establishment. It does not
// affect subsequent interactions with the client.
func DialIPC(ctx context.Context, endpoint string) (*Client, error) {
- return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
+ return newClient(ctx, newClientTransportIPC(endpoint))
+}
+
+func newClientTransportIPC(endpoint string) reconnectFunc {
+ return func(ctx context.Context) (ServerCodec, error) {
conn, err := newIPCConnection(ctx, endpoint)
if err != nil {
return nil, err
}
return NewCodec(conn), err
- })
+ }
}
diff --git a/rpc/json.go b/rpc/json.go
index 6024f1e7dc9b..6b2ac2d52a7b 100644
--- a/rpc/json.go
+++ b/rpc/json.go
@@ -58,21 +58,25 @@ type jsonrpcMessage struct {
}
func (msg *jsonrpcMessage) isNotification() bool {
- return msg.ID == nil && msg.Method != ""
+ return msg.hasValidVersion() && msg.ID == nil && msg.Method != ""
}
func (msg *jsonrpcMessage) isCall() bool {
- return msg.hasValidID() && msg.Method != ""
+ return msg.hasValidVersion() && msg.hasValidID() && msg.Method != ""
}
func (msg *jsonrpcMessage) isResponse() bool {
- return msg.hasValidID() && msg.Method == "" && msg.Params == nil && (msg.Result != nil || msg.Error != nil)
+ return msg.hasValidVersion() && msg.hasValidID() && msg.Method == "" && msg.Params == nil && (msg.Result != nil || msg.Error != nil)
}
func (msg *jsonrpcMessage) hasValidID() bool {
return len(msg.ID) > 0 && msg.ID[0] != '{' && msg.ID[0] != '['
}
+func (msg *jsonrpcMessage) hasValidVersion() bool {
+ return msg.Version == vsn
+}
+
func (msg *jsonrpcMessage) isSubscribe() bool {
return strings.HasSuffix(msg.Method, subscribeMethodSuffix)
}
diff --git a/rpc/stdio.go b/rpc/stdio.go
index be2bab1c98bd..ae32db26ef1c 100644
--- a/rpc/stdio.go
+++ b/rpc/stdio.go
@@ -32,12 +32,16 @@ func DialStdIO(ctx context.Context) (*Client, error) {
// DialIO creates a client which uses the given IO channels
func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) {
- return newClient(ctx, func(_ context.Context) (ServerCodec, error) {
+ return newClient(ctx, newClientTransportIO(in, out))
+}
+
+func newClientTransportIO(in io.Reader, out io.Writer) reconnectFunc {
+ return func(context.Context) (ServerCodec, error) {
return NewCodec(stdioConn{
in: in,
out: out,
}), nil
- })
+ }
}
type stdioConn struct {
diff --git a/rpc/subscription_test.go b/rpc/subscription_test.go
index a920205c001f..b2704578291e 100644
--- a/rpc/subscription_test.go
+++ b/rpc/subscription_test.go
@@ -79,7 +79,7 @@ func TestSubscriptions(t *testing.T) {
request := map[string]interface{}{
"id": i,
"method": fmt.Sprintf("%s_subscribe", namespace),
- "version": "2.0",
+ "jsonrpc": "2.0",
"params": []interface{}{"someSubscription", notificationCount, i},
}
if err := out.Encode(&request); err != nil {
diff --git a/rpc/testdata/invalid-badversion.js b/rpc/testdata/invalid-badversion.js
new file mode 100644
index 000000000000..75b5291dc3f0
--- /dev/null
+++ b/rpc/testdata/invalid-badversion.js
@@ -0,0 +1,19 @@
+// This test checks processing of messages with invalid Version.
+
+--> {"jsonrpc":"2.0","id":1,"method":"test_echo","params":["x", 3]}
+<-- {"jsonrpc":"2.0","id":1,"result":{"String":"x","Int":3,"Args":null}}
+
+--> {"jsonrpc":"2.1","id":1,"method":"test_echo","params":["x", 3]}
+<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}}
+
+--> {"jsonrpc":"go-ethereum","id":1,"method":"test_echo","params":["x", 3]}
+<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}}
+
+--> {"jsonrpc":1,"id":1,"method":"test_echo","params":["x", 3]}
+<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}}
+
+--> {"jsonrpc":2.0,"id":1,"method":"test_echo","params":["x", 3]}
+<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}}
+
+--> {"id":1,"method":"test_echo","params":["x", 3]}
+<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}}
diff --git a/rpc/websocket.go b/rpc/websocket.go
index 28380d8aa4ae..f2a923446cac 100644
--- a/rpc/websocket.go
+++ b/rpc/websocket.go
@@ -181,24 +181,23 @@ func parseOriginURL(origin string) (string, string, string, error) {
return scheme, hostname, port, nil
}
-// DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server
-// that is listening on the given endpoint using the provided dialer.
+// DialWebsocketWithDialer creates a new RPC client using WebSocket.
+//
+// The context is used for the initial connection establishment. It does not
+// affect subsequent interactions with the client.
+//
+// Deprecated: use DialOptions and the WithWebsocketDialer option.
func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
- endpoint, header, err := wsClientHeaders(endpoint, origin)
+ cfg := new(clientConfig)
+ cfg.wsDialer = &dialer
+ if origin != "" {
+ cfg.setHeader("origin", origin)
+ }
+ connect, err := newClientTransportWS(endpoint, cfg)
if err != nil {
return nil, err
}
- return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
- conn, resp, err := dialer.DialContext(ctx, endpoint, header)
- if err != nil {
- hErr := wsHandshakeError{err: err}
- if resp != nil {
- hErr.status = resp.Status
- }
- return nil, hErr
- }
- return newWebsocketCodec(conn, endpoint, header), nil
- })
+ return newClient(ctx, connect)
}
// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
@@ -207,12 +206,53 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale
// The context is used for the initial connection establishment. It does not
// affect subsequent interactions with the client.
func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
- dialer := websocket.Dialer{
- ReadBufferSize: wsReadBuffer,
- WriteBufferSize: wsWriteBuffer,
- WriteBufferPool: wsBufferPool,
+ cfg := new(clientConfig)
+ if origin != "" {
+ cfg.setHeader("origin", origin)
+ }
+ connect, err := newClientTransportWS(endpoint, cfg)
+ if err != nil {
+ return nil, err
+ }
+ return newClient(ctx, connect)
+}
+
+func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) {
+ dialer := cfg.wsDialer
+ if dialer == nil {
+ dialer = &websocket.Dialer{
+ ReadBufferSize: wsReadBuffer,
+ WriteBufferSize: wsWriteBuffer,
+ WriteBufferPool: wsBufferPool,
+ }
+ }
+
+ dialURL, header, err := wsClientHeaders(endpoint, "")
+ if err != nil {
+ return nil, err
+ }
+ for key, values := range cfg.httpHeaders {
+ header[key] = values
+ }
+
+ connect := func(ctx context.Context) (ServerCodec, error) {
+ header := header.Clone()
+ if cfg.httpAuth != nil {
+ if err := cfg.httpAuth(header); err != nil {
+ return nil, err
+ }
+ }
+ conn, resp, err := dialer.DialContext(ctx, dialURL, header)
+ if err != nil {
+ hErr := wsHandshakeError{err: err}
+ if resp != nil {
+ hErr.status = resp.Status
+ }
+ return nil, hErr
+ }
+ return newWebsocketCodec(conn, dialURL, header), nil
}
- return DialWebsocketWithDialer(ctx, endpoint, origin, dialer)
+ return connect, nil
}
func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
diff --git a/trie/secure_trie.go b/trie/secure_trie.go
index 1da152477d38..e26c22465504 100644
--- a/trie/secure_trie.go
+++ b/trie/secure_trie.go
@@ -58,7 +58,7 @@ type StateTrie struct {
// and returns MissingNodeError if the root node cannot be found.
func NewStateTrie(owner common.Hash, root common.Hash, db *Database) (*StateTrie, error) {
if db == nil {
- panic("trie.NewSecure called without a database")
+ panic("trie.NewStateTrie called without a database")
}
trie, err := New(owner, root, db)
if err != nil {
diff --git a/trie/util_test.go b/trie/util_test.go
index 252dc09e0804..cf6758e63d4a 100644
--- a/trie/util_test.go
+++ b/trie/util_test.go
@@ -69,7 +69,9 @@ func TestTrieTracer(t *testing.T) {
// Commit the changes and re-create with new root
root, nodes, _ := trie.Commit(false)
- db.Update(NewWithNodeSet(nodes))
+ if err := db.Update(NewWithNodeSet(nodes)); err != nil {
+ t.Fatal(err)
+ }
trie, _ = New(common.Hash{}, root, db)
trie.tracer = newTracer()