Skip to content

Commit

Permalink
connection rate limiting
Browse files Browse the repository at this point in the history
  • Loading branch information
Cory Schwartz committed May 20, 2022
1 parent c9d3652 commit b305483
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 4 deletions.
14 changes: 13 additions & 1 deletion cmd/lotus-gateway/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,21 @@ var runCmd = &cli.Command{
Usage: "rate-limit API calls. Use 0 to disable",
Value: 0,
},
&cli.Int64Flag{
Name: "per-conn-rate-limit",
Usage: "rate-limit API calls per each connection. Use 0 to disable",
Value: 0,
},
&cli.DurationFlag{
Name: "rate-limit-timeout",
Usage: "the maximum time to wait for the rate limter before returning an error to clients",
Value: gateway.DefaultRateLimitTimeout,
},
&cli.Int64Flag{
Name: "conn-per-minute",
Usage: "The number of incomming connections to accept from a single IP per minute. Use 0 to disable",
Value: 0,
},
},
Action: func(cctx *cli.Context) error {
log.Info("Starting lotus gateway")
Expand All @@ -165,7 +175,9 @@ var runCmd = &cli.Command{
address = cctx.String("listen")
waitLookback = abi.ChainEpoch(cctx.Int64("api-wait-lookback-limit"))
rateLimit = cctx.Int64("rate-limit")
perConnRateLimit = cctx.Int64("per-conn-rate-limit")
rateLimitTimeout = cctx.Duration("rate-limit-timeout")
connPerMinute = cctx.Int64("conn-per-minute")
)

serverOptions := make([]jsonrpc.ServerOption, 0)
Expand All @@ -186,7 +198,7 @@ var runCmd = &cli.Command{
}

gwapi := gateway.NewNode(api, lookbackCap, waitLookback, rateLimit, rateLimitTimeout)
h, err := gateway.Handler(gwapi, serverOptions...)
h, err := gateway.Handler(gwapi, perConnRateLimit, connPerMinute, serverOptions...)
if err != nil {
return xerrors.Errorf("failed to set up gateway HTTP handler")
}
Expand Down
99 changes: 97 additions & 2 deletions gateway/handler.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package gateway

import (
"context"
"net"
"net/http"
"sync"
"time"

"contrib.go.opencensus.io/exporter/prometheus"
"github.com/filecoin-project/go-jsonrpc"
Expand All @@ -11,10 +15,11 @@ import (
"github.com/filecoin-project/lotus/metrics/proxy"
"github.com/gorilla/mux"
promclient "github.com/prometheus/client_golang/prometheus"
"golang.org/x/time/rate"
)

// Handler returns a gateway http.Handler, to be mounted as-is on the server.
func Handler(a api.Gateway, opts ...jsonrpc.ServerOption) (http.Handler, error) {
func Handler(a api.Gateway, rateLimit int64, connPerMinute int64, opts ...jsonrpc.ServerOption) (http.Handler, error) {
m := mux.NewRouter()

serveRpc := func(path string, hnd interface{}) {
Expand Down Expand Up @@ -44,5 +49,95 @@ func Handler(a api.Gateway, opts ...jsonrpc.ServerOption) (http.Handler, error)
Next: mux.ServeHTTP,
}*/

return m, nil
rlh := NewRateLimiterHandler(m, rateLimit)
clh := NewConnectionRateLimiterHandler(rlh, connPerMinute)
return clh, nil
}

func NewRateLimiterHandler(handler http.Handler, rateLimit int64) *RateLimiterHandler {
limiter := limiterFromRateLimit(rateLimit)

return &RateLimiterHandler{
handler: handler,
limiter: limiter,
}
}

// Adds a rate limiter to the request context for per-connection rate limiting
type RateLimiterHandler struct {
handler http.Handler
limiter *rate.Limiter
}

func (h RateLimiterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r2 := r.WithContext(context.WithValue(r.Context(), "limiter", h.limiter))
h.handler.ServeHTTP(w, r2)
}

// this blocks new connections if there have already been too many.
func NewConnectionRateLimiterHandler(handler http.Handler, connPerMinute int64) *ConnectionRateLimiterHandler {
ipmap := make(map[string]int64)
return &ConnectionRateLimiterHandler{
ipmap: ipmap,
connPerMinute: connPerMinute,
handler: handler,
}
}

type ConnectionRateLimiterHandler struct {
mu sync.Mutex
ipmap map[string]int64
connPerMinute int64
handler http.Handler
}

func (h *ConnectionRateLimiterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if h.connPerMinute == 0 {
h.handler.ServeHTTP(w, r)
return
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}

h.mu.Lock()
seen, ok := h.ipmap[host]
if !ok {
h.ipmap[host] = 1
h.mu.Unlock()
h.handler.ServeHTTP(w, r)
return
}
// rate limited
if seen > h.connPerMinute {
h.mu.Unlock()
w.WriteHeader(http.StatusTooManyRequests)
return
}
h.ipmap[host] = seen + 1
h.mu.Unlock()
go func() {
select {
case <-time.After(time.Minute):
h.mu.Lock()
defer h.mu.Unlock()
h.ipmap[host] = h.ipmap[host] - 1
if h.ipmap[host] <= 0 {
delete(h.ipmap, host)
}
}
}()
h.handler.ServeHTTP(w, r)
}

func limiterFromRateLimit(rateLimit int64) *rate.Limiter {
var limit rate.Limit
if rateLimit == 0 {
limit = rate.Inf
} else {
limit = rate.Every(time.Second / time.Duration(rateLimit))
}
return rate.NewLimiter(limit, stateRateLimitTokens)
}
9 changes: 8 additions & 1 deletion gateway/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,13 @@ func (gw *Node) checkTimestamp(at time.Time) error {
func (gw *Node) limit(ctx context.Context, tokens int) error {
ctx2, cancel := context.WithTimeout(ctx, gw.rateLimitTimeout)
defer cancel()
if perConnLimiter, ok := ctx2.Value("limiter").(*rate.Limiter); ok {
err := perConnLimiter.WaitN(ctx2, tokens)
if err != nil {
return fmt.Errorf("connection limited. %w", err)
}
}

err := gw.rateLimiter.WaitN(ctx2, tokens)
if err != nil {
stats.Record(ctx, metrics.RateLimitCount.M(1))
Expand Down Expand Up @@ -212,7 +219,7 @@ func (gw *Node) ChainHead(ctx context.Context) (*types.TipSet, error) {
if err := gw.limit(ctx, chainRateLimitTokens); err != nil {
return nil, err
}
// TODO: cache and invalidate cache when timestamp is up (or have internal ChainNotify)

return gw.target.ChainHead(ctx)
}

Expand Down

0 comments on commit b305483

Please sign in to comment.