diff --git a/cmd/lotus-gateway/main.go b/cmd/lotus-gateway/main.go index c49e8d53288..962bed274b4 100644 --- a/cmd/lotus-gateway/main.go +++ b/cmd/lotus-gateway/main.go @@ -133,13 +133,33 @@ var runCmd = &cli.Command{ Usage: "maximum number of blocks to search back through for message inclusion", Value: int64(gateway.DefaultStateWaitLookbackLimit), }, + &cli.Int64Flag{ + Name: "rate-limit", + Usage: "rate-limit API calls. Use 0 to disable", + Value: 0, + }, + &cli.Int64Flag{ + Name: "per-conn-rate-limit", + Usage: "rate-limit API calls per each connection. Use 0 to disable", + Value: 0, + }, + &cli.DurationFlag{ + Name: "rate-limit-timeout", + Usage: "the maximum time to wait for the rate limter before returning an error to clients", + Value: gateway.DefaultRateLimitTimeout, + }, + &cli.Int64Flag{ + Name: "conn-per-minute", + Usage: "The number of incomming connections to accept from a single IP per minute. Use 0 to disable", + Value: 0, + }, }, Action: func(cctx *cli.Context) error { log.Info("Starting lotus gateway") // Register all metric views if err := view.Register( - metrics.ChainNodeViews..., + metrics.GatewayNodeViews..., ); err != nil { log.Fatalf("Cannot register the view: %v", err) } @@ -151,9 +171,13 @@ var runCmd = &cli.Command{ defer closer() var ( - lookbackCap = cctx.Duration("api-max-lookback") - address = cctx.String("listen") - waitLookback = abi.ChainEpoch(cctx.Int64("api-wait-lookback-limit")) + lookbackCap = cctx.Duration("api-max-lookback") + address = cctx.String("listen") + waitLookback = abi.ChainEpoch(cctx.Int64("api-wait-lookback-limit")) + rateLimit = cctx.Int64("rate-limit") + perConnRateLimit = cctx.Int64("per-conn-rate-limit") + rateLimitTimeout = cctx.Duration("rate-limit-timeout") + connPerMinute = cctx.Int64("conn-per-minute") ) serverOptions := make([]jsonrpc.ServerOption, 0) @@ -173,8 +197,8 @@ var runCmd = &cli.Command{ return xerrors.Errorf("failed to convert endpoint address to multiaddr: %w", err) } - gwapi := gateway.NewNode(api, lookbackCap, waitLookback) - h, err := gateway.Handler(gwapi, api, serverOptions...) + gwapi := gateway.NewNode(api, lookbackCap, waitLookback, rateLimit, rateLimitTimeout) + h, err := gateway.Handler(gwapi, api, perConnRateLimit, connPerMinute, serverOptions...) if err != nil { return xerrors.Errorf("failed to set up gateway HTTP handler") } diff --git a/gateway/handler.go b/gateway/handler.go index 3b1553acc0e..12dded40e5b 100644 --- a/gateway/handler.go +++ b/gateway/handler.go @@ -1,7 +1,11 @@ package gateway import ( + "context" + "net" "net/http" + "sync" + "time" "contrib.go.opencensus.io/exporter/prometheus" "github.com/filecoin-project/go-jsonrpc" @@ -12,10 +16,15 @@ import ( "github.com/filecoin-project/lotus/node" "github.com/gorilla/mux" promclient "github.com/prometheus/client_golang/prometheus" + "golang.org/x/time/rate" ) +type perConnLimiterKeyType string + +const perConnLimiterKey perConnLimiterKeyType = "limiter" + // Handler returns a gateway http.Handler, to be mounted as-is on the server. -func Handler(gwapi lapi.Gateway, api lapi.FullNode, opts ...jsonrpc.ServerOption) (http.Handler, error) { +func Handler(gwapi lapi.Gateway, api lapi.FullNode, rateLimit int64, connPerMinute int64, opts ...jsonrpc.ServerOption) (http.Handler, error) { m := mux.NewRouter() serveRpc := func(path string, hnd interface{}) { @@ -49,5 +58,95 @@ func Handler(gwapi lapi.Gateway, api lapi.FullNode, opts ...jsonrpc.ServerOption Next: mux.ServeHTTP, }*/ - return m, nil + rlh := NewRateLimiterHandler(m, rateLimit) + clh := NewConnectionRateLimiterHandler(rlh, connPerMinute) + return clh, nil +} + +func NewRateLimiterHandler(handler http.Handler, rateLimit int64) *RateLimiterHandler { + limiter := limiterFromRateLimit(rateLimit) + + return &RateLimiterHandler{ + handler: handler, + limiter: limiter, + } +} + +// Adds a rate limiter to the request context for per-connection rate limiting +type RateLimiterHandler struct { + handler http.Handler + limiter *rate.Limiter +} + +func (h RateLimiterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + r2 := r.WithContext(context.WithValue(r.Context(), perConnLimiterKey, h.limiter)) + h.handler.ServeHTTP(w, r2) +} + +// this blocks new connections if there have already been too many. +func NewConnectionRateLimiterHandler(handler http.Handler, connPerMinute int64) *ConnectionRateLimiterHandler { + ipmap := make(map[string]int64) + return &ConnectionRateLimiterHandler{ + ipmap: ipmap, + connPerMinute: connPerMinute, + handler: handler, + } +} + +type ConnectionRateLimiterHandler struct { + mu sync.Mutex + ipmap map[string]int64 + connPerMinute int64 + handler http.Handler +} + +func (h *ConnectionRateLimiterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if h.connPerMinute == 0 { + h.handler.ServeHTTP(w, r) + return + } + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + h.mu.Lock() + seen, ok := h.ipmap[host] + if !ok { + h.ipmap[host] = 1 + h.mu.Unlock() + h.handler.ServeHTTP(w, r) + return + } + // rate limited + if seen > h.connPerMinute { + h.mu.Unlock() + w.WriteHeader(http.StatusTooManyRequests) + return + } + h.ipmap[host] = seen + 1 + h.mu.Unlock() + go func() { + select { + case <-time.After(time.Minute): + h.mu.Lock() + defer h.mu.Unlock() + h.ipmap[host] = h.ipmap[host] - 1 + if h.ipmap[host] <= 0 { + delete(h.ipmap, host) + } + } + }() + h.handler.ServeHTTP(w, r) +} + +func limiterFromRateLimit(rateLimit int64) *rate.Limiter { + var limit rate.Limit + if rateLimit == 0 { + limit = rate.Inf + } else { + limit = rate.Every(time.Second / time.Duration(rateLimit)) + } + return rate.NewLimiter(limit, stateRateLimitTokens) } diff --git a/gateway/node.go b/gateway/node.go index 58a2f35c958..bf54530f02b 100644 --- a/gateway/node.go +++ b/gateway/node.go @@ -6,6 +6,8 @@ import ( "time" "github.com/ipfs/go-cid" + "go.opencensus.io/stats" + "golang.org/x/time/rate" "golang.org/x/xerrors" "github.com/filecoin-project/go-address" @@ -22,12 +24,18 @@ import ( "github.com/filecoin-project/lotus/lib/sigs" _ "github.com/filecoin-project/lotus/lib/sigs/bls" _ "github.com/filecoin-project/lotus/lib/sigs/secp" + "github.com/filecoin-project/lotus/metrics" "github.com/filecoin-project/lotus/node/impl/full" ) const ( DefaultLookbackCap = time.Hour * 24 DefaultStateWaitLookbackLimit = abi.ChainEpoch(20) + DefaultRateLimitTimeout = time.Second * 5 + basicRateLimitTokens = 1 + walletRateLimitTokens = 1 + chainRateLimitTokens = 2 + stateRateLimitTokens = 3 ) // TargetAPI defines the API methods that the Node depends on @@ -85,6 +93,8 @@ type Node struct { target TargetAPI lookbackCap time.Duration stateWaitLookbackLimit abi.ChainEpoch + rateLimiter *rate.Limiter + rateLimitTimeout time.Duration errLookback error } @@ -97,11 +107,19 @@ var ( ) // NewNode creates a new gateway node. -func NewNode(api TargetAPI, lookbackCap time.Duration, stateWaitLookbackLimit abi.ChainEpoch) *Node { +func NewNode(api TargetAPI, lookbackCap time.Duration, stateWaitLookbackLimit abi.ChainEpoch, rateLimit int64, rateLimitTimeout time.Duration) *Node { + var limit rate.Limit + if rateLimit == 0 { + limit = rate.Inf + } else { + limit = rate.Every(time.Second / time.Duration(rateLimit)) + } return &Node{ target: api, lookbackCap: lookbackCap, stateWaitLookbackLimit: stateWaitLookbackLimit, + rateLimiter: rate.NewLimiter(limit, stateRateLimitTokens), + rateLimitTimeout: rateLimitTimeout, errLookback: fmt.Errorf("lookbacks of more than %s are disallowed", lookbackCap), } } @@ -145,57 +163,102 @@ func (gw *Node) checkTimestamp(at time.Time) error { return nil } +func (gw *Node) limit(ctx context.Context, tokens int) error { + ctx2, cancel := context.WithTimeout(ctx, gw.rateLimitTimeout) + defer cancel() + if perConnLimiter, ok := ctx2.Value(perConnLimiterKey).(*rate.Limiter); ok { + err := perConnLimiter.WaitN(ctx2, tokens) + if err != nil { + return fmt.Errorf("connection limited. %w", err) + } + } + + err := gw.rateLimiter.WaitN(ctx2, tokens) + if err != nil { + stats.Record(ctx, metrics.RateLimitCount.M(1)) + return fmt.Errorf("server busy. %w", err) + } + return nil +} + func (gw *Node) Discover(ctx context.Context) (apitypes.OpenRPCDocument, error) { return build.OpenRPCDiscoverJSON_Gateway(), nil } func (gw *Node) Version(ctx context.Context) (api.APIVersion, error) { + if err := gw.limit(ctx, basicRateLimitTokens); err != nil { + return api.APIVersion{}, err + } return gw.target.Version(ctx) } func (gw *Node) ChainGetParentMessages(ctx context.Context, c cid.Cid) ([]api.Message, error) { + if err := gw.limit(ctx, chainRateLimitTokens); err != nil { + return nil, err + } return gw.target.ChainGetParentMessages(ctx, c) } func (gw *Node) ChainGetParentReceipts(ctx context.Context, c cid.Cid) ([]*types.MessageReceipt, error) { + if err := gw.limit(ctx, chainRateLimitTokens); err != nil { + return nil, err + } return gw.target.ChainGetParentReceipts(ctx, c) } func (gw *Node) ChainGetBlockMessages(ctx context.Context, c cid.Cid) (*api.BlockMessages, error) { + if err := gw.limit(ctx, chainRateLimitTokens); err != nil { + return nil, err + } return gw.target.ChainGetBlockMessages(ctx, c) } func (gw *Node) ChainHasObj(ctx context.Context, c cid.Cid) (bool, error) { + if err := gw.limit(ctx, chainRateLimitTokens); err != nil { + return false, err + } return gw.target.ChainHasObj(ctx, c) } func (gw *Node) ChainHead(ctx context.Context) (*types.TipSet, error) { - // TODO: cache and invalidate cache when timestamp is up (or have internal ChainNotify) + if err := gw.limit(ctx, chainRateLimitTokens); err != nil { + return nil, err + } return gw.target.ChainHead(ctx) } func (gw *Node) ChainGetMessage(ctx context.Context, mc cid.Cid) (*types.Message, error) { + if err := gw.limit(ctx, chainRateLimitTokens); err != nil { + return nil, err + } return gw.target.ChainGetMessage(ctx, mc) } func (gw *Node) ChainGetTipSet(ctx context.Context, tsk types.TipSetKey) (*types.TipSet, error) { + if err := gw.limit(ctx, chainRateLimitTokens); err != nil { + return nil, err + } return gw.target.ChainGetTipSet(ctx, tsk) } func (gw *Node) ChainGetTipSetByHeight(ctx context.Context, h abi.ChainEpoch, tsk types.TipSetKey) (*types.TipSet, error) { + if err := gw.limit(ctx, chainRateLimitTokens); err != nil { + return nil, err + } if err := gw.checkTipSetHeight(ctx, h, tsk); err != nil { return nil, err } - return gw.target.ChainGetTipSetByHeight(ctx, h, tsk) } func (gw *Node) ChainGetTipSetAfterHeight(ctx context.Context, h abi.ChainEpoch, tsk types.TipSetKey) (*types.TipSet, error) { + if err := gw.limit(ctx, chainRateLimitTokens); err != nil { + return nil, err + } if err := gw.checkTipSetHeight(ctx, h, tsk); err != nil { return nil, err } - return gw.target.ChainGetTipSetAfterHeight(ctx, h, tsk) } @@ -229,66 +292,91 @@ func (gw *Node) checkTipSetHeight(ctx context.Context, h abi.ChainEpoch, tsk typ } func (gw *Node) ChainGetNode(ctx context.Context, p string) (*api.IpldObject, error) { + if err := gw.limit(ctx, chainRateLimitTokens); err != nil { + return nil, err + } return gw.target.ChainGetNode(ctx, p) } func (gw *Node) ChainNotify(ctx context.Context) (<-chan []*api.HeadChange, error) { + if err := gw.limit(ctx, chainRateLimitTokens); err != nil { + return nil, err + } return gw.target.ChainNotify(ctx) } func (gw *Node) ChainGetPath(ctx context.Context, from, to types.TipSetKey) ([]*api.HeadChange, error) { + if err := gw.limit(ctx, chainRateLimitTokens); err != nil { + return nil, err + } if err := gw.checkTipsetKey(ctx, from); err != nil { return nil, xerrors.Errorf("gateway: checking 'from' tipset: %w", err) } - if err := gw.checkTipsetKey(ctx, to); err != nil { return nil, xerrors.Errorf("gateway: checking 'to' tipset: %w", err) } - return gw.target.ChainGetPath(ctx, from, to) } func (gw *Node) ChainGetGenesis(ctx context.Context) (*types.TipSet, error) { + if err := gw.limit(ctx, chainRateLimitTokens); err != nil { + return nil, err + } return gw.target.ChainGetGenesis(ctx) } func (gw *Node) ChainReadObj(ctx context.Context, c cid.Cid) ([]byte, error) { + if err := gw.limit(ctx, chainRateLimitTokens); err != nil { + return nil, err + } return gw.target.ChainReadObj(ctx, c) } func (gw *Node) GasEstimateMessageGas(ctx context.Context, msg *types.Message, spec *api.MessageSendSpec, tsk types.TipSetKey) (*types.Message, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return nil, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return nil, err } - return gw.target.GasEstimateMessageGas(ctx, msg, spec, tsk) } func (gw *Node) MpoolPush(ctx context.Context, sm *types.SignedMessage) (cid.Cid, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return cid.Cid{}, err + } // TODO: additional anti-spam checks return gw.target.MpoolPushUntrusted(ctx, sm) } func (gw *Node) MsigGetAvailableBalance(ctx context.Context, addr address.Address, tsk types.TipSetKey) (types.BigInt, error) { + if err := gw.limit(ctx, walletRateLimitTokens); err != nil { + return types.BigInt{}, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return types.NewInt(0), err } - return gw.target.MsigGetAvailableBalance(ctx, addr, tsk) } func (gw *Node) MsigGetVested(ctx context.Context, addr address.Address, start types.TipSetKey, end types.TipSetKey) (types.BigInt, error) { + if err := gw.limit(ctx, walletRateLimitTokens); err != nil { + return types.BigInt{}, err + } if err := gw.checkTipsetKey(ctx, start); err != nil { return types.NewInt(0), err } if err := gw.checkTipsetKey(ctx, end); err != nil { return types.NewInt(0), err } - return gw.target.MsigGetVested(ctx, addr, start, end) } func (gw *Node) MsigGetVestingSchedule(ctx context.Context, addr address.Address, tsk types.TipSetKey) (api.MsigVesting, error) { + if err := gw.limit(ctx, walletRateLimitTokens); err != nil { + return api.MsigVesting{}, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return api.MsigVesting{}, err } @@ -296,78 +384,99 @@ func (gw *Node) MsigGetVestingSchedule(ctx context.Context, addr address.Address } func (gw *Node) MsigGetPending(ctx context.Context, addr address.Address, tsk types.TipSetKey) ([]*api.MsigTransaction, error) { + if err := gw.limit(ctx, walletRateLimitTokens); err != nil { + return nil, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return nil, err } - return gw.target.MsigGetPending(ctx, addr, tsk) } func (gw *Node) StateAccountKey(ctx context.Context, addr address.Address, tsk types.TipSetKey) (address.Address, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return address.Address{}, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return address.Undef, err } - return gw.target.StateAccountKey(ctx, addr, tsk) } func (gw *Node) StateDealProviderCollateralBounds(ctx context.Context, size abi.PaddedPieceSize, verified bool, tsk types.TipSetKey) (api.DealCollateralBounds, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return api.DealCollateralBounds{}, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return api.DealCollateralBounds{}, err } - return gw.target.StateDealProviderCollateralBounds(ctx, size, verified, tsk) } func (gw *Node) StateGetActor(ctx context.Context, actor address.Address, tsk types.TipSetKey) (*types.Actor, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return nil, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return nil, err } - return gw.target.StateGetActor(ctx, actor, tsk) } func (gw *Node) StateListMiners(ctx context.Context, tsk types.TipSetKey) ([]address.Address, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return nil, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return nil, err } - return gw.target.StateListMiners(ctx, tsk) } func (gw *Node) StateLookupID(ctx context.Context, addr address.Address, tsk types.TipSetKey) (address.Address, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return address.Address{}, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return address.Undef, err } - return gw.target.StateLookupID(ctx, addr, tsk) } func (gw *Node) StateMarketBalance(ctx context.Context, addr address.Address, tsk types.TipSetKey) (api.MarketBalance, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return api.MarketBalance{}, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return api.MarketBalance{}, err } - return gw.target.StateMarketBalance(ctx, addr, tsk) } func (gw *Node) StateMarketStorageDeal(ctx context.Context, dealId abi.DealID, tsk types.TipSetKey) (*api.MarketDeal, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return nil, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return nil, err } - return gw.target.StateMarketStorageDeal(ctx, dealId, tsk) } func (gw *Node) StateNetworkVersion(ctx context.Context, tsk types.TipSetKey) (network.Version, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return network.VersionMax, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return network.VersionMax, err } - return gw.target.StateNetworkVersion(ctx, tsk) } func (gw *Node) StateSearchMsg(ctx context.Context, from types.TipSetKey, msg cid.Cid, limit abi.ChainEpoch, allowReplaced bool) (*api.MsgLookup, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return nil, err + } if limit == api.LookbackNoLimit { limit = gw.stateWaitLookbackLimit } @@ -377,22 +486,26 @@ func (gw *Node) StateSearchMsg(ctx context.Context, from types.TipSetKey, msg ci if err := gw.checkTipsetKey(ctx, from); err != nil { return nil, err } - return gw.target.StateSearchMsg(ctx, from, msg, limit, allowReplaced) } func (gw *Node) StateWaitMsg(ctx context.Context, msg cid.Cid, confidence uint64, limit abi.ChainEpoch, allowReplaced bool) (*api.MsgLookup, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return nil, err + } if limit == api.LookbackNoLimit { limit = gw.stateWaitLookbackLimit } if gw.stateWaitLookbackLimit != api.LookbackNoLimit && limit > gw.stateWaitLookbackLimit { limit = gw.stateWaitLookbackLimit } - return gw.target.StateWaitMsg(ctx, msg, confidence, limit, allowReplaced) } func (gw *Node) StateReadState(ctx context.Context, actor address.Address, tsk types.TipSetKey) (*api.ActorState, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return nil, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return nil, err } @@ -400,6 +513,9 @@ func (gw *Node) StateReadState(ctx context.Context, actor address.Address, tsk t } func (gw *Node) StateMinerPower(ctx context.Context, m address.Address, tsk types.TipSetKey) (*api.MinerPower, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return nil, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return nil, err } @@ -407,12 +523,19 @@ func (gw *Node) StateMinerPower(ctx context.Context, m address.Address, tsk type } func (gw *Node) StateMinerFaults(ctx context.Context, m address.Address, tsk types.TipSetKey) (bitfield.BitField, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return bitfield.BitField{}, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return bitfield.BitField{}, err } return gw.target.StateMinerFaults(ctx, m, tsk) } + func (gw *Node) StateMinerRecoveries(ctx context.Context, m address.Address, tsk types.TipSetKey) (bitfield.BitField, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return bitfield.BitField{}, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return bitfield.BitField{}, err } @@ -420,6 +543,9 @@ func (gw *Node) StateMinerRecoveries(ctx context.Context, m address.Address, tsk } func (gw *Node) StateMinerInfo(ctx context.Context, m address.Address, tsk types.TipSetKey) (api.MinerInfo, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return api.MinerInfo{}, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return api.MinerInfo{}, err } @@ -427,6 +553,9 @@ func (gw *Node) StateMinerInfo(ctx context.Context, m address.Address, tsk types } func (gw *Node) StateMinerDeadlines(ctx context.Context, m address.Address, tsk types.TipSetKey) ([]api.Deadline, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return nil, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return nil, err } @@ -434,6 +563,9 @@ func (gw *Node) StateMinerDeadlines(ctx context.Context, m address.Address, tsk } func (gw *Node) StateMinerAvailableBalance(ctx context.Context, m address.Address, tsk types.TipSetKey) (types.BigInt, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return types.BigInt{}, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return types.BigInt{}, err } @@ -441,6 +573,9 @@ func (gw *Node) StateMinerAvailableBalance(ctx context.Context, m address.Addres } func (gw *Node) StateMinerProvingDeadline(ctx context.Context, m address.Address, tsk types.TipSetKey) (*dline.Info, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return nil, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return nil, err } @@ -448,13 +583,19 @@ func (gw *Node) StateMinerProvingDeadline(ctx context.Context, m address.Address } func (gw *Node) StateCirculatingSupply(ctx context.Context, tsk types.TipSetKey) (abi.TokenAmount, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return abi.TokenAmount{}, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { - return types.BigInt{}, err + return abi.TokenAmount{}, err } return gw.target.StateCirculatingSupply(ctx, tsk) } func (gw *Node) StateSectorGetInfo(ctx context.Context, maddr address.Address, n abi.SectorNumber, tsk types.TipSetKey) (*miner.SectorOnChainInfo, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return nil, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return nil, err } @@ -462,6 +603,9 @@ func (gw *Node) StateSectorGetInfo(ctx context.Context, maddr address.Address, n } func (gw *Node) StateVerifiedClientStatus(ctx context.Context, addr address.Address, tsk types.TipSetKey) (*abi.StoragePower, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return nil, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return nil, err } @@ -469,6 +613,9 @@ func (gw *Node) StateVerifiedClientStatus(ctx context.Context, addr address.Addr } func (gw *Node) StateVMCirculatingSupplyInternal(ctx context.Context, tsk types.TipSetKey) (api.CirculatingSupply, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return api.CirculatingSupply{}, err + } if err := gw.checkTipsetKey(ctx, tsk); err != nil { return api.CirculatingSupply{}, err } @@ -476,9 +623,15 @@ func (gw *Node) StateVMCirculatingSupplyInternal(ctx context.Context, tsk types. } func (gw *Node) WalletVerify(ctx context.Context, k address.Address, msg []byte, sig *crypto.Signature) (bool, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return false, err + } return sigs.Verify(sig, k, msg) == nil, nil } func (gw *Node) WalletBalance(ctx context.Context, k address.Address) (types.BigInt, error) { + if err := gw.limit(ctx, stateRateLimitTokens); err != nil { + return types.BigInt{}, err + } return gw.target.WalletBalance(ctx, k) } diff --git a/gateway/node_test.go b/gateway/node_test.go index aebd8ba14b7..b077d514af5 100644 --- a/gateway/node_test.go +++ b/gateway/node_test.go @@ -89,7 +89,7 @@ func TestGatewayAPIChainGetTipSetByHeight(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { mock := &mockGatewayDepsAPI{} - a := NewNode(mock, DefaultLookbackCap, DefaultStateWaitLookbackLimit) + a := NewNode(mock, DefaultLookbackCap, DefaultStateWaitLookbackLimit, 0, time.Minute) // Create tipsets from genesis up to tskh and return the highest ts := mock.createTipSets(tt.args.tskh, tt.args.genesisTS) @@ -245,9 +245,33 @@ func TestGatewayVersion(t *testing.T) { //stm: @GATEWAY_NODE_GET_VERSION_001 ctx := context.Background() mock := &mockGatewayDepsAPI{} - a := NewNode(mock, DefaultLookbackCap, DefaultStateWaitLookbackLimit) + a := NewNode(mock, DefaultLookbackCap, DefaultStateWaitLookbackLimit, 0, time.Minute) v, err := a.Version(ctx) require.NoError(t, err) require.Equal(t, api.FullAPIVersion1, v.APIVersion) } + +func TestGatewayLimitTokensAvailable(t *testing.T) { + ctx := context.Background() + mock := &mockGatewayDepsAPI{} + tokens := 3 + a := NewNode(mock, DefaultLookbackCap, DefaultStateWaitLookbackLimit, int64(tokens), time.Minute) + require.NoError(t, a.limit(ctx, tokens), "requests should not be limited when there are enough tokens available") +} + +func TestGatewayLimitTokensNotAvailable(t *testing.T) { + ctx := context.Background() + mock := &mockGatewayDepsAPI{} + tokens := 3 + a := NewNode(mock, DefaultLookbackCap, DefaultStateWaitLookbackLimit, int64(1), time.Millisecond) + var err error + // try to be rate limited + for i := 0; i <= 1000; i++ { + err = a.limit(ctx, tokens) + if err != nil { + break + } + } + require.Error(t, err, "requiests should be rate limited when they hit limits") +} diff --git a/itests/gateway_test.go b/itests/gateway_test.go index 593ec17e996..ce31a0a3633 100644 --- a/itests/gateway_test.go +++ b/itests/gateway_test.go @@ -290,8 +290,8 @@ func startNodes( ens.InterconnectAll().BeginMining(blocktime) // Create a gateway server in front of the full node - gwapi := gateway.NewNode(full, lookbackCap, stateWaitLookbackLimit) - handler, err := gateway.Handler(gwapi, full) + gwapi := gateway.NewNode(full, lookbackCap, stateWaitLookbackLimit, 0, time.Minute) + handler, err := gateway.Handler(gwapi, full, 0, 0) require.NoError(t, err) l, err := net.Listen("tcp", "127.0.0.1:0") diff --git a/metrics/metrics.go b/metrics/metrics.go index 2bbba88d3fc..8a4c3aa314f 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -167,6 +167,9 @@ var ( RcmgrBlockSvcPeer = stats.Int64("rcmgr/block_svc", "Number of blocked blocked streams attached to a service for a specific peer", stats.UnitDimensionless) RcmgrAllowMem = stats.Int64("rcmgr/allow_mem", "Number of allowed memory reservations", stats.UnitDimensionless) RcmgrBlockMem = stats.Int64("rcmgr/block_mem", "Number of blocked memory reservations", stats.UnitDimensionless) + + // gateway rate limit + RateLimitCount = stats.Int64("ratelimit/limited", "rate limited connections", stats.UnitDimensionless) ) var ( @@ -599,6 +602,10 @@ var ( Measure: RcmgrBlockMem, Aggregation: view.Count(), } + RateLimitedView = &view.View{ + Measure: RateLimitCount, + Aggregation: view.Count(), + } ) // DefaultViews is an array of OpenCensus views for metric gathering purposes @@ -711,6 +718,10 @@ var MinerNodeViews = append([]*view.View{ DagStorePRSeekForwardBytesView, }, DefaultViews...) +var GatewayNodeViews = append([]*view.View{ + RateLimitedView, +}, ChainNodeViews...) + // SinceInMilliseconds returns the duration of time since the provide time as a float64. func SinceInMilliseconds(startTime time.Time) float64 { return float64(time.Since(startTime).Nanoseconds()) / 1e6