Skip to content

Commit

Permalink
Fix limit middleware skip options (#1568)
Browse files Browse the repository at this point in the history
* fix limit middleware skip options

* fix limiter middleware remaining count

* used constant StatusBadRequest instead of int 400
  • Loading branch information
aliereno authored Oct 11, 2021
1 parent 9eaa8b0 commit 9c37b4c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 10 deletions.
25 changes: 15 additions & 10 deletions middleware/limiter/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ func New(config ...Config) fiber.Handler {
return c.Next()
}

// Continue stack for reaching c.Response().StatusCode()
// Store err for returning
err := c.Next()

// Get key from request
key := cfg.KeyGenerator(c)

Expand All @@ -76,12 +72,8 @@ func New(config ...Config) fiber.Handler {
e.exp = ts + expiration
}

// Check for SkipFailedRequests and SkipSuccessfulRequests
if (!cfg.SkipSuccessfulRequests || c.Response().StatusCode() >= 400) &&
(!cfg.SkipFailedRequests || c.Response().StatusCode() < 400) {
// Increment hits
e.hits++
}
// Increment hits
e.hits++

// Calculate when it resets in seconds
expire := e.exp - ts
Expand All @@ -105,6 +97,19 @@ func New(config ...Config) fiber.Handler {
return cfg.LimitReached(c)
}

// Continue stack for reaching c.Response().StatusCode()
// Store err for returning
err := c.Next()

// Check for SkipFailedRequests and SkipSuccessfulRequests
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
mux.Lock()
e.hits--
remaining++
mux.Unlock()
}

// We can continue, update RateLimit headers
c.Set(xRateLimitLimit, max)
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
Expand Down
33 changes: 33 additions & 0 deletions middleware/limiter/limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,39 @@ func Test_Limiter_Concurrency(t *testing.T) {

}

// go test -run Test_Limiter_No_Skip_Choices -v
func Test_Limiter_No_Skip_Choices(t *testing.T) {

app := fiber.New()

app.Use(New(Config{
Max: 2,
Expiration: 2 * time.Second,
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
}))

app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})

resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)

resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)

resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)

}

// go test -run Test_Limiter_Skip_Failed_Requests -v
func Test_Limiter_Skip_Failed_Requests(t *testing.T) {

Expand Down

0 comments on commit 9c37b4c

Please sign in to comment.