Skip to content

Commit

Permalink
added Websocket support and authentication to Engine API (#3752)
Browse files Browse the repository at this point in the history
* added ws support and auth

* fixed lint
  • Loading branch information
Giulio2002 authored Mar 23, 2022
1 parent f8668da commit 904674e
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 59 deletions.
74 changes: 22 additions & 52 deletions cmd/rpcdaemon/cli/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"strings"
"time"

"github.com/golang-jwt/jwt/v4"
"github.com/ledgerwatch/erigon-lib/direct"
"github.com/ledgerwatch/erigon-lib/gointerfaces"
"github.com/ledgerwatch/erigon-lib/gointerfaces/grpcutil"
Expand Down Expand Up @@ -53,7 +52,6 @@ var rootCmd = &cobra.Command{
Short: "rpcdaemon is JSON RPC server that connects to Erigon node for remote DB access",
}

const JwtTokenExpiry = 5 * time.Second
const JwtDefaultFile = "jwt.hex"

func RootCommand() (*cobra.Command, *httpcfg.HttpCfg) {
Expand Down Expand Up @@ -475,10 +473,10 @@ func StartRpcServer(ctx context.Context, cfg httpcfg.HttpCfg, rpcAPI []rpc.API)
httpHandler := node.NewHTTPHandlerStack(srv, cfg.HttpCORSDomain, cfg.HttpVirtualHost, cfg.HttpCompression)
var wsHandler http.Handler
if cfg.WebsocketEnabled {
wsHandler = srv.WebsocketHandler([]string{"*"}, cfg.WebsocketCompression)
wsHandler = srv.WebsocketHandler([]string{"*"}, nil, cfg.WebsocketCompression)
}

apiHandler, err := createHandler(cfg, defaultAPIList, httpHandler, wsHandler, false)
apiHandler, err := createHandler(cfg, defaultAPIList, httpHandler, wsHandler, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -584,13 +582,7 @@ func obtainJWTSecret(cfg httpcfg.HttpCfg) ([]byte, error) {
return jwtSecret, nil
}

func createHandler(cfg httpcfg.HttpCfg, apiList []rpc.API, httpHandler http.Handler, wsHandler http.Handler, isAuth bool) (http.Handler, error) {
// Finds jwt secret
jwtVerificationKey, err := obtainJWTSecret(cfg)
if err != nil {
return nil, err
}

func createHandler(cfg httpcfg.HttpCfg, apiList []rpc.API, httpHandler http.Handler, wsHandler http.Handler, jwtSecret []byte) (http.Handler, error) {
var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// adding a healthcheck here
if health.ProcessHealthcheckIfNeeded(w, r, apiList) {
Expand All @@ -601,43 +593,8 @@ func createHandler(cfg httpcfg.HttpCfg, apiList []rpc.API, httpHandler http.Hand
return
}

if isAuth {
var tokenStr string
// Check if JWT signature is correct
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
tokenStr = strings.TrimPrefix(auth, "Bearer ")
}

if len(tokenStr) == 0 {
http.Error(w, "missing token", http.StatusForbidden)
return
}

keyFunc := func(token *jwt.Token) (interface{}, error) {
return jwtVerificationKey, nil
}
claims := jwt.RegisteredClaims{}
// We explicitly set only HS256 allowed, and also disables the
// claim-check: the RegisteredClaims internally requires 'iat' to
// be no later than 'now', but we allow for a bit of drift.
token, err := jwt.ParseWithClaims(tokenStr, &claims, keyFunc,
jwt.WithValidMethods([]string{"HS256"}),
jwt.WithoutClaimsValidation())

switch {
case err != nil:
http.Error(w, err.Error(), http.StatusForbidden)
case !token.Valid:
http.Error(w, "invalid token", http.StatusForbidden)
case !claims.VerifyExpiresAt(time.Now(), false): // optional
http.Error(w, "token is expired", http.StatusForbidden)
case claims.IssuedAt == nil:
http.Error(w, "missing issued-at", http.StatusForbidden)
case time.Since(claims.IssuedAt.Time) > JwtTokenExpiry:
http.Error(w, "stale token", http.StatusForbidden)
case time.Until(claims.IssuedAt.Time) > JwtTokenExpiry:
http.Error(w, "future token", http.StatusForbidden)
}
if jwtSecret != nil && !rpc.CheckJwtSecret(w, r, jwtSecret) {
return
}

httpHandler.ServeHTTP(w, r)
Expand All @@ -662,13 +619,26 @@ func createEngineListener(cfg httpcfg.HttpCfg, engineApi []rpc.API) (*http.Serve
return nil, nil, nil, "", fmt.Errorf("could not start register RPC engine api: %w", err)
}

jwtSecret, err := obtainJWTSecret(cfg)
if err != nil {
return nil, nil, nil, "", err
}

var wsHandlerNonAuth http.Handler
var wsHandlerAuth http.Handler

if cfg.WebsocketEnabled {
wsHandlerNonAuth = engineSrv.WebsocketHandler([]string{"*"}, nil, cfg.WebsocketCompression)
wsHandlerAuth = engineSrv.WebsocketHandler([]string{"*"}, jwtSecret, cfg.WebsocketCompression)
}

engineHttpHandler := node.NewHTTPHandlerStack(engineSrv, cfg.HttpCORSDomain, cfg.HttpVirtualHost, cfg.HttpCompression)
engineApiHandler, err := createHandler(cfg, engineApi, engineHttpHandler, nil, false)
engineApiHandler, err := createHandler(cfg, engineApi, engineHttpHandler, wsHandlerNonAuth, nil)
if err != nil {
return nil, nil, nil, "", err
}

engineApiHandlerAuth, err := createHandler(cfg, engineApi, engineHttpHandler, nil, true)
engineApiHandlerAuth, err := createHandler(cfg, engineApi, engineHttpHandler, wsHandlerAuth, jwtSecret)
if err != nil {
return nil, nil, nil, "", err
}
Expand All @@ -683,9 +653,9 @@ func createEngineListener(cfg httpcfg.HttpCfg, engineApi []rpc.API) (*http.Serve
return nil, nil, nil, "", fmt.Errorf("could not start RPC api: %w", err)
}

engineInfo := []interface{}{"url", engineHttpEndpoint}
engineInfo := []interface{}{"url", engineHttpEndpoint, "ws", cfg.WebsocketEnabled}
log.Info("HTTP endpoint opened for engine", engineInfo...)
engineInfoAuth := []interface{}{"url", engineHttpEndpointAuth}
engineInfoAuth := []interface{}{"url", engineHttpEndpointAuth, "ws", cfg.WebsocketEnabled}
log.Info("HTTP endpoint opened for auth engine", engineInfoAuth...)

return engineListener, engineListenerAuth, engineSrv, engineHttpEndpoint, nil
Expand Down
2 changes: 1 addition & 1 deletion node/rpcstack.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig, allowList rpc.All
}
h.wsConfig = config
h.wsHandler.Store(&rpcHandler{
Handler: srv.WebsocketHandler(config.Origins, false),
Handler: srv.WebsocketHandler(config.Origins, nil, false),
server: srv,
})
return nil
Expand Down
4 changes: 2 additions & 2 deletions rpc/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ func TestClientReconnect(t *testing.T) {
if err != nil {
t.Fatal("can't listen:", err)
}
go http.Serve(l, srv.WebsocketHandler([]string{"*"}, false))
go http.Serve(l, srv.WebsocketHandler([]string{"*"}, nil, false))
return srv, l
}

Expand Down Expand Up @@ -573,7 +573,7 @@ func httpTestClient(srv *Server, transport string, fl *flakeyListener) (*Client,
var hs *httptest.Server
switch transport {
case "ws":
hs = httptest.NewUnstartedServer(srv.WebsocketHandler([]string{"*"}, false))
hs = httptest.NewUnstartedServer(srv.WebsocketHandler([]string{"*"}, nil, false))
case "http":
hs = httptest.NewUnstartedServer(srv)
default:
Expand Down
47 changes: 47 additions & 0 deletions rpc/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ import (
"mime"
"net/http"
"net/url"
"strings"
"sync"
"time"

"github.com/golang-jwt/jwt/v4"
)

const (
maxRequestContentLength = 1024 * 1024 * 5
contentType = "application/json"
jwtTokenExpiry = 5 * time.Second
)

// https://www.jsonrpc.org/historical/json-rpc-over-http.html#id13
Expand Down Expand Up @@ -280,3 +284,46 @@ func validateRequest(r *http.Request) (int, error) {
err := fmt.Errorf("invalid content type, only %s is supported", contentType)
return http.StatusUnsupportedMediaType, err
}

func CheckJwtSecret(w http.ResponseWriter, r *http.Request, jwtSecret []byte) bool {
var tokenStr string
// Check if JWT signature is correct
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
tokenStr = strings.TrimPrefix(auth, "Bearer ")
}

if len(tokenStr) == 0 {
http.Error(w, "missing token", http.StatusForbidden)
return false
}

keyFunc := func(token *jwt.Token) (interface{}, error) {
return jwtSecret, nil
}
claims := jwt.RegisteredClaims{}
// We explicitly set only HS256 allowed, and also disables the
// claim-check: the RegisteredClaims internally requires 'iat' to
// be no later than 'now', but we allow for a bit of drift.
token, err := jwt.ParseWithClaims(tokenStr, &claims, keyFunc,
jwt.WithValidMethods([]string{"HS256"}),
jwt.WithoutClaimsValidation())

switch {
case err != nil:
http.Error(w, err.Error(), http.StatusForbidden)
case !token.Valid:
http.Error(w, "invalid token", http.StatusForbidden)
case !claims.VerifyExpiresAt(time.Now(), false): // optional
http.Error(w, "token is expired", http.StatusForbidden)
case claims.IssuedAt == nil:
http.Error(w, "missing issued-at", http.StatusForbidden)
case time.Since(claims.IssuedAt.Time) > jwtTokenExpiry:
http.Error(w, "stale token", http.StatusForbidden)
case time.Until(claims.IssuedAt.Time) > jwtTokenExpiry:
http.Error(w, "future token", http.StatusForbidden)
default:
return true
}

return false
}
5 changes: 4 additions & 1 deletion rpc/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ var wsBufferPool = new(sync.Pool)
//
// allowedOrigins should be a comma-separated list of allowed origin URLs.
// To allow connections with any origin, pass "*".
func (s *Server) WebsocketHandler(allowedOrigins []string, compression bool) http.Handler {
func (s *Server) WebsocketHandler(allowedOrigins []string, jwtSecret []byte, compression bool) http.Handler {
upgrader := websocket.Upgrader{
EnableCompression: compression,
ReadBufferSize: wsReadBuffer,
Expand All @@ -55,6 +55,9 @@ func (s *Server) WebsocketHandler(allowedOrigins []string, compression bool) htt
CheckOrigin: wsHandshakeValidator(allowedOrigins),
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if jwtSecret != nil && !CheckJwtSecret(w, r, jwtSecret) {
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Warn("WebSocket upgrade failed", "err", err)
Expand Down
6 changes: 3 additions & 3 deletions rpc/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestWebsocketOriginCheck(t *testing.T) {

var (
srv = newTestServer()
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"}, false))
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"}, nil, false))
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
)
defer srv.Stop()
Expand Down Expand Up @@ -83,7 +83,7 @@ func TestWebsocketLargeCall(t *testing.T) {

var (
srv = newTestServer()
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}, false))
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}, nil, false))
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
)
defer srv.Stop()
Expand Down Expand Up @@ -164,7 +164,7 @@ func TestClientWebsocketPing(t *testing.T) {
func TestClientWebsocketLargeMessage(t *testing.T) {
var (
srv = NewServer(50)
httpsrv = httptest.NewServer(srv.WebsocketHandler(nil, false))
httpsrv = httptest.NewServer(srv.WebsocketHandler(nil, nil, false))
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
)
defer srv.Stop()
Expand Down

0 comments on commit 904674e

Please sign in to comment.