Skip to content

Commit

Permalink
feat: handle panic in rpc rate limit interceptor
Browse files Browse the repository at this point in the history
  • Loading branch information
Poonam Jadhav authored and JadhavPoonam committed Jan 23, 2023
1 parent a1e2a4f commit 81439cc
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 3 deletions.
2 changes: 1 addition & 1 deletion agent/consul/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server, incom
}

rpcServerOpts := []func(*rpc.Server){
rpc.WithPreBodyInterceptor(middleware.GetNetRPCRateLimitingInterceptor(s.incomingRPCLimiter)),
rpc.WithPreBodyInterceptor(middleware.GetNetRPCRateLimitingInterceptor(s.incomingRPCLimiter, middleware.NewPanicHandler(s.logger))),
}

if flat.GetNetRPCInterceptorFunc != nil {
Expand Down
11 changes: 9 additions & 2 deletions agent/rpc/middleware/interceptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,16 @@ func GetNetRPCInterceptor(recorder *RequestRecorder) rpc.ServerServiceCallInterc
}
}

func GetNetRPCRateLimitingInterceptor(requestLimitsHandler rpcRate.RequestLimitsHandler) rpc.PreBodyInterceptor {
func GetNetRPCRateLimitingInterceptor(requestLimitsHandler rpcRate.RequestLimitsHandler, panicHandler RecoveryHandlerFunc) rpc.PreBodyInterceptor {

return func(reqServiceMethod string, sourceAddr net.Addr) (retErr error) {

defer func() {
if r := recover(); r != nil {
retErr = panicHandler(r)
}
}()

return func(reqServiceMethod string, sourceAddr net.Addr) error {
op := rpcRate.Operation{
Name: reqServiceMethod,
SourceAddr: sourceAddr,
Expand Down
23 changes: 23 additions & 0 deletions agent/rpc/middleware/interceptors_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package middleware

import (
"net"
"strings"
"sync"
"testing"
"time"

"github.com/armon/go-metrics"
"github.com/hashicorp/consul/agent/consul/rate"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -266,3 +269,23 @@ func TestRequestRecorder(t *testing.T) {
})
}
}

func TestGetNetRPCRateLimitingInterceptor(t *testing.T) {
limiter := rate.NewMockRequestLimitsHandler(t)

logger := hclog.NewNullLogger()
var rateLimitInterceptor = GetNetRPCRateLimitingInterceptor(limiter, NewPanicHandler(logger))

listener, _ := net.Listen("tcp", "127.0.0.1:0")

t.Run("Allow panics", func(t *testing.T) {
limiter.On("Allow", mock.Anything).
Panic("uh oh").
Once()

err := rateLimitInterceptor("Status.Leader", listener.Addr())

require.Error(t, err)
require.Equal(t, "rpc: panic serving request", err.Error())
})
}
24 changes: 24 additions & 0 deletions agent/rpc/middleware/recovery.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package middleware

import (
"fmt"

"github.com/hashicorp/go-hclog"
)

// NewPanicHandler returns a RecoveryHandlerFunc type function
// to handle panic in RPC server's handlers.
func NewPanicHandler(logger hclog.Logger) RecoveryHandlerFunc {
return func(p interface{}) (err error) {
// Log the panic and the stack trace of the Goroutine that caused the panic.
stacktrace := hclog.Stacktrace()
logger.Error("panic serving rpc request",
"panic", p,
"stack", stacktrace,
)

return fmt.Errorf("rpc: panic serving request")
}
}

type RecoveryHandlerFunc func(p interface{}) (err error)

0 comments on commit 81439cc

Please sign in to comment.