diff --git a/cmd/devp2p/internal/ethtest/helpers.go b/cmd/devp2p/internal/ethtest/helpers.go index eeeb4f93cabf..b57649ade99d 100644 --- a/cmd/devp2p/internal/ethtest/helpers.go +++ b/cmd/devp2p/internal/ethtest/helpers.go @@ -357,9 +357,13 @@ func (s *Suite) waitAnnounce(conn *Conn, blockAnnouncement *NewBlock) error { return fmt.Errorf("wrong block hash in announcement: expected %v, got %v", blockAnnouncement.Block.Hash(), hashes[0].Hash) } return nil + + // ignore tx announcements from previous tests case *NewPooledTransactionHashes: - // ignore tx announcements from previous tests continue + case *Transactions: + continue + default: return fmt.Errorf("unexpected: %s", pretty.Sdump(msg)) } diff --git a/cmd/devp2p/internal/ethtest/suite.go b/cmd/devp2p/internal/ethtest/suite.go index 7059b4ba738c..4497478d72d6 100644 --- a/cmd/devp2p/internal/ethtest/suite.go +++ b/cmd/devp2p/internal/ethtest/suite.go @@ -544,9 +544,13 @@ func (s *Suite) TestNewPooledTxs(t *utesting.T) { t.Fatalf("unexpected number of txs requested: wanted %d, got %d", len(hashes), len(msg.GetPooledTransactionsPacket)) } return + // ignore propagated txs from previous tests case *NewPooledTransactionHashes: continue + case *Transactions: + continue + // ignore block announcements from previous tests case *NewBlockHashes: continue diff --git a/cmd/devp2p/internal/ethtest/transaction.go b/cmd/devp2p/internal/ethtest/transaction.go index c4748bf8f7d8..baa55bd49268 100644 --- a/cmd/devp2p/internal/ethtest/transaction.go +++ b/cmd/devp2p/internal/ethtest/transaction.go @@ -29,7 +29,7 @@ import ( "github.com/ethereum/go-ethereum/params" ) -//var faucetAddr = common.HexToAddress("0x71562b71999873DB5b286dF957af199Ec94617F7") +// var faucetAddr = common.HexToAddress("0x71562b71999873DB5b286dF957af199Ec94617F7") var faucetKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") func (s *Suite) sendSuccessfulTxs(t *utesting.T) error { @@ -192,10 +192,10 @@ func sendMultipleSuccessfulTxs(t *utesting.T, s *Suite, txs []*types.Transaction nonce = txs[len(txs)-1].Nonce() // Wait for the transaction announcement(s) and make sure all sent txs are being propagated. - // all txs should be announced within 3 announcements. + // all txs should be announced within a couple announcements. recvHashes := make([]common.Hash, 0) - for i := 0; i < 3; i++ { + for i := 0; i < 20; i++ { switch msg := recvConn.readAndServe(s.chain, timeout).(type) { case *Transactions: for _, tx := range *msg { diff --git a/eth/fetcher/tx_fetcher.go b/eth/fetcher/tx_fetcher.go index 035e0c2ec7d8..677a6422b011 100644 --- a/eth/fetcher/tx_fetcher.go +++ b/eth/fetcher/tx_fetcher.go @@ -262,54 +262,72 @@ func (f *TxFetcher) Notify(peer string, hashes []common.Hash) error { // direct request replies. The differentiation is important so the fetcher can // re-schedule missing transactions as soon as possible. func (f *TxFetcher) Enqueue(peer string, txs []*types.Transaction, direct bool) error { - // Keep track of all the propagated transactions - if direct { - txReplyInMeter.Mark(int64(len(txs))) - } else { - txBroadcastInMeter.Mark(int64(len(txs))) + var ( + inMeter = txReplyInMeter + knownMeter = txReplyKnownMeter + underpricedMeter = txReplyUnderpricedMeter + otherRejectMeter = txReplyOtherRejectMeter + ) + if !direct { + inMeter = txBroadcastInMeter + knownMeter = txBroadcastKnownMeter + underpricedMeter = txBroadcastUnderpricedMeter + otherRejectMeter = txBroadcastOtherRejectMeter } + // Keep track of all the propagated transactions + inMeter.Mark(int64(len(txs))) + // Push all the transactions into the pool, tracking underpriced ones to avoid // re-requesting them and dropping the peer in case of malicious transfers. var ( - added = make([]common.Hash, 0, len(txs)) - duplicate int64 - underpriced int64 - otherreject int64 + added = make([]common.Hash, 0, len(txs)) ) - errs := f.addTxs(txs) - for i, err := range errs { - // Track the transaction hash if the price is too low for us. - // Avoid re-request this transaction when we receive another - // announcement. - if errors.Is(err, core.ErrUnderpriced) || errors.Is(err, core.ErrReplaceUnderpriced) { - for f.underpriced.Cardinality() >= maxTxUnderpricedSetSize { - f.underpriced.Pop() - } - f.underpriced.Add(txs[i].Hash()) + // proceed in batches + for i := 0; i < len(txs); i += 128 { + end := i + 128 + if end > len(txs) { + end = len(txs) } - // Track a few interesting failure types - switch { - case err == nil: // Noop, but need to handle to not count these + var ( + duplicate int64 + underpriced int64 + otherreject int64 + ) + batch := txs[i:end] + for j, err := range f.addTxs(batch) { + // Track the transaction hash if the price is too low for us. + // Avoid re-request this transaction when we receive another + // announcement. + if errors.Is(err, core.ErrUnderpriced) || errors.Is(err, core.ErrReplaceUnderpriced) { + for f.underpriced.Cardinality() >= maxTxUnderpricedSetSize { + f.underpriced.Pop() + } + f.underpriced.Add(batch[j].Hash()) + } + // Track a few interesting failure types + switch { + case err == nil: // Noop, but need to handle to not count these - case errors.Is(err, core.ErrAlreadyKnown): - duplicate++ + case errors.Is(err, core.ErrAlreadyKnown): + duplicate++ - case errors.Is(err, core.ErrUnderpriced) || errors.Is(err, core.ErrReplaceUnderpriced): - underpriced++ + case errors.Is(err, core.ErrUnderpriced) || errors.Is(err, core.ErrReplaceUnderpriced): + underpriced++ - default: - otherreject++ + default: + otherreject++ + } + added = append(added, batch[j].Hash()) + } + knownMeter.Mark(duplicate) + underpricedMeter.Mark(underpriced) + otherRejectMeter.Mark(otherreject) + + // If 'other reject' is >25% of the deliveries in any batch, sleep a bit. + if otherreject > 128/4 { + time.Sleep(200 * time.Millisecond) + log.Warn("Peer delivering stale transactions", "peer", peer, "rejected", otherreject) } - added = append(added, txs[i].Hash()) - } - if direct { - txReplyKnownMeter.Mark(duplicate) - txReplyUnderpricedMeter.Mark(underpriced) - txReplyOtherRejectMeter.Mark(otherreject) - } else { - txBroadcastKnownMeter.Mark(duplicate) - txBroadcastUnderpricedMeter.Mark(underpriced) - txBroadcastOtherRejectMeter.Mark(otherreject) } select { case f.cleanup <- &txDelivery{origin: peer, hashes: added, direct: direct}: diff --git a/node/config.go b/node/config.go index 2047299fb5d7..49959d5ec5de 100644 --- a/node/config.go +++ b/node/config.go @@ -201,7 +201,7 @@ type Config struct { // AllowUnprotectedTxs allows non EIP-155 protected transactions to be send over RPC. AllowUnprotectedTxs bool `toml:",omitempty"` - // JWTSecret is the hex-encoded jwt secret. + // JWTSecret is the path to the hex-encoded jwt secret. JWTSecret string `toml:",omitempty"` } diff --git a/node/jwt_auth.go b/node/jwt_auth.go new file mode 100644 index 000000000000..d4f8193ca7f2 --- /dev/null +++ b/node/jwt_auth.go @@ -0,0 +1,45 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package node + +import ( + "fmt" + "net/http" + "time" + + "github.com/ethereum/go-ethereum/rpc" + "github.com/golang-jwt/jwt/v4" +) + +// NewJWTAuth creates an rpc client authentication provider that uses JWT. The +// secret MUST be 32 bytes (256 bits) as defined by the Engine-API authentication spec. +// +// See https://github.com/ethereum/execution-apis/blob/main/src/engine/authentication.md +// for more details about this authentication scheme. +func NewJWTAuth(jwtsecret [32]byte) rpc.HTTPAuth { + return func(h http.Header) error { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "iat": &jwt.NumericDate{Time: time.Now()}, + }) + s, err := token.SignedString(jwtsecret[:]) + if err != nil { + return fmt.Errorf("failed to create JWT token: %w", err) + } + h.Set("Authorization", "Bearer "+s) + return nil + } +} diff --git a/node/node.go b/node/node.go index b60e32f22fd2..3cbefef022e5 100644 --- a/node/node.go +++ b/node/node.go @@ -668,6 +668,19 @@ func (n *Node) WSEndpoint() string { return "ws://" + n.ws.listenAddr() + n.ws.wsConfig.prefix } +// HTTPAuthEndpoint returns the URL of the authenticated HTTP server. +func (n *Node) HTTPAuthEndpoint() string { + return "http://" + n.httpAuth.listenAddr() +} + +// WSAuthEndpoint returns the current authenticated JSON-RPC over WebSocket endpoint. +func (n *Node) WSAuthEndpoint() string { + if n.httpAuth.wsAllowed() { + return "ws://" + n.httpAuth.listenAddr() + n.httpAuth.wsConfig.prefix + } + return "ws://" + n.wsAuth.listenAddr() + n.wsAuth.wsConfig.prefix +} + // EventMux retrieves the event multiplexer used by all the network services in // the current protocol stack. func (n *Node) EventMux() *event.TypeMux { diff --git a/node/node_auth_test.go b/node/node_auth_test.go new file mode 100644 index 000000000000..597cd8531f79 --- /dev/null +++ b/node/node_auth_test.go @@ -0,0 +1,237 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package node + +import ( + "context" + crand "crypto/rand" + "fmt" + "net/http" + "os" + "path" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/rpc" + "github.com/golang-jwt/jwt/v4" +) + +type helloRPC string + +func (ta helloRPC) HelloWorld() (string, error) { + return string(ta), nil +} + +type authTest struct { + name string + endpoint string + prov rpc.HTTPAuth + expectDialFail bool + expectCall1Fail bool + expectCall2Fail bool +} + +func (at *authTest) Run(t *testing.T) { + ctx := context.Background() + cl, err := rpc.DialOptions(ctx, at.endpoint, rpc.WithHTTPAuth(at.prov)) + if at.expectDialFail { + if err == nil { + t.Fatal("expected initial dial to fail") + } else { + return + } + } + if err != nil { + t.Fatalf("failed to dial rpc endpoint: %v", err) + } + + var x string + err = cl.CallContext(ctx, &x, "engine_helloWorld") + if at.expectCall1Fail { + if err == nil { + t.Fatal("expected call 1 to fail") + } else { + return + } + } + if err != nil { + t.Fatalf("failed to call rpc endpoint: %v", err) + } + if x != "hello engine" { + t.Fatalf("method was silent but did not return expected value: %q", x) + } + + err = cl.CallContext(ctx, &x, "eth_helloWorld") + if at.expectCall2Fail { + if err == nil { + t.Fatal("expected call 2 to fail") + } else { + return + } + } + if err != nil { + t.Fatalf("failed to call rpc endpoint: %v", err) + } + if x != "hello eth" { + t.Fatalf("method was silent but did not return expected value: %q", x) + } +} + +func TestAuthEndpoints(t *testing.T) { + var secret [32]byte + if _, err := crand.Read(secret[:]); err != nil { + t.Fatalf("failed to create jwt secret: %v", err) + } + // Geth must read it from a file, and does not support in-memory JWT secrets, so we create a temporary file. + jwtPath := path.Join(t.TempDir(), "jwt_secret") + if err := os.WriteFile(jwtPath, []byte(hexutil.Encode(secret[:])), 0600); err != nil { + t.Fatalf("failed to prepare jwt secret file: %v", err) + } + // We get ports assigned by the node automatically + conf := &Config{ + HTTPHost: "127.0.0.1", + HTTPPort: 0, + WSHost: "127.0.0.1", + WSPort: 0, + AuthAddr: "127.0.0.1", + AuthPort: 0, + JWTSecret: jwtPath, + + WSModules: []string{"eth", "engine"}, + HTTPModules: []string{"eth", "engine"}, + } + node, err := New(conf) + if err != nil { + t.Fatalf("could not create a new node: %v", err) + } + // register dummy apis so we can test the modules are available and reachable with authentication + node.RegisterAPIs([]rpc.API{ + { + Namespace: "engine", + Version: "1.0", + Service: helloRPC("hello engine"), + Public: true, + Authenticated: true, + }, + { + Namespace: "eth", + Version: "1.0", + Service: helloRPC("hello eth"), + Public: true, + Authenticated: true, + }, + }) + if err := node.Start(); err != nil { + t.Fatalf("failed to start test node: %v", err) + } + defer node.Close() + + // sanity check we are running different endpoints + if a, b := node.WSEndpoint(), node.WSAuthEndpoint(); a == b { + t.Fatalf("expected ws and auth-ws endpoints to be different, got: %q and %q", a, b) + } + if a, b := node.HTTPEndpoint(), node.HTTPAuthEndpoint(); a == b { + t.Fatalf("expected http and auth-http endpoints to be different, got: %q and %q", a, b) + } + + goodAuth := NewJWTAuth(secret) + var otherSecret [32]byte + if _, err := crand.Read(otherSecret[:]); err != nil { + t.Fatalf("failed to create jwt secret: %v", err) + } + badAuth := NewJWTAuth(otherSecret) + + notTooLong := time.Second * 57 + tooLong := time.Second * 60 + requestDelay := time.Second + + testCases := []authTest{ + // Auth works + {name: "ws good", endpoint: node.WSAuthEndpoint(), prov: goodAuth, expectCall1Fail: false}, + {name: "http good", endpoint: node.HTTPAuthEndpoint(), prov: goodAuth, expectCall1Fail: false}, + + // Try a bad auth + {name: "ws bad", endpoint: node.WSAuthEndpoint(), prov: badAuth, expectDialFail: true}, // ws auth is immediate + {name: "http bad", endpoint: node.HTTPAuthEndpoint(), prov: badAuth, expectCall1Fail: true}, // http auth is on first call + + // A common mistake with JWT is to allow the "none" algorithm, which is a valid JWT but not secure. + {name: "ws none", endpoint: node.WSAuthEndpoint(), prov: noneAuth(secret), expectDialFail: true}, + {name: "http none", endpoint: node.HTTPAuthEndpoint(), prov: noneAuth(secret), expectCall1Fail: true}, + + // claims of 5 seconds or more, older or newer, are not allowed + {name: "ws too old", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, -tooLong), expectDialFail: true}, + {name: "http too old", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, -tooLong), expectCall1Fail: true}, + // note: for it to be too long we need to add a delay, so that once we receive the request, the difference has not dipped below the "tooLong" + {name: "ws too new", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, tooLong+requestDelay), expectDialFail: true}, + {name: "http too new", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, tooLong+requestDelay), expectCall1Fail: true}, + + // Try offset the time, but stay just within bounds + {name: "ws old", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, -notTooLong)}, + {name: "http old", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, -notTooLong)}, + {name: "ws new", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, notTooLong)}, + {name: "http new", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, notTooLong)}, + + // ws only authenticates on initial dial, then continues communication + {name: "ws single auth", endpoint: node.WSAuthEndpoint(), prov: changingAuth(goodAuth, badAuth)}, + {name: "http call fail auth", endpoint: node.HTTPAuthEndpoint(), prov: changingAuth(goodAuth, badAuth), expectCall2Fail: true}, + {name: "http call fail time", endpoint: node.HTTPAuthEndpoint(), prov: changingAuth(goodAuth, offsetTimeAuth(secret, tooLong+requestDelay)), expectCall2Fail: true}, + } + + for _, testCase := range testCases { + t.Run(testCase.name, testCase.Run) + } +} + +func noneAuth(secret [32]byte) rpc.HTTPAuth { + return func(header http.Header) error { + token := jwt.NewWithClaims(jwt.SigningMethodNone, jwt.MapClaims{ + "iat": &jwt.NumericDate{Time: time.Now()}, + }) + s, err := token.SignedString(secret[:]) + if err != nil { + return fmt.Errorf("failed to create JWT token: %w", err) + } + header.Set("Authorization", "Bearer "+s) + return nil + } +} + +func changingAuth(provs ...rpc.HTTPAuth) rpc.HTTPAuth { + i := 0 + return func(header http.Header) error { + i += 1 + if i > len(provs) { + i = len(provs) + } + return provs[i-1](header) + } +} + +func offsetTimeAuth(secret [32]byte, offset time.Duration) rpc.HTTPAuth { + return func(header http.Header) error { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "iat": &jwt.NumericDate{Time: time.Now().Add(offset)}, + }) + s, err := token.SignedString(secret[:]) + if err != nil { + return fmt.Errorf("failed to create JWT token: %w", err) + } + header.Set("Authorization", "Bearer "+s) + return nil + } +} diff --git a/rpc/client.go b/rpc/client.go index d3ce0297754c..8288f976ebeb 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -22,6 +22,7 @@ import ( "errors" "fmt" "net/url" + "os" "reflect" "strconv" "sync/atomic" @@ -99,7 +100,7 @@ type Client struct { reqTimeout chan *requestOp // removes response IDs when call timeout expires } -type reconnectFunc func(ctx context.Context) (ServerCodec, error) +type reconnectFunc func(context.Context) (ServerCodec, error) type clientContextKey struct{} @@ -153,14 +154,16 @@ func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, erro // // The currently supported URL schemes are "http", "https", "ws" and "wss". If rawurl is a // file name with no URL scheme, a local socket connection is established using UNIX -// domain sockets on supported platforms and named pipes on Windows. If you want to -// configure transport options, use DialHTTP, DialWebsocket or DialIPC instead. +// domain sockets on supported platforms and named pipes on Windows. +// +// If you want to further configure the transport, use DialOptions instead of this +// function. // // For websocket connections, the origin is set to the local host name. // -// The client reconnects automatically if the connection is lost. +// The client reconnects automatically when the connection is lost. func Dial(rawurl string) (*Client, error) { - return DialContext(context.Background(), rawurl) + return DialOptions(context.Background(), rawurl) } // DialContext creates a new RPC client, just like Dial. @@ -168,22 +171,46 @@ func Dial(rawurl string) (*Client, error) { // The context is used to cancel or time out the initial connection establishment. It does // not affect subsequent interactions with the client. func DialContext(ctx context.Context, rawurl string) (*Client, error) { + return DialOptions(ctx, rawurl) +} + +// DialOptions creates a new RPC client for the given URL. You can supply any of the +// pre-defined client options to configure the underlying transport. +// +// The context is used to cancel or time out the initial connection establishment. It does +// not affect subsequent interactions with the client. +// +// The client reconnects automatically when the connection is lost. +func DialOptions(ctx context.Context, rawurl string, options ...ClientOption) (*Client, error) { u, err := url.Parse(rawurl) if err != nil { return nil, err } + + cfg := new(clientConfig) + for _, opt := range options { + opt.applyOption(cfg) + } + + var reconnect reconnectFunc switch u.Scheme { case "http", "https": - return DialHTTP(rawurl) + reconnect = newClientTransportHTTP(rawurl, cfg) case "ws", "wss": - return DialWebsocket(ctx, rawurl, "") + rc, err := newClientTransportWS(rawurl, cfg) + if err != nil { + return nil, err + } + reconnect = rc case "stdio": - return DialStdIO(ctx) + reconnect = newClientTransportIO(os.Stdin, os.Stdout) case "": - return DialIPC(ctx, rawurl) + reconnect = newClientTransportIPC(rawurl) default: return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme) } + + return newClient(ctx, reconnect) } // ClientFromContext retrieves the client from the context, if any. This can be used to perform diff --git a/rpc/client_opt.go b/rpc/client_opt.go new file mode 100644 index 000000000000..5ad7c22b3ce7 --- /dev/null +++ b/rpc/client_opt.go @@ -0,0 +1,106 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rpc + +import ( + "net/http" + + "github.com/gorilla/websocket" +) + +// ClientOption is a configuration option for the RPC client. +type ClientOption interface { + applyOption(*clientConfig) +} + +type clientConfig struct { + httpClient *http.Client + httpHeaders http.Header + httpAuth HTTPAuth + + wsDialer *websocket.Dialer +} + +func (cfg *clientConfig) initHeaders() { + if cfg.httpHeaders == nil { + cfg.httpHeaders = make(http.Header) + } +} + +func (cfg *clientConfig) setHeader(key, value string) { + cfg.initHeaders() + cfg.httpHeaders.Set(key, value) +} + +type optionFunc func(*clientConfig) + +func (fn optionFunc) applyOption(opt *clientConfig) { + fn(opt) +} + +// WithWebsocketDialer configures the websocket.Dialer used by the RPC client. +func WithWebsocketDialer(dialer websocket.Dialer) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.wsDialer = &dialer + }) +} + +// WithHeader configures HTTP headers set by the RPC client. Headers set using this option +// will be used for both HTTP and WebSocket connections. +func WithHeader(key, value string) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.initHeaders() + cfg.httpHeaders.Set(key, value) + }) +} + +// WithHeaders configures HTTP headers set by the RPC client. Headers set using this +// option will be used for both HTTP and WebSocket connections. +func WithHeaders(headers http.Header) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.initHeaders() + for k, vs := range headers { + cfg.httpHeaders[k] = vs + } + }) +} + +// WithHTTPClient configures the http.Client used by the RPC client. +func WithHTTPClient(c *http.Client) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.httpClient = c + }) +} + +// WithHTTPAuth configures HTTP request authentication. The given provider will be called +// whenever a request is made. Note that only one authentication provider can be active at +// any time. +func WithHTTPAuth(a HTTPAuth) ClientOption { + if a == nil { + panic("nil auth") + } + return optionFunc(func(cfg *clientConfig) { + cfg.httpAuth = a + }) +} + +// A HTTPAuth function is called by the client whenever a HTTP request is sent. +// The function must be safe for concurrent use. +// +// Usually, HTTPAuth functions will call h.Set("authorization", "...") to add +// auth information to the request. +type HTTPAuth func(h http.Header) error diff --git a/rpc/client_opt_test.go b/rpc/client_opt_test.go new file mode 100644 index 000000000000..d7cc2572a776 --- /dev/null +++ b/rpc/client_opt_test.go @@ -0,0 +1,25 @@ +package rpc_test + +import ( + "context" + "net/http" + "time" + + "github.com/ethereum/go-ethereum/rpc" +) + +// This example configures a HTTP-based RPC client with two options - one setting the +// overall request timeout, the other adding a custom HTTP header to all requests. +func ExampleDialOptions() { + tokenHeader := rpc.WithHeader("x-token", "foo") + httpClient := rpc.WithHTTPClient(&http.Client{ + Timeout: 10 * time.Second, + }) + + ctx := context.Background() + c, err := rpc.DialOptions(ctx, "http://rpc.example.com", httpClient, tokenHeader) + if err != nil { + panic(err) + } + c.Close() +} diff --git a/rpc/http.go b/rpc/http.go index 858d80858652..8595959afb66 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -45,6 +45,7 @@ type httpConn struct { closeCh chan interface{} mu sync.Mutex // protects headers headers http.Header + auth HTTPAuth } // httpConn implements ServerCodec, but it is treated specially by Client @@ -117,8 +118,15 @@ var DefaultHTTPTimeouts = HTTPTimeouts{ IdleTimeout: 120 * time.Second, } +// DialHTTP creates a new RPC client that connects to an RPC server over HTTP. +func DialHTTP(endpoint string) (*Client, error) { + return DialHTTPWithClient(endpoint, new(http.Client)) +} + // DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP // using the provided HTTP Client. +// +// Deprecated: use DialOptions and the WithHTTPClient option. func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) { // Sanity check URL so we don't end up with a client that will fail every request. _, err := url.Parse(endpoint) @@ -126,24 +134,35 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) { return nil, err } - initctx := context.Background() - headers := make(http.Header, 2) + var cfg clientConfig + fn := newClientTransportHTTP(endpoint, &cfg) + return newClient(context.Background(), fn) +} + +func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc { + headers := make(http.Header, 2+len(cfg.httpHeaders)) headers.Set("accept", contentType) headers.Set("content-type", contentType) - return newClient(initctx, func(context.Context) (ServerCodec, error) { - hc := &httpConn{ - client: client, - headers: headers, - url: endpoint, - closeCh: make(chan interface{}), - } - return hc, nil - }) -} + for key, values := range cfg.httpHeaders { + headers[key] = values + } -// DialHTTP creates a new RPC client that connects to an RPC server over HTTP. -func DialHTTP(endpoint string) (*Client, error) { - return DialHTTPWithClient(endpoint, new(http.Client)) + client := cfg.httpClient + if client == nil { + client = new(http.Client) + } + + hc := &httpConn{ + client: client, + headers: headers, + url: endpoint, + auth: cfg.httpAuth, + closeCh: make(chan interface{}), + } + + return func(ctx context.Context) (ServerCodec, error) { + return hc, nil + } } func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) error { @@ -195,6 +214,11 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos hc.mu.Lock() req.Header = hc.headers.Clone() hc.mu.Unlock() + if hc.auth != nil { + if err := hc.auth(req.Header); err != nil { + return nil, err + } + } // do request resp, err := hc.client.Do(req) diff --git a/rpc/ipc.go b/rpc/ipc.go index 07a211c6277c..d9e0de62e877 100644 --- a/rpc/ipc.go +++ b/rpc/ipc.go @@ -46,11 +46,15 @@ func (s *Server) ServeListener(l net.Listener) error { // The context is used for the initial connection establishment. It does not // affect subsequent interactions with the client. func DialIPC(ctx context.Context, endpoint string) (*Client, error) { - return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { + return newClient(ctx, newClientTransportIPC(endpoint)) +} + +func newClientTransportIPC(endpoint string) reconnectFunc { + return func(ctx context.Context) (ServerCodec, error) { conn, err := newIPCConnection(ctx, endpoint) if err != nil { return nil, err } return NewCodec(conn), err - }) + } } diff --git a/rpc/json.go b/rpc/json.go index 6024f1e7dc9b..6b2ac2d52a7b 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -58,21 +58,25 @@ type jsonrpcMessage struct { } func (msg *jsonrpcMessage) isNotification() bool { - return msg.ID == nil && msg.Method != "" + return msg.hasValidVersion() && msg.ID == nil && msg.Method != "" } func (msg *jsonrpcMessage) isCall() bool { - return msg.hasValidID() && msg.Method != "" + return msg.hasValidVersion() && msg.hasValidID() && msg.Method != "" } func (msg *jsonrpcMessage) isResponse() bool { - return msg.hasValidID() && msg.Method == "" && msg.Params == nil && (msg.Result != nil || msg.Error != nil) + return msg.hasValidVersion() && msg.hasValidID() && msg.Method == "" && msg.Params == nil && (msg.Result != nil || msg.Error != nil) } func (msg *jsonrpcMessage) hasValidID() bool { return len(msg.ID) > 0 && msg.ID[0] != '{' && msg.ID[0] != '[' } +func (msg *jsonrpcMessage) hasValidVersion() bool { + return msg.Version == vsn +} + func (msg *jsonrpcMessage) isSubscribe() bool { return strings.HasSuffix(msg.Method, subscribeMethodSuffix) } diff --git a/rpc/stdio.go b/rpc/stdio.go index be2bab1c98bd..ae32db26ef1c 100644 --- a/rpc/stdio.go +++ b/rpc/stdio.go @@ -32,12 +32,16 @@ func DialStdIO(ctx context.Context) (*Client, error) { // DialIO creates a client which uses the given IO channels func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) { - return newClient(ctx, func(_ context.Context) (ServerCodec, error) { + return newClient(ctx, newClientTransportIO(in, out)) +} + +func newClientTransportIO(in io.Reader, out io.Writer) reconnectFunc { + return func(context.Context) (ServerCodec, error) { return NewCodec(stdioConn{ in: in, out: out, }), nil - }) + } } type stdioConn struct { diff --git a/rpc/subscription_test.go b/rpc/subscription_test.go index a920205c001f..b2704578291e 100644 --- a/rpc/subscription_test.go +++ b/rpc/subscription_test.go @@ -79,7 +79,7 @@ func TestSubscriptions(t *testing.T) { request := map[string]interface{}{ "id": i, "method": fmt.Sprintf("%s_subscribe", namespace), - "version": "2.0", + "jsonrpc": "2.0", "params": []interface{}{"someSubscription", notificationCount, i}, } if err := out.Encode(&request); err != nil { diff --git a/rpc/testdata/invalid-badversion.js b/rpc/testdata/invalid-badversion.js new file mode 100644 index 000000000000..75b5291dc3f0 --- /dev/null +++ b/rpc/testdata/invalid-badversion.js @@ -0,0 +1,19 @@ +// This test checks processing of messages with invalid Version. + +--> {"jsonrpc":"2.0","id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"result":{"String":"x","Int":3,"Args":null}} + +--> {"jsonrpc":"2.1","id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"jsonrpc":"go-ethereum","id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"jsonrpc":1,"id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"jsonrpc":2.0,"id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} diff --git a/rpc/websocket.go b/rpc/websocket.go index 28380d8aa4ae..f2a923446cac 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -181,24 +181,23 @@ func parseOriginURL(origin string) (string, string, string, error) { return scheme, hostname, port, nil } -// DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server -// that is listening on the given endpoint using the provided dialer. +// DialWebsocketWithDialer creates a new RPC client using WebSocket. +// +// The context is used for the initial connection establishment. It does not +// affect subsequent interactions with the client. +// +// Deprecated: use DialOptions and the WithWebsocketDialer option. func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { - endpoint, header, err := wsClientHeaders(endpoint, origin) + cfg := new(clientConfig) + cfg.wsDialer = &dialer + if origin != "" { + cfg.setHeader("origin", origin) + } + connect, err := newClientTransportWS(endpoint, cfg) if err != nil { return nil, err } - return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { - conn, resp, err := dialer.DialContext(ctx, endpoint, header) - if err != nil { - hErr := wsHandshakeError{err: err} - if resp != nil { - hErr.status = resp.Status - } - return nil, hErr - } - return newWebsocketCodec(conn, endpoint, header), nil - }) + return newClient(ctx, connect) } // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server @@ -207,12 +206,53 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale // The context is used for the initial connection establishment. It does not // affect subsequent interactions with the client. func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { - dialer := websocket.Dialer{ - ReadBufferSize: wsReadBuffer, - WriteBufferSize: wsWriteBuffer, - WriteBufferPool: wsBufferPool, + cfg := new(clientConfig) + if origin != "" { + cfg.setHeader("origin", origin) + } + connect, err := newClientTransportWS(endpoint, cfg) + if err != nil { + return nil, err + } + return newClient(ctx, connect) +} + +func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) { + dialer := cfg.wsDialer + if dialer == nil { + dialer = &websocket.Dialer{ + ReadBufferSize: wsReadBuffer, + WriteBufferSize: wsWriteBuffer, + WriteBufferPool: wsBufferPool, + } + } + + dialURL, header, err := wsClientHeaders(endpoint, "") + if err != nil { + return nil, err + } + for key, values := range cfg.httpHeaders { + header[key] = values + } + + connect := func(ctx context.Context) (ServerCodec, error) { + header := header.Clone() + if cfg.httpAuth != nil { + if err := cfg.httpAuth(header); err != nil { + return nil, err + } + } + conn, resp, err := dialer.DialContext(ctx, dialURL, header) + if err != nil { + hErr := wsHandshakeError{err: err} + if resp != nil { + hErr.status = resp.Status + } + return nil, hErr + } + return newWebsocketCodec(conn, dialURL, header), nil } - return DialWebsocketWithDialer(ctx, endpoint, origin, dialer) + return connect, nil } func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 1da152477d38..e26c22465504 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -58,7 +58,7 @@ type StateTrie struct { // and returns MissingNodeError if the root node cannot be found. func NewStateTrie(owner common.Hash, root common.Hash, db *Database) (*StateTrie, error) { if db == nil { - panic("trie.NewSecure called without a database") + panic("trie.NewStateTrie called without a database") } trie, err := New(owner, root, db) if err != nil { diff --git a/trie/util_test.go b/trie/util_test.go index 252dc09e0804..cf6758e63d4a 100644 --- a/trie/util_test.go +++ b/trie/util_test.go @@ -69,7 +69,9 @@ func TestTrieTracer(t *testing.T) { // Commit the changes and re-create with new root root, nodes, _ := trie.Commit(false) - db.Update(NewWithNodeSet(nodes)) + if err := db.Update(NewWithNodeSet(nodes)); err != nil { + t.Fatal(err) + } trie, _ = New(common.Hash{}, root, db) trie.tracer = newTracer()