diff --git a/ratelimit.go b/ratelimit.go index 2f5688d..20ffad8 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -32,7 +32,7 @@ func RedisRateLimiter(rdb *redis.Client, key KeyFunc, errFunc ErrFunc) gin.Handl c.Abort() return } - if limit != nil { + if limit != nil && rdb != nil { res, err := limiter.Allow(ctx, key, redis_rate.PerMinute(PtrValue(limit))) if err == nil { reset := time.Now().Add(res.ResetAfter) @@ -53,6 +53,7 @@ func RedisRateLimiter(rdb *redis.Client, key KeyFunc, errFunc ErrFunc) gin.Handl } } } + c.Next() } } diff --git a/ratelimit_test.go b/ratelimit_test.go index c8b16d3..83ff049 100644 --- a/ratelimit_test.go +++ b/ratelimit_test.go @@ -143,3 +143,32 @@ func TestRedisRateLimiterForce(t *testing.T) { require.Equal(t, "", w2.Result().Header.Get(ratelimitLimit)) require.Equal(t, "", w2.Result().Header.Get(ratelimitRemaining)) } + +//nolint:bodyclose +func TestRedisRateLimiterNil(t *testing.T) { + nilLimiter := RedisRateLimiter(nil, + func(c *gin.Context) (key string, limit *int, err error) { + return "test-user", Int(2), nil + }, + func(c *gin.Context, err error) bool { + if err != nil { + t.Log(err) + } + c.JSON(http.StatusBadRequest, + ErrorResponse{Code: http.StatusBadRequest, Message: err.Error()}, + ) + c.Abort() + return true + }, + ) + router := setupRouter(nilLimiter) + + w := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/healthz", nil) + require.Equal(t, err, nil) + router.ServeHTTP(w, req) + require.Equal(t, 200, w.Code) + require.Equal(t, "ok", w.Body.String()) + require.Equal(t, "", w.Result().Header.Get(ratelimitLimit)) + require.Equal(t, "", w.Result().Header.Get(ratelimitRemaining)) +}