Skip to content

Commit

Permalink
limitCountRedis: add limit quota headers
Browse files Browse the repository at this point in the history
Signed-off-by: spacewander <[email protected]>
  • Loading branch information
spacewander committed Feb 6, 2024
1 parent a2f6122 commit 1c4e0a1
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 14 deletions.
10 changes: 8 additions & 2 deletions plugins/limit_count_redis/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package limit_count_redis
import (
"fmt"
"net"
"strings"

"github.com/google/cel-go/cel"
"github.com/google/uuid"
Expand Down Expand Up @@ -60,8 +61,9 @@ func (p *plugin) Config() api.PluginConfig {
type config struct {
Config

client *redis.Client
limiters []*Limiter
client *redis.Client
limiters []*Limiter
quotaPolicy string
}

type Limiter struct {
Expand Down Expand Up @@ -108,18 +110,22 @@ func (conf *config) Init(cb api.ConfigCallbackHandler) error {
api.LogInfof("limitCountRedis filter uses %s as prefix, config: %v", prefix, &conf.Config)

conf.limiters = make([]*Limiter, len(conf.Rules))
quotaPolicy := make([]string, len(conf.Rules))
for i, rule := range conf.Rules {
conf.limiters[i] = &Limiter{
count: rule.Count,
timeWindow: rule.TimeWindow.Seconds,
prefix: fmt.Sprintf("%s|%d", prefix, i),
}
quotaPolicy[i] = fmt.Sprintf("%d;w=%d", rule.Count, rule.TimeWindow.Seconds)

if rule.Key == "" {
continue
}
script, _ := expr.CompileCel(rule.Key, cel.StringType)
conf.limiters[i].script = script
}
conf.quotaPolicy = strings.Join(quotaPolicy, ", ")

return nil
}
28 changes: 20 additions & 8 deletions plugins/limit_count_redis/config.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions plugins/limit_count_redis/config.pb.validate.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions plugins/limit_count_redis/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ message Config {
// put a max limit as the rules are sent as one lua script
repeated Rule rules = 2 [(validate.rules).repeated = {min_items: 1, max_items: 8}];
bool failure_mode_deny = 3;
bool enable_limit_quota_headers = 4;
}
45 changes: 41 additions & 4 deletions plugins/limit_count_redis/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ package limit_count_redis
import (
"context"
"fmt"
"math"
"net/http"
"strconv"

"mosn.io/htnn/pkg/expr"
"mosn.io/htnn/pkg/filtermanager/api"
Expand All @@ -40,6 +42,8 @@ type filter struct {

callbacks api.FilterCallbackHandler
config *config

ress []interface{}
}

func (f *filter) getKey(script expr.Script, headers api.RequestHeaderMap) string {
Expand All @@ -65,7 +69,7 @@ var (
if ttl<0 then
redis.call('set',KEYS[i],ARGV[i*2-1]-1,'EX',ARGV[i*2])
res[i*2-1]=ARGV[i*2-1]-1
res[i*2]=ARGV[i*2]
res[i*2]=tonumber(ARGV[i*2])
else
res[i*2-1]=redis.call('incrby',KEYS[i],-1)
res[i*2]=ttl
Expand Down Expand Up @@ -103,16 +107,49 @@ func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api
}

ress := res.([]interface{})
for i := 0; i < len(config.limiters); i++ {
f.ress = ress

for i := range config.limiters {
remain := ress[2*i].(int64)
if remain < 0 {
// TODO: add X-RateLimit headers
hdr := http.Header{}
// TODO: add option to disable x-envoy-ratelimited
hdr.Set("X-Envoy-Ratelimited", "true")
hdr.Set("x-envoy-ratelimited", "true")
return &api.LocalResponse{Code: 429, Header: hdr}
}
}

return api.Continue
}

func (f *filter) EncodeHeaders(headers api.ResponseHeaderMap, endStream bool) api.ResultAction {
config := f.config
if !config.EnableLimitQuotaHeaders {
return api.Continue
}

var minCount uint32
var minRemain int64 = math.MaxUint32
var minTTL int64
for i, lim := range f.config.limiters {
remain := f.ress[2*i].(int64)
ttl := f.ress[2*i+1].(int64)

if remain < minRemain {
minRemain = remain
minCount = lim.count
minTTL = ttl
}
}

// According to the RFC, these headers MUST NOT occur multiple times.
headers.Add("x-ratelimit-limit", fmt.Sprintf("%d, %s", minCount, config.quotaPolicy))
if minRemain <= 0 {
headers.Add("x-ratelimit-remaining", "0")
} else {
headers.Add("x-ratelimit-remaining", strconv.FormatInt(minRemain, 10))
}
headers.Add("x-ratelimit-remaining", strconv.FormatInt(minRemain, 10))
headers.Add("x-ratelimit-reset", strconv.FormatInt(minTTL, 10))
return api.Continue
}
73 changes: 73 additions & 0 deletions plugins/tests/integration/limit_count_redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,79 @@ func TestLimitCountRedis(t *testing.T) {
assert.Equal(t, 200, resp.StatusCode)
},
},
{
name: "single rule, with limit quota headers enabled",
config: control_plane.NewSinglePluinConfig("limitCountRedis", map[string]interface{}{
"address": "redis:6379",
"enableLimitQuotaHeaders": true,
"rules": []interface{}{
map[string]interface{}{
"count": 1,
"timeWindow": "1s",
"key": `request.header("x-key")`,
},
},
}),
run: func(t *testing.T) {
hdr := http.Header{}
hdr.Add("x-key", "1")
resp, _ := dp.Head("/echo", hdr)
assert.Equal(t, 200, resp.StatusCode)
assert.Equal(t, "1, 1;w=1", resp.Header.Get("X-Ratelimit-Limit"))
assert.Equal(t, "0", resp.Header.Get("X-Ratelimit-Remaining"))
assert.Equal(t, "1", resp.Header.Get("X-Ratelimit-Reset"))
resp, _ = dp.Head("/echo", hdr)
assert.Equal(t, 429, resp.StatusCode)
assert.Equal(t, "1, 1;w=1", resp.Header.Get("X-Ratelimit-Limit"))
assert.Equal(t, "0", resp.Header.Get("X-Ratelimit-Remaining"))
assert.Equal(t, "1", resp.Header.Get("X-Ratelimit-Reset"))
},
},
{
name: "multiple rules, with limit quota headers enabled",
config: control_plane.NewSinglePluinConfig("limitCountRedis", map[string]interface{}{
"address": "redis:6379",
"enableLimitQuotaHeaders": true,
"rules": []interface{}{
map[string]interface{}{
"count": 2,
"timeWindow": "10s",
"key": `request.header("x-key")`,
},
map[string]interface{}{
"count": 2,
"timeWindow": "1s",
},
map[string]interface{}{
"count": 3,
"timeWindow": "1s",
},
},
}),
run: func(t *testing.T) {
hdr := http.Header{}
hdr.Add("x-key", "1")
resp, _ := dp.Head("/echo", hdr)
assert.Equal(t, 200, resp.StatusCode)
assert.Equal(t, "2, 2;w=10, 2;w=1, 3;w=1", resp.Header.Get("X-Ratelimit-Limit"))
assert.Equal(t, "1", resp.Header.Get("X-Ratelimit-Remaining"))
assert.Equal(t, "10", resp.Header.Get("X-Ratelimit-Reset"))

hdr2 := http.Header{}
hdr2.Add("x-key", "2")
resp, _ = dp.Head("/echo", hdr2)
assert.Equal(t, 200, resp.StatusCode)
assert.Equal(t, "2, 2;w=10, 2;w=1, 3;w=1", resp.Header.Get("X-Ratelimit-Limit"))
assert.Equal(t, "0", resp.Header.Get("X-Ratelimit-Remaining"))
assert.Equal(t, "1", resp.Header.Get("X-Ratelimit-Reset"))

resp, _ = dp.Head("/echo", nil)
assert.Equal(t, 429, resp.StatusCode)
assert.Equal(t, "2, 2;w=10, 2;w=1, 3;w=1", resp.Header.Get("X-Ratelimit-Limit"))
assert.Equal(t, "0", resp.Header.Get("X-Ratelimit-Remaining"))
assert.Equal(t, "1", resp.Header.Get("X-Ratelimit-Reset"))
},
},
}

for _, tt := range tests {
Expand Down

0 comments on commit 1c4e0a1

Please sign in to comment.