diff --git a/go.mod b/go.mod index 7a9e862..a592f31 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.19 replace golang.org/x/crypto => golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f require ( + github.com/alicebob/miniredis/v2 v2.23.1 github.com/gin-gonic/gin v1.8.1 github.com/go-redis/redis/v8 v8.11.5 github.com/go-redis/redis_rate/v9 v9.1.2 @@ -14,6 +15,7 @@ require ( ) require ( + github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect @@ -32,6 +34,7 @@ require ( github.com/pelletier/go-toml/v2 v2.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/ugorji/go/codec v1.2.7 // indirect + github.com/yuin/gopher-lua v0.0.0-20220504180219-658193537a64 // indirect golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect golang.org/x/net v0.0.0-20220722155237-a158d28d115b // indirect golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f // indirect diff --git a/go.sum b/go.sum index 7a68d74..2e77316 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,12 @@ +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis/v2 v2.23.1 h1:jR6wZggBxwWygeXcdNyguCOCIjPsZyNUNlAkTx2fu0U= +github.com/alicebob/miniredis/v2 v2.23.1/go.mod h1:84TWKZlxYkfgMucPBf5SOQBYJceZeQRFIaQgNMiCX6Q= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -82,11 +89,14 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= +github.com/yuin/gopher-lua v0.0.0-20220504180219-658193537a64 h1:5mLPGnFdSsevFRFc9q3yYbBkB6tsm4aCwwQV/j1JQAQ= +github.com/yuin/gopher-lua v0.0.0-20220504180219-658193537a64/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f h1:OeJjE6G4dgCY4PIXvIRQbE8+RX+uXZyGhUy/ksMGJoc= golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b h1:PxfKdU9lEEDYjdIzOtC4qFWgkU2rGHdKlKowJSMN9h0= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/ratelimit.go b/ratelimit.go index 9e98202..2f5688d 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -15,11 +15,16 @@ type ( ErrFunc func(*gin.Context, error) (shouldReturn bool) ) +const ( + ratelimitReset = "X-Ratelimit-Reset" + ratelimitLimit = "X-Ratelimit-Limit" + ratelimitRemaining = "X-Ratelimit-Remaining" +) + // RedisRateLimiter ... -func RedisRateLimiter(opts *redis.Options, key KeyFunc, errFunc ErrFunc) gin.HandlerFunc { +func RedisRateLimiter(rdb *redis.Client, key KeyFunc, errFunc ErrFunc) gin.HandlerFunc { return func(c *gin.Context) { ctx := c.Request.Context() - rdb := redis.NewClient(opts) limiter := redis_rate.NewLimiter(rdb) key, limit, err := key(c) if err != nil { @@ -31,9 +36,9 @@ func RedisRateLimiter(opts *redis.Options, key KeyFunc, errFunc ErrFunc) gin.Han res, err := limiter.Allow(ctx, key, redis_rate.PerMinute(PtrValue(limit))) if err == nil { reset := time.Now().Add(res.ResetAfter) - c.Header("X-Ratelimit-Reset", reset.String()) - c.Header("X-Ratelimit-Limit", strconv.Itoa(PtrValue(limit))) - c.Header("X-Ratelimit-Remaining", strconv.Itoa(res.Remaining)) + c.Header(ratelimitReset, strconv.Itoa(int(reset.Unix()))) + c.Header(ratelimitLimit, strconv.Itoa(PtrValue(limit))) + c.Header(ratelimitRemaining, strconv.Itoa(res.Remaining)) if res.Allowed <= 0 { c.JSON(http.StatusTooManyRequests, ErrorResponse{Code: http.StatusTooManyRequests, Message: "rate limit exceeded"}, diff --git a/ratelimit_test.go b/ratelimit_test.go new file mode 100644 index 0000000..c8b16d3 --- /dev/null +++ b/ratelimit_test.go @@ -0,0 +1,145 @@ +package common + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/alicebob/miniredis/v2" + "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" + "github.com/stretchr/testify/require" +) + +func setupRouter(mw gin.HandlerFunc) *gin.Engine { + r := gin.New() + r.GET("/healthz", mw, func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + return r +} + +//nolint:bodyclose +func TestRedisRateLimiterAlways(t *testing.T) { + s, err := miniredis.Run() + require.Equal(t, err, nil) + redisClient := redis.NewClient(&redis.Options{ + Addr: s.Addr(), + }) + alwaysRateLimiter := RedisRateLimiter(redisClient, + func(c *gin.Context) (key string, limit *int, err error) { + return "test-user", Int(2), nil + }, + func(c *gin.Context, err error) bool { + if err != nil { + t.Log(err) + } + return false + }, + ) + + router := setupRouter(alwaysRateLimiter) + require.Equal(t, err, nil) + + w := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/healthz", nil) + require.Equal(t, err, nil) + router.ServeHTTP(w, req) + require.Equal(t, 200, w.Code) + require.Equal(t, "ok", w.Body.String()) + require.Equal(t, "2", w.Result().Header.Get(ratelimitLimit)) + require.Equal(t, "1", w.Result().Header.Get(ratelimitRemaining)) + + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req) + require.Equal(t, 200, w2.Code) + require.Equal(t, "ok", w2.Body.String()) + require.Equal(t, "2", w2.Result().Header.Get(ratelimitLimit)) + require.Equal(t, "0", w2.Result().Header.Get(ratelimitRemaining)) + + w3 := httptest.NewRecorder() + router.ServeHTTP(w3, req) + require.Equal(t, 429, w3.Code) + require.Equal(t, `{"code":429,"message":"rate limit exceeded"}`, w3.Body.String()) + require.Equal(t, "2", w3.Result().Header.Get(ratelimitLimit)) + require.Equal(t, "0", w3.Result().Header.Get(ratelimitRemaining)) +} + +//nolint:bodyclose +func TestRedisRateLimiterSkip(t *testing.T) { + s, err := miniredis.Run() + require.Equal(t, err, nil) + redisClient := redis.NewClient(&redis.Options{ + Addr: s.Addr(), + }) + + skipRateLimiter := RedisRateLimiter(redisClient, + func(c *gin.Context) (key string, limit *int, err error) { + return "", nil, nil + }, + func(c *gin.Context, err error) bool { + if err != nil { + t.Log(err) + } + return false + }, + ) + + router := setupRouter(skipRateLimiter) + require.Equal(t, err, nil) + for i := 1; i < 5; i++ { + w := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/healthz", nil) + require.Equal(t, err, nil) + router.ServeHTTP(w, req) + require.Equal(t, 200, w.Code) + require.Equal(t, "ok", w.Body.String()) + + require.Equal(t, "", w.Result().Header.Get(ratelimitLimit)) + require.Equal(t, "", w.Result().Header.Get(ratelimitRemaining)) + } +} + +//nolint:bodyclose +func TestRedisRateLimiterForce(t *testing.T) { + s, err := miniredis.Run() + require.Equal(t, err, nil) + redisClient := redis.NewClient(&redis.Options{ + Addr: s.Addr(), + }) + + forceRateLimiter := RedisRateLimiter(redisClient, + func(c *gin.Context) (key string, limit *int, err error) { + return "test-user", Int(2), nil + }, + func(c *gin.Context, err error) bool { + if err != nil { + t.Log(err) + } + c.JSON(http.StatusBadRequest, + ErrorResponse{Code: http.StatusBadRequest, Message: err.Error()}, + ) + c.Abort() + return true + }, + ) + router := setupRouter(forceRateLimiter) + + w := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/healthz", nil) + require.Equal(t, err, nil) + router.ServeHTTP(w, req) + require.Equal(t, 200, w.Code) + require.Equal(t, "ok", w.Body.String()) + require.Equal(t, "2", w.Result().Header.Get(ratelimitLimit)) + require.Equal(t, "1", w.Result().Header.Get(ratelimitRemaining)) + + s.SetError("server is unavailable") + + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req) + require.Equal(t, 400, w2.Code) + require.Equal(t, `{"code":400,"message":"server is unavailable"}`, w2.Body.String()) + require.Equal(t, "", w2.Result().Header.Get(ratelimitLimit)) + require.Equal(t, "", w2.Result().Header.Get(ratelimitRemaining)) +}