Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added Websocket support and authentication to Engine API #3752

Merged
merged 2 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 22 additions & 52 deletions cmd/rpcdaemon/cli/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"strings"
"time"

"github.com/golang-jwt/jwt/v4"
"github.com/ledgerwatch/erigon-lib/direct"
"github.com/ledgerwatch/erigon-lib/gointerfaces"
"github.com/ledgerwatch/erigon-lib/gointerfaces/grpcutil"
Expand Down Expand Up @@ -53,7 +52,6 @@ var rootCmd = &cobra.Command{
Short: "rpcdaemon is JSON RPC server that connects to Erigon node for remote DB access",
}

const JwtTokenExpiry = 5 * time.Second
const JwtDefaultFile = "jwt.hex"

func RootCommand() (*cobra.Command, *httpcfg.HttpCfg) {
Expand Down Expand Up @@ -475,10 +473,10 @@ func StartRpcServer(ctx context.Context, cfg httpcfg.HttpCfg, rpcAPI []rpc.API)
httpHandler := node.NewHTTPHandlerStack(srv, cfg.HttpCORSDomain, cfg.HttpVirtualHost, cfg.HttpCompression)
var wsHandler http.Handler
if cfg.WebsocketEnabled {
wsHandler = srv.WebsocketHandler([]string{"*"}, cfg.WebsocketCompression)
wsHandler = srv.WebsocketHandler([]string{"*"}, nil, cfg.WebsocketCompression)
}

apiHandler, err := createHandler(cfg, defaultAPIList, httpHandler, wsHandler, false)
apiHandler, err := createHandler(cfg, defaultAPIList, httpHandler, wsHandler, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -584,13 +582,7 @@ func obtainJWTSecret(cfg httpcfg.HttpCfg) ([]byte, error) {
return jwtSecret, nil
}

func createHandler(cfg httpcfg.HttpCfg, apiList []rpc.API, httpHandler http.Handler, wsHandler http.Handler, isAuth bool) (http.Handler, error) {
// Finds jwt secret
jwtVerificationKey, err := obtainJWTSecret(cfg)
if err != nil {
return nil, err
}

func createHandler(cfg httpcfg.HttpCfg, apiList []rpc.API, httpHandler http.Handler, wsHandler http.Handler, jwtSecret []byte) (http.Handler, error) {
var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// adding a healthcheck here
if health.ProcessHealthcheckIfNeeded(w, r, apiList) {
Expand All @@ -601,43 +593,8 @@ func createHandler(cfg httpcfg.HttpCfg, apiList []rpc.API, httpHandler http.Hand
return
}

if isAuth {
var tokenStr string
// Check if JWT signature is correct
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
tokenStr = strings.TrimPrefix(auth, "Bearer ")
}

if len(tokenStr) == 0 {
http.Error(w, "missing token", http.StatusForbidden)
return
}

keyFunc := func(token *jwt.Token) (interface{}, error) {
return jwtVerificationKey, nil
}
claims := jwt.RegisteredClaims{}
// We explicitly set only HS256 allowed, and also disables the
// claim-check: the RegisteredClaims internally requires 'iat' to
// be no later than 'now', but we allow for a bit of drift.
token, err := jwt.ParseWithClaims(tokenStr, &claims, keyFunc,
jwt.WithValidMethods([]string{"HS256"}),
jwt.WithoutClaimsValidation())

switch {
case err != nil:
http.Error(w, err.Error(), http.StatusForbidden)
case !token.Valid:
http.Error(w, "invalid token", http.StatusForbidden)
case !claims.VerifyExpiresAt(time.Now(), false): // optional
http.Error(w, "token is expired", http.StatusForbidden)
case claims.IssuedAt == nil:
http.Error(w, "missing issued-at", http.StatusForbidden)
case time.Since(claims.IssuedAt.Time) > JwtTokenExpiry:
http.Error(w, "stale token", http.StatusForbidden)
case time.Until(claims.IssuedAt.Time) > JwtTokenExpiry:
http.Error(w, "future token", http.StatusForbidden)
}
if jwtSecret != nil && !rpc.CheckJwtSecret(w, r, jwtSecret) {
return
}

httpHandler.ServeHTTP(w, r)
Expand All @@ -662,13 +619,26 @@ func createEngineListener(cfg httpcfg.HttpCfg, engineApi []rpc.API) (*http.Serve
return nil, nil, nil, "", fmt.Errorf("could not start register RPC engine api: %w", err)
}

jwtSecret, err := obtainJWTSecret(cfg)
if err != nil {
return nil, nil, nil, "", err
}

var wsHandlerNonAuth http.Handler
var wsHandlerAuth http.Handler

if cfg.WebsocketEnabled {
wsHandlerNonAuth = engineSrv.WebsocketHandler([]string{"*"}, nil, cfg.WebsocketCompression)
wsHandlerAuth = engineSrv.WebsocketHandler([]string{"*"}, jwtSecret, cfg.WebsocketCompression)
}

engineHttpHandler := node.NewHTTPHandlerStack(engineSrv, cfg.HttpCORSDomain, cfg.HttpVirtualHost, cfg.HttpCompression)
engineApiHandler, err := createHandler(cfg, engineApi, engineHttpHandler, nil, false)
engineApiHandler, err := createHandler(cfg, engineApi, engineHttpHandler, wsHandlerNonAuth, nil)
if err != nil {
return nil, nil, nil, "", err
}

engineApiHandlerAuth, err := createHandler(cfg, engineApi, engineHttpHandler, nil, true)
engineApiHandlerAuth, err := createHandler(cfg, engineApi, engineHttpHandler, wsHandlerAuth, jwtSecret)
if err != nil {
return nil, nil, nil, "", err
}
Expand All @@ -683,9 +653,9 @@ func createEngineListener(cfg httpcfg.HttpCfg, engineApi []rpc.API) (*http.Serve
return nil, nil, nil, "", fmt.Errorf("could not start RPC api: %w", err)
}

engineInfo := []interface{}{"url", engineHttpEndpoint}
engineInfo := []interface{}{"url", engineHttpEndpoint, "ws", cfg.WebsocketEnabled}
log.Info("HTTP endpoint opened for engine", engineInfo...)
engineInfoAuth := []interface{}{"url", engineHttpEndpointAuth}
engineInfoAuth := []interface{}{"url", engineHttpEndpointAuth, "ws", cfg.WebsocketEnabled}
log.Info("HTTP endpoint opened for auth engine", engineInfoAuth...)

return engineListener, engineListenerAuth, engineSrv, engineHttpEndpoint, nil
Expand Down
2 changes: 1 addition & 1 deletion node/rpcstack.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig, allowList rpc.All
}
h.wsConfig = config
h.wsHandler.Store(&rpcHandler{
Handler: srv.WebsocketHandler(config.Origins, false),
Handler: srv.WebsocketHandler(config.Origins, nil, false),
server: srv,
})
return nil
Expand Down
4 changes: 2 additions & 2 deletions rpc/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ func TestClientReconnect(t *testing.T) {
if err != nil {
t.Fatal("can't listen:", err)
}
go http.Serve(l, srv.WebsocketHandler([]string{"*"}, false))
go http.Serve(l, srv.WebsocketHandler([]string{"*"}, nil, false))
return srv, l
}

Expand Down Expand Up @@ -573,7 +573,7 @@ func httpTestClient(srv *Server, transport string, fl *flakeyListener) (*Client,
var hs *httptest.Server
switch transport {
case "ws":
hs = httptest.NewUnstartedServer(srv.WebsocketHandler([]string{"*"}, false))
hs = httptest.NewUnstartedServer(srv.WebsocketHandler([]string{"*"}, nil, false))
case "http":
hs = httptest.NewUnstartedServer(srv)
default:
Expand Down
47 changes: 47 additions & 0 deletions rpc/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ import (
"mime"
"net/http"
"net/url"
"strings"
"sync"
"time"

"github.com/golang-jwt/jwt/v4"
)

const (
maxRequestContentLength = 1024 * 1024 * 5
contentType = "application/json"
jwtTokenExpiry = 5 * time.Second
)

// https://www.jsonrpc.org/historical/json-rpc-over-http.html#id13
Expand Down Expand Up @@ -280,3 +284,46 @@ func validateRequest(r *http.Request) (int, error) {
err := fmt.Errorf("invalid content type, only %s is supported", contentType)
return http.StatusUnsupportedMediaType, err
}

func CheckJwtSecret(w http.ResponseWriter, r *http.Request, jwtSecret []byte) bool {
var tokenStr string
// Check if JWT signature is correct
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
tokenStr = strings.TrimPrefix(auth, "Bearer ")
}

if len(tokenStr) == 0 {
http.Error(w, "missing token", http.StatusForbidden)
return false
}

keyFunc := func(token *jwt.Token) (interface{}, error) {
return jwtSecret, nil
}
claims := jwt.RegisteredClaims{}
// We explicitly set only HS256 allowed, and also disables the
// claim-check: the RegisteredClaims internally requires 'iat' to
// be no later than 'now', but we allow for a bit of drift.
token, err := jwt.ParseWithClaims(tokenStr, &claims, keyFunc,
jwt.WithValidMethods([]string{"HS256"}),
jwt.WithoutClaimsValidation())

switch {
case err != nil:
http.Error(w, err.Error(), http.StatusForbidden)
case !token.Valid:
http.Error(w, "invalid token", http.StatusForbidden)
case !claims.VerifyExpiresAt(time.Now(), false): // optional
http.Error(w, "token is expired", http.StatusForbidden)
case claims.IssuedAt == nil:
http.Error(w, "missing issued-at", http.StatusForbidden)
case time.Since(claims.IssuedAt.Time) > jwtTokenExpiry:
http.Error(w, "stale token", http.StatusForbidden)
case time.Until(claims.IssuedAt.Time) > jwtTokenExpiry:
http.Error(w, "future token", http.StatusForbidden)
default:
return true
}

return false
}
5 changes: 4 additions & 1 deletion rpc/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ var wsBufferPool = new(sync.Pool)
//
// allowedOrigins should be a comma-separated list of allowed origin URLs.
// To allow connections with any origin, pass "*".
func (s *Server) WebsocketHandler(allowedOrigins []string, compression bool) http.Handler {
func (s *Server) WebsocketHandler(allowedOrigins []string, jwtSecret []byte, compression bool) http.Handler {
upgrader := websocket.Upgrader{
EnableCompression: compression,
ReadBufferSize: wsReadBuffer,
Expand All @@ -55,6 +55,9 @@ func (s *Server) WebsocketHandler(allowedOrigins []string, compression bool) htt
CheckOrigin: wsHandshakeValidator(allowedOrigins),
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if jwtSecret != nil && !CheckJwtSecret(w, r, jwtSecret) {
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Warn("WebSocket upgrade failed", "err", err)
Expand Down
6 changes: 3 additions & 3 deletions rpc/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestWebsocketOriginCheck(t *testing.T) {

var (
srv = newTestServer()
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"}, false))
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"}, nil, false))
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
)
defer srv.Stop()
Expand Down Expand Up @@ -83,7 +83,7 @@ func TestWebsocketLargeCall(t *testing.T) {

var (
srv = newTestServer()
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}, false))
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}, nil, false))
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
)
defer srv.Stop()
Expand Down Expand Up @@ -164,7 +164,7 @@ func TestClientWebsocketPing(t *testing.T) {
func TestClientWebsocketLargeMessage(t *testing.T) {
var (
srv = NewServer(50)
httpsrv = httptest.NewServer(srv.WebsocketHandler(nil, false))
httpsrv = httptest.NewServer(srv.WebsocketHandler(nil, nil, false))
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
)
defer srv.Stop()
Expand Down