Skip to content

Commit

Permalink
feat(service): implement trigger rate-limit (#321)
Browse files Browse the repository at this point in the history
Because

- we need to rate limit the pipeline trigger

This commit

- implement trigger rate-limit, if the value is not set in Redis, we'll
ignore that.
  • Loading branch information
donch1989 authored Dec 9, 2023
1 parent 466fe2b commit 91a9706
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 18 deletions.
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1183,8 +1183,6 @@ github.com/influxdata/influxdb-client-go/v2 v2.12.3 h1:28nRlNMRIV4QbtIUvxhWqaxn0
github.com/influxdata/influxdb-client-go/v2 v2.12.3/go.mod h1:IrrLUbCjjfkmRuaCiGQg4m2GbkaeJDcuWoxiWdQEbA0=
github.com/influxdata/line-protocol v0.0.0-20200327222509-2487e7298839 h1:W9WBk7wlPfJLvMCdtV4zPulc4uCPrlywQOmbFOhgQNU=
github.com/influxdata/line-protocol v0.0.0-20200327222509-2487e7298839/go.mod h1:xaLFMmpvUxqXtVkUJfg9QmT88cDaCJ3ZKgdZ78oO8Qo=
github.com/instill-ai/component v0.7.1-alpha.0.20231206035822-12eee341c80e h1:nhjzSXK71aB5WQEoSW3sR9OLCkKU1YlGal9lhdQJ9ao=
github.com/instill-ai/component v0.7.1-alpha.0.20231206035822-12eee341c80e/go.mod h1:fWyVPJVJ4WyFSQMgWCc7KvcS7m9QpdS3VXCL2CZE8OY=
github.com/instill-ai/component v0.7.1-alpha.0.20231208045032-bafec3495571 h1:LdlHF/NN65GrM9kTnDMp3Ycxls3nU08/65UmHhIXjag=
github.com/instill-ai/component v0.7.1-alpha.0.20231208045032-bafec3495571/go.mod h1:fWyVPJVJ4WyFSQMgWCc7KvcS7m9QpdS3VXCL2CZE8OY=
github.com/instill-ai/connector v0.7.0-alpha.0.20231206040111-5a57a09f2adc h1:i34jk4bQi1jEcN66UNzK6UoPmZuCWrldV60qrZXa/H4=
Expand Down
4 changes: 4 additions & 0 deletions pkg/middleware/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ func InjectErrCode(err error) error {
errors.Is(err, service.ErrUnauthenticated):
return status.Error(codes.Unauthenticated, err.Error())

case
errors.Is(err, service.ErrRateLimiting):
return status.Error(codes.ResourceExhausted, err.Error())

case
errors.Is(err, acl.ErrMembershipNotFound):
return status.Error(codes.NotFound, err.Error())
Expand Down
1 change: 1 addition & 0 deletions pkg/service/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ import "errors"
var ErrNoPermission = errors.New("no permission")
var ErrNotFound = errors.New("not found")
var ErrUnauthenticated = errors.New("unauthenticated")
var ErrRateLimiting = errors.New("rate limiting")
51 changes: 35 additions & 16 deletions pkg/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/rand"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -221,6 +223,14 @@ func (a AuthUser) GetACLType() string {
}
}

func (a AuthUser) Permalink() string {
if a.IsVisitor {
return fmt.Sprintf("visitors/%s", a.UID)
} else {
return fmt.Sprintf("users/%s", a.UID)
}
}

func (s *service) AuthenticateUser(ctx context.Context, allowVisitor bool) (authUser *AuthUser, err error) {
// Verify if "jwt-sub" is in the header
headerCtxUserUID := resource.GetRequestSingleHeader(ctx, constant.HeaderUserUIDKey)
Expand Down Expand Up @@ -751,10 +761,21 @@ func (s *service) UpdateNamespacePipelineIDByID(ctx context.Context, ns resource
return s.DBToPBPipeline(ctx, dbPipeline, VIEW_FULL)
}

func (s *service) preTriggerPipeline(recipe *datamodel.Recipe, pipelineInputs []*structpb.Struct) error {
func (s *service) preTriggerPipeline(authUser *AuthUser, recipe *datamodel.Recipe, pipelineInputs []*structpb.Struct) error {

value, err := s.redisClient.Get(context.Background(), fmt.Sprintf("user_rate_limit:user:%s", authUser.UID)).Result()

// TODO: use a more robust way to check key exist
if !errors.Is(err, redis.Nil) {
requestLeft, _ := strconv.ParseInt(value, 10, 64)
if requestLeft <= 0 {
return ErrRateLimiting
} else {
_ = s.redisClient.Decr(context.Background(), fmt.Sprintf("user_rate_limit:user:%s", authUser.UID))
}
}

var metadata []byte
var err error

instillFormatMap := map[string]string{}
for _, comp := range recipe.Components {
Expand Down Expand Up @@ -1269,7 +1290,7 @@ func (s *service) SetDefaultNamespacePipelineReleaseByID(ctx context.Context, ns
func (s *service) triggerPipeline(
ctx context.Context,
ownerPermalink string,
userPermalink string,
authUser *AuthUser,
recipe *datamodel.Recipe,
pipelineId string,
pipelineUid uuid.UUID,
Expand All @@ -1281,7 +1302,7 @@ func (s *service) triggerPipeline(

logger, _ := logger.GetZapLogger(ctx)

err := s.preTriggerPipeline(recipe, pipelineInputs)
err := s.preTriggerPipeline(authUser, recipe, pipelineInputs)
if err != nil {
return nil, nil, err
}
Expand All @@ -1301,6 +1322,7 @@ func (s *service) triggerPipeline(
time.Duration(config.Config.Server.Workflow.MaxWorkflowTimeout)*time.Second,
)
inputBlobRedisKeys = append(inputBlobRedisKeys, inputBlobRedisKey)
defer s.redisClient.Del(context.Background(), inputBlobRedisKey)
}
memo := map[string]interface{}{}
memo["number_of_data"] = len(inputBlobRedisKeys)
Expand All @@ -1327,7 +1349,7 @@ func (s *service) triggerPipeline(
PipelineReleaseUid: pipelineReleaseUid,
PipelineRecipe: recipe,
OwnerPermalink: ownerPermalink,
UserPermalink: userPermalink,
UserPermalink: authUser.Permalink(),
ReturnTraces: returnTraces,
})
if err != nil {
Expand All @@ -1346,6 +1368,7 @@ func (s *service) triggerPipeline(
if err != nil {
return nil, nil, err
}
s.redisClient.Del(context.Background(), result.OutputBlobRedisKey)

err = protojson.Unmarshal(blob, pipelineResp)
if err != nil {
Expand All @@ -1358,7 +1381,7 @@ func (s *service) triggerPipeline(
func (s *service) triggerAsyncPipeline(
ctx context.Context,
ownerPermalink string,
userPermalink string,
authUser *AuthUser,
recipe *datamodel.Recipe,
pipelineId string,
pipelineUid uuid.UUID,
Expand All @@ -1368,7 +1391,7 @@ func (s *service) triggerAsyncPipeline(
pipelineTriggerId string,
returnTraces bool) (*longrunningpb.Operation, error) {

err := s.preTriggerPipeline(recipe, pipelineInputs)
err := s.preTriggerPipeline(authUser, recipe, pipelineInputs)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1415,7 +1438,7 @@ func (s *service) triggerAsyncPipeline(
PipelineReleaseUid: pipelineReleaseUid,
PipelineRecipe: recipe,
OwnerPermalink: ownerPermalink,
UserPermalink: userPermalink,
UserPermalink: authUser.Permalink(),
ReturnTraces: returnTraces,
})
if err != nil {
Expand All @@ -1435,7 +1458,6 @@ func (s *service) triggerAsyncPipeline(
func (s *service) TriggerNamespacePipelineByID(ctx context.Context, ns resource.Namespace, authUser *AuthUser, id string, inputs []*structpb.Struct, pipelineTriggerId string, returnTraces bool) ([]*structpb.Struct, *pipelinePB.TriggerMetadata, error) {

ownerPermalink := ns.String()
userPermalink := fmt.Sprintf("users/%s", authUser.UID)

dbPipeline, err := s.repository.GetNamespacePipelineByID(ctx, ownerPermalink, id, false)
if err != nil {
Expand All @@ -1454,14 +1476,13 @@ func (s *service) TriggerNamespacePipelineByID(ctx context.Context, ns resource.
return nil, nil, ErrNoPermission
}

return s.triggerPipeline(ctx, ownerPermalink, userPermalink, dbPipeline.Recipe, dbPipeline.ID, dbPipeline.UID, "", uuid.Nil, inputs, pipelineTriggerId, returnTraces)
return s.triggerPipeline(ctx, ownerPermalink, authUser, dbPipeline.Recipe, dbPipeline.ID, dbPipeline.UID, "", uuid.Nil, inputs, pipelineTriggerId, returnTraces)

}

func (s *service) TriggerAsyncNamespacePipelineByID(ctx context.Context, ns resource.Namespace, authUser *AuthUser, id string, inputs []*structpb.Struct, pipelineTriggerId string, returnTraces bool) (*longrunningpb.Operation, error) {

ownerPermalink := ns.String()
userPermalink := fmt.Sprintf("users/%s", authUser.UID)

dbPipeline, err := s.repository.GetNamespacePipelineByID(ctx, ownerPermalink, id, false)
if err != nil {
Expand All @@ -1479,14 +1500,13 @@ func (s *service) TriggerAsyncNamespacePipelineByID(ctx context.Context, ns reso
return nil, ErrNoPermission
}

return s.triggerAsyncPipeline(ctx, ownerPermalink, userPermalink, dbPipeline.Recipe, dbPipeline.ID, dbPipeline.UID, "", uuid.Nil, inputs, pipelineTriggerId, returnTraces)
return s.triggerAsyncPipeline(ctx, ownerPermalink, authUser, dbPipeline.Recipe, dbPipeline.ID, dbPipeline.UID, "", uuid.Nil, inputs, pipelineTriggerId, returnTraces)

}

func (s *service) TriggerNamespacePipelineReleaseByID(ctx context.Context, ns resource.Namespace, authUser *AuthUser, pipelineUid uuid.UUID, id string, inputs []*structpb.Struct, pipelineTriggerId string, returnTraces bool) ([]*structpb.Struct, *pipelinePB.TriggerMetadata, error) {

ownerPermalink := ns.String()
userPermalink := fmt.Sprintf("users/%s", authUser.UID)

dbPipeline, err := s.repository.GetPipelineByUID(ctx, pipelineUid, false)
if err != nil {
Expand All @@ -1509,13 +1529,12 @@ func (s *service) TriggerNamespacePipelineReleaseByID(ctx context.Context, ns re
return nil, nil, err
}

return s.triggerPipeline(ctx, ownerPermalink, userPermalink, dbPipelineRelease.Recipe, dbPipeline.ID, dbPipeline.UID, dbPipelineRelease.ID, dbPipelineRelease.UID, inputs, pipelineTriggerId, returnTraces)
return s.triggerPipeline(ctx, ownerPermalink, authUser, dbPipelineRelease.Recipe, dbPipeline.ID, dbPipeline.UID, dbPipelineRelease.ID, dbPipelineRelease.UID, inputs, pipelineTriggerId, returnTraces)
}

func (s *service) TriggerAsyncNamespacePipelineReleaseByID(ctx context.Context, ns resource.Namespace, authUser *AuthUser, pipelineUid uuid.UUID, id string, inputs []*structpb.Struct, pipelineTriggerId string, returnTraces bool) (*longrunningpb.Operation, error) {

ownerPermalink := ns.String()
userPermalink := fmt.Sprintf("users/%s", authUser.UID)

dbPipeline, err := s.repository.GetPipelineByUID(ctx, pipelineUid, false)
if err != nil {
Expand All @@ -1538,7 +1557,7 @@ func (s *service) TriggerAsyncNamespacePipelineReleaseByID(ctx context.Context,
return nil, err
}

return s.triggerAsyncPipeline(ctx, ownerPermalink, userPermalink, dbPipelineRelease.Recipe, dbPipeline.ID, dbPipeline.UID, dbPipelineRelease.ID, dbPipelineRelease.UID, inputs, pipelineTriggerId, returnTraces)
return s.triggerAsyncPipeline(ctx, ownerPermalink, authUser, dbPipelineRelease.Recipe, dbPipeline.ID, dbPipeline.UID, dbPipelineRelease.ID, dbPipelineRelease.UID, inputs, pipelineTriggerId, returnTraces)
}

func (s *service) RemoveCredentialFieldsWithMaskString(dbConnDefID string, config *structpb.Struct) {
Expand Down

0 comments on commit 91a9706

Please sign in to comment.