From aa408bffe1cf64f485822d07514b50c719228471 Mon Sep 17 00:00:00 2001 From: Jesse Haka Date: Wed, 30 Nov 2022 10:31:26 +0200 Subject: [PATCH 1/2] check nil rdb client --- ratelimit.go | 54 ++++++++++++++++++++++++----------------------- ratelimit_test.go | 29 +++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 26 deletions(-) diff --git a/ratelimit.go b/ratelimit.go index 2f5688d..55c2e27 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -24,32 +24,34 @@ const ( // RedisRateLimiter ... func RedisRateLimiter(rdb *redis.Client, key KeyFunc, errFunc ErrFunc) gin.HandlerFunc { return func(c *gin.Context) { - ctx := c.Request.Context() - limiter := redis_rate.NewLimiter(rdb) - key, limit, err := key(c) - if err != nil { - c.JSON(400, ErrorResponse{Code: 400, Message: err.Error()}) - c.Abort() - return - } - if limit != nil { - res, err := limiter.Allow(ctx, key, redis_rate.PerMinute(PtrValue(limit))) - if err == nil { - reset := time.Now().Add(res.ResetAfter) - c.Header(ratelimitReset, strconv.Itoa(int(reset.Unix()))) - c.Header(ratelimitLimit, strconv.Itoa(PtrValue(limit))) - c.Header(ratelimitRemaining, strconv.Itoa(res.Remaining)) - if res.Allowed <= 0 { - c.JSON(http.StatusTooManyRequests, - ErrorResponse{Code: http.StatusTooManyRequests, Message: "rate limit exceeded"}, - ) - c.Abort() - return - } - } else { - shouldReturn := errFunc(c, err) - if shouldReturn { - return + if rdb != nil { + ctx := c.Request.Context() + limiter := redis_rate.NewLimiter(rdb) + key, limit, err := key(c) + if err != nil { + c.JSON(400, ErrorResponse{Code: 400, Message: err.Error()}) + c.Abort() + return + } + if limit != nil { + res, err := limiter.Allow(ctx, key, redis_rate.PerMinute(PtrValue(limit))) + if err == nil { + reset := time.Now().Add(res.ResetAfter) + c.Header(ratelimitReset, strconv.Itoa(int(reset.Unix()))) + c.Header(ratelimitLimit, strconv.Itoa(PtrValue(limit))) + c.Header(ratelimitRemaining, strconv.Itoa(res.Remaining)) + if res.Allowed <= 0 { + c.JSON(http.StatusTooManyRequests, + ErrorResponse{Code: http.StatusTooManyRequests, Message: "rate limit exceeded"}, + ) + c.Abort() + return + } + } else { + shouldReturn := errFunc(c, err) + if shouldReturn { + return + } } } } diff --git a/ratelimit_test.go b/ratelimit_test.go index c8b16d3..83ff049 100644 --- a/ratelimit_test.go +++ b/ratelimit_test.go @@ -143,3 +143,32 @@ func TestRedisRateLimiterForce(t *testing.T) { require.Equal(t, "", w2.Result().Header.Get(ratelimitLimit)) require.Equal(t, "", w2.Result().Header.Get(ratelimitRemaining)) } + +//nolint:bodyclose +func TestRedisRateLimiterNil(t *testing.T) { + nilLimiter := RedisRateLimiter(nil, + func(c *gin.Context) (key string, limit *int, err error) { + return "test-user", Int(2), nil + }, + func(c *gin.Context, err error) bool { + if err != nil { + t.Log(err) + } + c.JSON(http.StatusBadRequest, + ErrorResponse{Code: http.StatusBadRequest, Message: err.Error()}, + ) + c.Abort() + return true + }, + ) + router := setupRouter(nilLimiter) + + w := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/healthz", nil) + require.Equal(t, err, nil) + router.ServeHTTP(w, req) + require.Equal(t, 200, w.Code) + require.Equal(t, "ok", w.Body.String()) + require.Equal(t, "", w.Result().Header.Get(ratelimitLimit)) + require.Equal(t, "", w.Result().Header.Get(ratelimitRemaining)) +} From 67a8584d8d61e24f4e11b3d42856ff49e5adee4e Mon Sep 17 00:00:00 2001 From: Jesse Haka Date: Wed, 30 Nov 2022 10:46:43 +0200 Subject: [PATCH 2/2] fix nested ifs --- ratelimit.go | 55 ++++++++++++++++++++++++++-------------------------- 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/ratelimit.go b/ratelimit.go index 55c2e27..20ffad8 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -24,37 +24,36 @@ const ( // RedisRateLimiter ... func RedisRateLimiter(rdb *redis.Client, key KeyFunc, errFunc ErrFunc) gin.HandlerFunc { return func(c *gin.Context) { - if rdb != nil { - ctx := c.Request.Context() - limiter := redis_rate.NewLimiter(rdb) - key, limit, err := key(c) - if err != nil { - c.JSON(400, ErrorResponse{Code: 400, Message: err.Error()}) - c.Abort() - return - } - if limit != nil { - res, err := limiter.Allow(ctx, key, redis_rate.PerMinute(PtrValue(limit))) - if err == nil { - reset := time.Now().Add(res.ResetAfter) - c.Header(ratelimitReset, strconv.Itoa(int(reset.Unix()))) - c.Header(ratelimitLimit, strconv.Itoa(PtrValue(limit))) - c.Header(ratelimitRemaining, strconv.Itoa(res.Remaining)) - if res.Allowed <= 0 { - c.JSON(http.StatusTooManyRequests, - ErrorResponse{Code: http.StatusTooManyRequests, Message: "rate limit exceeded"}, - ) - c.Abort() - return - } - } else { - shouldReturn := errFunc(c, err) - if shouldReturn { - return - } + ctx := c.Request.Context() + limiter := redis_rate.NewLimiter(rdb) + key, limit, err := key(c) + if err != nil { + c.JSON(400, ErrorResponse{Code: 400, Message: err.Error()}) + c.Abort() + return + } + if limit != nil && rdb != nil { + res, err := limiter.Allow(ctx, key, redis_rate.PerMinute(PtrValue(limit))) + if err == nil { + reset := time.Now().Add(res.ResetAfter) + c.Header(ratelimitReset, strconv.Itoa(int(reset.Unix()))) + c.Header(ratelimitLimit, strconv.Itoa(PtrValue(limit))) + c.Header(ratelimitRemaining, strconv.Itoa(res.Remaining)) + if res.Allowed <= 0 { + c.JSON(http.StatusTooManyRequests, + ErrorResponse{Code: http.StatusTooManyRequests, Message: "rate limit exceeded"}, + ) + c.Abort() + return + } + } else { + shouldReturn := errFunc(c, err) + if shouldReturn { + return } } } + c.Next() } }