diff --git a/app/app.go b/app/app.go index 6cf76ef7707..b0c37fa912c 100644 --- a/app/app.go +++ b/app/app.go @@ -74,18 +74,23 @@ func (a *app) initCPUProfiling() { } func (a *app) initFeatures() { + enterpriseLogger := logger.NewLogger().Child("enterprise") a.features = &Features{ SuppressUser: &suppression.Factory{ EnterpriseToken: a.options.EnterpriseToken, + Log: enterpriseLogger.Child("suppress-user"), }, Reporting: &reporting.Factory{ EnterpriseToken: a.options.EnterpriseToken, + Log: enterpriseLogger.Child("reporting"), }, Replay: &replay.Factory{ EnterpriseToken: a.options.EnterpriseToken, + Log: enterpriseLogger.Child("replay"), }, ConfigEnv: &configenv.Factory{ EnterpriseToken: a.options.EnterpriseToken, + Log: enterpriseLogger.Child("config-env"), }, } } diff --git a/app/apphandlers/embeddedAppHandler.go b/app/apphandlers/embeddedAppHandler.go index fd7488802f6..d7484186c5c 100644 --- a/app/apphandlers/embeddedAppHandler.go +++ b/app/apphandlers/embeddedAppHandler.go @@ -209,6 +209,7 @@ func (embedded *EmbeddedApp) StartRudderCore(ctx context.Context, options *app.O gw.SetReadonlyDBs(&readonlyGatewayDB, &readonlyRouterDB, &readonlyBatchRouterDB) err = gw.Setup( + ctx, embedded.App, backendconfig.DefaultBackendConfig, gatewayDB, &rateLimiter, embedded.VersionHandler, rsourcesService, ) diff --git a/app/apphandlers/gatewayAppHandler.go b/app/apphandlers/gatewayAppHandler.go index 0e2e2b9d4de..664eb5be784 100644 --- a/app/apphandlers/gatewayAppHandler.go +++ b/app/apphandlers/gatewayAppHandler.go @@ -101,6 +101,7 @@ func (gatewayApp *GatewayApp) StartRudderCore(ctx context.Context, options *app. return err } err = gw.Setup( + ctx, gatewayApp.App, backendconfig.DefaultBackendConfig, gatewayDB, &rateLimiter, gatewayApp.VersionHandler, rsourcesService, ) diff --git a/app/features.go b/app/features.go index 6ceea88407c..5d39279dc20 100644 --- a/app/features.go +++ b/app/features.go @@ -12,7 +12,7 @@ import ( // SuppressUserFeature handles webhook event requests type SuppressUserFeature interface { - Setup(backendConfig backendconfig.BackendConfig) (types.UserSuppression, error) + Setup(ctx context.Context, backendConfig backendconfig.BackendConfig) (types.UserSuppression, error) } /********************************* diff --git a/enterprise/config-env/configEnv.go b/enterprise/config-env/configEnv.go index 6c4d39b3cf6..7bc15c341a1 100644 --- a/enterprise/config-env/configEnv.go +++ b/enterprise/config-env/configEnv.go @@ -14,30 +14,29 @@ import ( "github.com/rudderlabs/rudder-server/utils/logger" ) -type HandleT struct{} +type HandleT struct { + Log logger.Logger +} -var ( - configEnvReplacer string - pkgLogger logger.Logger -) +var configEnvReplacer string func loadConfig() { configEnvReplacer = config.GetString("BackendConfig.configEnvReplacer", "env.") } // ReplaceConfigWithEnvVariables : Replaces all env variables in the config -func (*HandleT) ReplaceConfigWithEnvVariables(workspaceConfig []byte) (updatedConfig []byte) { +func (h *HandleT) ReplaceConfigWithEnvVariables(workspaceConfig []byte) (updatedConfig []byte) { configMap := make(map[string]interface{}, 0) err := json.Unmarshal(workspaceConfig, &configMap) if err != nil { - pkgLogger.Error("[ConfigEnv] Error while parsing request", err, string(workspaceConfig)) + h.Log.Error("[ConfigEnv] Error while parsing request", err, string(workspaceConfig)) return workspaceConfig } flattenedConfig, err := flatten.Flatten(configMap, "", flatten.DotStyle) if err != nil { - pkgLogger.Errorf("[ConfigEnv] Failed to flatten workspace config: %v", err) + h.Log.Errorf("[ConfigEnv] Failed to flatten workspace config: %v", err) return workspaceConfig } @@ -51,12 +50,12 @@ func (*HandleT) ReplaceConfigWithEnvVariables(workspaceConfig []byte) (updatedCo envVarValue := os.Getenv(envVariable) if envVarValue == "" { errorMessage := fmt.Sprintf("[ConfigEnv] Missing envVariable: %s. Either set it as envVariable or remove %s from the destination config.", envVariable, configEnvReplacer) - pkgLogger.Errorf(errorMessage) + h.Log.Errorf(errorMessage) continue } workspaceConfig, err = sjson.SetBytes(workspaceConfig, configKey, envVarValue) if err != nil { - pkgLogger.Error("[ConfigEnv] Failed to set config for %s", configKey) + h.Log.Error("[ConfigEnv] Failed to set config for %s", configKey) } } } diff --git a/enterprise/config-env/setup.go b/enterprise/config-env/setup.go index 6663f38fc78..c6af6f5afcc 100644 --- a/enterprise/config-env/setup.go +++ b/enterprise/config-env/setup.go @@ -7,19 +7,23 @@ import ( type Factory struct { EnterpriseToken string + Log logger.Logger } // Setup initializes Suppress User feature func (m *Factory) Setup() types.ConfigEnvI { + if m.Log == nil { + m.Log = logger.NewLogger().Child("enterprise").Child("config-env") + } if m.EnterpriseToken == "" { return &NOOP{} } loadConfig() - pkgLogger = logger.NewLogger().Child("enterprise").Child("config-env") + m.Log = logger.NewLogger().Child("enterprise").Child("config-env") - pkgLogger.Info("[[ ConfigEnv ]] Setting up config env handler") - handle := &HandleT{} + m.Log.Info("[[ ConfigEnv ]] Setting up config env handler") + handle := &HandleT{Log: m.Log} return handle } diff --git a/enterprise/replay/dumpsloader.go b/enterprise/replay/dumpsloader.go index 74e76c5060d..b88c2b63de7 100644 --- a/enterprise/replay/dumpsloader.go +++ b/enterprise/replay/dumpsloader.go @@ -18,12 +18,9 @@ import ( "github.com/tidwall/gjson" ) -func init() { - pkgLogger = logger.NewLogger().Child("enterprise").Child("replay").Child("dumpsLoader") -} - // DumpsLoaderHandleT - dumps-loader handle type dumpsLoaderHandleT struct { + log logger.Logger dbHandle *jobsdb.HandleT prefix string bucket string @@ -78,7 +75,7 @@ type OrderedJobs struct { Job *jobsdb.JobT } -func storeJobs(ctx context.Context, objects []OrderedJobs, dbHandle *jobsdb.HandleT) { +func storeJobs(ctx context.Context, objects []OrderedJobs, dbHandle *jobsdb.HandleT, log logger.Logger) { // sorting dumps list on index sort.Slice(objects, func(i, j int) bool { return objects[i].SortIndex < objects[j].SortIndex @@ -89,7 +86,7 @@ func storeJobs(ctx context.Context, objects []OrderedJobs, dbHandle *jobsdb.Hand jobs = append(jobs, object.Job) } - pkgLogger.Info("Total dumps count : ", len(objects)) + log.Info("Total dumps count : ", len(objects)) err := dbHandle.Store(ctx, jobs) if err != nil { panic(fmt.Errorf("Failed to write dumps locations to DB with error: %w", err)) @@ -103,7 +100,7 @@ func (gwHandle *GWReplayRequestHandler) fetchDumpsList(ctx context.Context) { maxItems := config.GetInt64("MAX_ITEMS", 1000) // MAX_ITEMS is the max number of files to be fetched in one iteration from object storage uploadMaxItems := config.GetInt64("UPLOAD_MAX_ITEMS", 1) // UPLOAD_MAX_ITEMS is the max number of objects to be uploaded to postgres - pkgLogger.Info("Fetching gw dump files list") + gwHandle.handle.log.Info("Fetching gw dump files list") objects := make([]OrderedJobs, 0) iter := filemanager.IterateFilesWithPrefix(ctx, @@ -131,8 +128,8 @@ func (gwHandle *GWReplayRequestHandler) fetchDumpsList(ctx context.Context) { if err == nil { pass = maxJobCreatedAt >= startTimeMilli && minJobCreatedAt <= endTimeMilli } else { - pkgLogger.Infof("gw dump name(%s) is not of the expected format. Parse failed with error %w", object.Key, err) - pkgLogger.Info("Falling back to comparing start and end time stamps with gw dump last modified.") + gwHandle.handle.log.Infof("gw dump name(%s) is not of the expected format. Parse failed with error %w", object.Key, err) + gwHandle.handle.log.Info("Falling back to comparing start and end time stamps with gw dump last modified.") pass = object.LastModified.After(gwHandle.handle.startTime) && object.LastModified.Before(gwHandle.handle.endTime) } @@ -148,7 +145,7 @@ func (gwHandle *GWReplayRequestHandler) fetchDumpsList(ctx context.Context) { } } if len(objects) >= int(uploadMaxItems) { - storeJobs(ctx, objects, gwHandle.handle.dbHandle) + storeJobs(ctx, objects, gwHandle.handle.dbHandle, gwHandle.handle.log) objects = nil } } @@ -156,17 +153,17 @@ func (gwHandle *GWReplayRequestHandler) fetchDumpsList(ctx context.Context) { panic(fmt.Errorf("Failed to iterate gw dump files with error: %w", iter.Err())) } if len(objects) != 0 { - storeJobs(ctx, objects, gwHandle.handle.dbHandle) + storeJobs(ctx, objects, gwHandle.handle.dbHandle, gwHandle.handle.log) objects = nil } - pkgLogger.Info("Dumps loader job is done") + gwHandle.handle.log.Info("Dumps loader job is done") gwHandle.handle.done = true } func (procHandle *ProcErrorRequestHandler) fetchDumpsList(ctx context.Context) { objects := make([]OrderedJobs, 0) - pkgLogger.Info("Fetching proc err files list") + procHandle.handle.log.Info("Fetching proc err files list") var err error maxItems := config.GetInt64("MAX_ITEMS", 1000) // MAX_ITEMS is the max number of files to be fetched in one iteration from object storage uploadMaxItems := config.GetInt64("UPLOAD_MAX_ITEMS", 1) // UPLOAD_MAX_ITEMS is the max number of objects to be uploaded to postgres @@ -181,7 +178,7 @@ func (procHandle *ProcErrorRequestHandler) fetchDumpsList(ctx context.Context) { object := iter.Get() if strings.Contains(object.Key, "rudder-proc-err-logs") { if object.LastModified.Before(procHandle.handle.startTime) || (object.LastModified.Sub(procHandle.handle.endTime).Hours() > 1) { - pkgLogger.Debugf("Skipping object: %v ObjectLastModifiedTime: %v", object.Key, object.LastModified) + procHandle.handle.log.Debugf("Skipping object: %v ObjectLastModifiedTime: %v", object.Key, object.LastModified) continue } key := object.Key @@ -204,7 +201,7 @@ func (procHandle *ProcErrorRequestHandler) fetchDumpsList(ctx context.Context) { objects = append(objects, OrderedJobs{Job: &job, SortIndex: idx}) } if len(objects) >= int(uploadMaxItems) { - storeJobs(ctx, objects, procHandle.handle.dbHandle) + storeJobs(ctx, objects, procHandle.handle.dbHandle, procHandle.handle.log) objects = nil } @@ -213,10 +210,10 @@ func (procHandle *ProcErrorRequestHandler) fetchDumpsList(ctx context.Context) { panic(fmt.Errorf("Failed to iterate proc err files with error: %w", iter.Err())) } if len(objects) != 0 { - storeJobs(ctx, objects, procHandle.handle.dbHandle) + storeJobs(ctx, objects, procHandle.handle.dbHandle, procHandle.handle.log) } - pkgLogger.Info("Dumps loader job is done") + procHandle.handle.log.Info("Dumps loader job is done") procHandle.handle.done = true } @@ -226,8 +223,9 @@ func (handle *dumpsLoaderHandleT) handleRecovery() { } // Setup sets up dumps-loader. -func (handle *dumpsLoaderHandleT) Setup(ctx context.Context, db *jobsdb.HandleT, tablePrefix string, uploader filemanager.FileManager, bucket string) { +func (handle *dumpsLoaderHandleT) Setup(ctx context.Context, db *jobsdb.HandleT, tablePrefix string, uploader filemanager.FileManager, bucket string, log logger.Logger) { var err error + handle.log = log handle.dbHandle = db handle.handleRecovery() diff --git a/enterprise/replay/replay.go b/enterprise/replay/replay.go index c82412f0d15..88cbe1a8c00 100644 --- a/enterprise/replay/replay.go +++ b/enterprise/replay/replay.go @@ -10,9 +10,11 @@ import ( "github.com/rudderlabs/rudder-server/jobsdb" "github.com/rudderlabs/rudder-server/processor/transformer" "github.com/rudderlabs/rudder-server/services/filemanager" + "github.com/rudderlabs/rudder-server/utils/logger" ) type Handler struct { + log logger.Logger bucket string db *jobsdb.HandleT toDB *jobsdb.HandleT @@ -26,11 +28,11 @@ type Handler struct { } func (handle *Handler) generatorLoop(ctx context.Context) { - pkgLogger.Infof("generator reading from replay_jobs_* started") + handle.log.Infof("generator reading from replay_jobs_* started") var breakLoop bool select { case <-ctx.Done(): - pkgLogger.Infof("generator reading from replay_jobs_* stopped:Context cancelled") + handle.log.Infof("generator reading from replay_jobs_* stopped:Context cancelled") return case <-handle.initSourceWorkersChannel: } @@ -41,7 +43,7 @@ func (handle *Handler) generatorLoop(ctx context.Context) { } toRetry, err := handle.db.GetToRetry(context.TODO(), queryParams) if err != nil { - pkgLogger.Errorf("Error getting to retry jobs: %v", err) + handle.log.Errorf("Error getting to retry jobs: %v", err) panic(err) } combinedList := toRetry.Jobs @@ -50,21 +52,21 @@ func (handle *Handler) generatorLoop(ctx context.Context) { queryParams.JobsLimit -= len(combinedList) unprocessed, err := handle.db.GetUnprocessed(context.TODO(), queryParams) if err != nil { - pkgLogger.Errorf("Error getting unprocessed jobs: %v", err) + handle.log.Errorf("Error getting unprocessed jobs: %v", err) panic(err) } combinedList = append(combinedList, unprocessed.Jobs...) } - pkgLogger.Infof("length of combinedList : %d", len(combinedList)) + handle.log.Infof("length of combinedList : %d", len(combinedList)) if len(combinedList) == 0 { if breakLoop { executingList, err := handle.db.GetExecuting(context.TODO(), jobsdb.GetQueryParamsT{CustomValFilters: []string{"replay"}, JobsLimit: handle.dbReadSize}) if err != nil { - pkgLogger.Errorf("Error getting executing jobs: %v", err) + handle.log.Errorf("Error getting executing jobs: %v", err) panic(err) } - pkgLogger.Infof("breakLoop is set. Pending executing: %d", len(executingList.Jobs)) + handle.log.Infof("breakLoop is set. Pending executing: %d", len(executingList.Jobs)) if len(executingList.Jobs) == 0 { break } @@ -74,7 +76,7 @@ func (handle *Handler) generatorLoop(ctx context.Context) { breakLoop = true } - pkgLogger.Debugf("DB Read Complete. No Jobs to process") + handle.log.Debugf("DB Read Complete. No Jobs to process") time.Sleep(5 * time.Second) continue } @@ -121,7 +123,7 @@ func (handle *Handler) generatorLoop(ctx context.Context) { // Since generator read is done, closing worker channels for _, worker := range handle.workers { - pkgLogger.Infof("Closing worker channels") + handle.log.Infof("Closing worker channels") close(worker.channel) } } @@ -130,6 +132,7 @@ func (handle *Handler) initSourceWorkers(ctx context.Context) { handle.workers = make([]*SourceWorkerT, handle.noOfWorkers) for i := 0; i < handle.noOfWorkers; i++ { worker := &SourceWorkerT{ + log: handle.log, channel: make(chan *jobsdb.JobT, handle.dbReadSize), workerID: i, replayHandler: handle, @@ -144,7 +147,8 @@ func (handle *Handler) initSourceWorkers(ctx context.Context) { handle.initSourceWorkersChannel <- true } -func (handle *Handler) Setup(ctx context.Context, dumpsLoader *dumpsLoaderHandleT, db, toDB *jobsdb.HandleT, tablePrefix string, uploader filemanager.FileManager, bucket string) { +func (handle *Handler) Setup(ctx context.Context, dumpsLoader *dumpsLoaderHandleT, db, toDB *jobsdb.HandleT, tablePrefix string, uploader filemanager.FileManager, bucket string, log logger.Logger) { + handle.log = log handle.db = db handle.toDB = toDB handle.bucket = bucket diff --git a/enterprise/replay/setup.go b/enterprise/replay/setup.go index 7ed02632e5c..35bf101eec1 100644 --- a/enterprise/replay/setup.go +++ b/enterprise/replay/setup.go @@ -12,20 +12,17 @@ import ( "github.com/rudderlabs/rudder-server/utils/types" ) -var ( - pkgLogger logger.Logger - replayEnabled bool -) +var replayEnabled bool func loadConfig() { replayEnabled = config.GetBool("Replay.enabled", types.DEFAULT_REPLAY_ENABLED) config.RegisterIntConfigVariable(200, &userTransformBatchSize, true, 1, "Processor.userTransformBatchSize") } -func initFileManager() (filemanager.FileManager, string, error) { +func initFileManager(log logger.Logger) (filemanager.FileManager, string, error) { bucket := strings.TrimSpace(config.GetString("JOBS_REPLAY_BACKUP_BUCKET", "")) if bucket == "" { - pkgLogger.Error("[[ Replay ]] JOBS_REPLAY_BACKUP_BUCKET is not set") + log.Error("[[ Replay ]] JOBS_REPLAY_BACKUP_BUCKET is not set") panic("Bucket is not configured.") } @@ -44,24 +41,24 @@ func initFileManager() (filemanager.FileManager, string, error) { }), }) if err != nil { - pkgLogger.Errorf("[[ Replay ]] Error creating file manager: %s", err.Error()) + log.Errorf("[[ Replay ]] Error creating file manager: %s", err.Error()) return nil, "", err } return uploader, bucket, nil } -func setup(ctx context.Context, replayDB, gwDB, routerDB, batchRouterDB *jobsdb.HandleT) error { +func setup(ctx context.Context, replayDB, gwDB, routerDB, batchRouterDB *jobsdb.HandleT, log logger.Logger) error { tablePrefix := config.GetString("TO_REPLAY", "gw") replayToDB := config.GetString("REPLAY_TO_DB", "gw") - pkgLogger.Infof("TO_REPLAY=%s and REPLAY_TO_DB=%s", tablePrefix, replayToDB) + log.Infof("TO_REPLAY=%s and REPLAY_TO_DB=%s", tablePrefix, replayToDB) var dumpsLoader dumpsLoaderHandleT - uploader, bucket, err := initFileManager() + uploader, bucket, err := initFileManager(log) if err != nil { return err } - dumpsLoader.Setup(ctx, replayDB, tablePrefix, uploader, bucket) + dumpsLoader.Setup(ctx, replayDB, tablePrefix, uploader, bucket, log) var replayer Handler var toDB *jobsdb.HandleT @@ -76,25 +73,28 @@ func setup(ctx context.Context, replayDB, gwDB, routerDB, batchRouterDB *jobsdb. toDB = routerDB } _ = toDB.Start() - replayer.Setup(ctx, &dumpsLoader, replayDB, toDB, tablePrefix, uploader, bucket) + replayer.Setup(ctx, &dumpsLoader, replayDB, toDB, tablePrefix, uploader, bucket, log) return nil } type Factory struct { EnterpriseToken string + Log logger.Logger } // Setup initializes Replay feature func (m *Factory) Setup(ctx context.Context, replayDB, gwDB, routerDB, batchRouterDB *jobsdb.HandleT) { + if m.Log == nil { + m.Log = logger.NewLogger().Child("enterprise").Child("replay") + } if m.EnterpriseToken == "" { return } loadConfig() - pkgLogger = logger.NewLogger().Child("enterprise").Child("replay") if replayEnabled { - pkgLogger.Info("[[ Replay ]] Setting up Replay") - err := setup(ctx, replayDB, gwDB, routerDB, batchRouterDB) + m.Log.Info("[[ Replay ]] Setting up Replay") + err := setup(ctx, replayDB, gwDB, routerDB, batchRouterDB, m.Log) if err != nil { panic(err) } diff --git a/enterprise/replay/sourceWorker.go b/enterprise/replay/sourceWorker.go index f7c173c7fdf..33a8f7964f6 100644 --- a/enterprise/replay/sourceWorker.go +++ b/enterprise/replay/sourceWorker.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/rudderlabs/rudder-server/utils/logger" "github.com/rudderlabs/rudder-server/utils/misc" "github.com/gofrs/uuid" @@ -24,6 +25,7 @@ import ( ) type SourceWorkerT struct { + log logger.Logger channel chan *jobsdb.JobT workerID int replayHandler *Handler @@ -35,9 +37,9 @@ type SourceWorkerT struct { var userTransformBatchSize int func (worker *SourceWorkerT) workerProcess(ctx context.Context) { - pkgLogger.Debugf("worker started %d", worker.workerID) + worker.log.Debugf("worker started %d", worker.workerID) for job := range worker.channel { - pkgLogger.Debugf("job received: %s", job.EventPayload) + worker.log.Debugf("job received: %s", job.EventPayload) worker.replayJobsInFile(ctx, gjson.GetBytes(job.EventPayload, "location").String()) @@ -85,7 +87,7 @@ func (worker *SourceWorkerT) replayJobsInFile(ctx context.Context, filePath stri if err != nil { panic(err) // failed to download } - pkgLogger.Debugf("file downloaded at %s", path) + worker.log.Debugf("file downloaded at %s", path) defer func() { _ = file.Close() }() rawf, err := os.Open(path) @@ -120,7 +122,7 @@ func (worker *SourceWorkerT) replayJobsInFile(ctx context.Context, filePath stri if transformationVersionID == "" { createdAt, err := time.Parse(misc.POSTGRESTIMEFORMATPARSE, gjson.GetBytes(copyLineBytes, worker.getFieldIdentifier(createdAt)).String()) if err != nil { - pkgLogger.Errorf("failed to parse created at: %s", err) + worker.log.Errorf("failed to parse created at: %s", err) continue } if !(worker.replayHandler.dumpsLoader.startTime.Before(createdAt) && worker.replayHandler.dumpsLoader.endTime.After(createdAt)) { @@ -140,7 +142,7 @@ func (worker *SourceWorkerT) replayJobsInFile(ctx context.Context, filePath stri message, ok := gjson.ParseBytes(copyLineBytes).Value().(map[string]interface{}) if !ok { - pkgLogger.Errorf("EventPayload not a json: %v", copyLineBytes) + worker.log.Errorf("EventPayload not a json: %v", copyLineBytes) continue } @@ -168,17 +170,17 @@ func (worker *SourceWorkerT) replayJobsInFile(ctx context.Context, filePath stri for _, ev := range response.Events { destEventJSON, err := json.Marshal(ev.Output[worker.getFieldIdentifier(eventPayload)]) if err != nil { - pkgLogger.Errorf("Error unmarshalling transformer output: %v", err) + worker.log.Errorf("Error unmarshalling transformer output: %v", err) continue } createdAtString, ok := ev.Output[worker.getFieldIdentifier(createdAt)].(string) if !ok { - pkgLogger.Errorf("Error getting created at from transformer output: %v", err) + worker.log.Errorf("Error getting created at from transformer output: %v", err) continue } createdAt, err := time.Parse(misc.POSTGRESTIMEFORMATPARSE, createdAtString) if err != nil { - pkgLogger.Errorf("failed to parse created at: %s", err) + worker.log.Errorf("failed to parse created at: %s", err) continue } if !(worker.replayHandler.dumpsLoader.startTime.Before(createdAt) && worker.replayHandler.dumpsLoader.endTime.After(createdAt)) { @@ -186,7 +188,7 @@ func (worker *SourceWorkerT) replayJobsInFile(ctx context.Context, filePath stri } params, err := json.Marshal(ev.Output[worker.getFieldIdentifier(parameters)]) if err != nil { - pkgLogger.Errorf("Error unmarshalling transformer output: %v", err) + worker.log.Errorf("Error unmarshalling transformer output: %v", err) continue } job := jobsdb.JobT{ @@ -200,11 +202,11 @@ func (worker *SourceWorkerT) replayJobsInFile(ctx context.Context, filePath stri } for _, failedEv := range response.FailedEvents { - pkgLogger.Errorf(`Event failed in transformer with err: %v`, failedEv.Error) + worker.log.Errorf(`Event failed in transformer with err: %v`, failedEv.Error) } } - pkgLogger.Infof("brt-debug: TO_DB=%s", worker.replayHandler.toDB.Identifier()) + worker.log.Infof("brt-debug: TO_DB=%s", worker.replayHandler.toDB.Identifier()) err = worker.replayHandler.toDB.Store(ctx, jobs) if err != nil { @@ -213,7 +215,7 @@ func (worker *SourceWorkerT) replayJobsInFile(ctx context.Context, filePath stri err = os.Remove(path) if err != nil { - pkgLogger.Errorf("[%s]: failed to remove file with error: %w", err) + worker.log.Errorf("[%s]: failed to remove file with error: %w", err) } } diff --git a/enterprise/reporting/reporting.go b/enterprise/reporting/reporting.go index 258f9895f26..ae943785a79 100644 --- a/enterprise/reporting/reporting.go +++ b/enterprise/reporting/reporting.go @@ -47,7 +47,7 @@ type HandleT struct { onceInit sync.Once clients map[string]*types.Client clientsMapLock sync.RWMutex - logger logger.Logger + log logger.Logger reportingServiceURL string namespace string workspaceID string @@ -63,23 +63,22 @@ type HandleT struct { requestLatency stats.Measurement } -func NewFromEnvConfig() *HandleT { +func NewFromEnvConfig(log logger.Logger) *HandleT { var sleepInterval, mainLoopSleepInterval time.Duration reportingServiceURL := config.GetString("REPORTING_URL", "https://reporting.rudderstack.com/") reportingServiceURL = strings.TrimSuffix(reportingServiceURL, "/") config.RegisterDurationConfigVariable(5, &mainLoopSleepInterval, true, time.Second, "Reporting.mainLoopSleepInterval") config.RegisterDurationConfigVariable(30, &sleepInterval, true, time.Second, "Reporting.sleepInterval") config.RegisterIntConfigVariable(32, &maxConcurrentRequests, true, 1, "Reporting.maxConcurrentRequests") - reportingLogger := logger.NewLogger().Child("enterprise").Child("reporting") // only send reports for wh actions sources if whActionsOnly is configured whActionsOnly := config.GetBool("REPORTING_WH_ACTIONS_ONLY", false) if whActionsOnly { - reportingLogger.Info("REPORTING_WH_ACTIONS_ONLY enabled.only sending reports relevant to wh actions.") + log.Info("REPORTING_WH_ACTIONS_ONLY enabled.only sending reports relevant to wh actions.") } return &HandleT{ init: make(chan struct{}), - logger: reportingLogger, + log: log, clients: make(map[string]*types.Client), reportingServiceURL: reportingServiceURL, namespace: config.GetKubeNamespace(), @@ -93,7 +92,7 @@ func NewFromEnvConfig() *HandleT { } func (handle *HandleT) setup(beConfigHandle backendconfig.BackendConfig) { - handle.logger.Info("[[ Reporting ]] Setting up reporting handler") + handle.log.Info("[[ Reporting ]] Setting up reporting handler") ch := beConfigHandle.Subscribe(context.TODO(), backendconfig.TopicBackendConfig) @@ -321,7 +320,7 @@ func (handle *HandleT) mainLoop(ctx context.Context, clientName string) { handle.requestLatency = stats.Default.NewTaggedStat(STAT_REPORTING_HTTP_REQ_LATENCY, stats.TimerType, tags) for { if ctx.Err() != nil { - handle.logger.Infof("stopping mainLoop for client %s : %s", clientName, ctx.Err()) + handle.log.Infof("stopping mainLoop for client %s : %s", clientName, ctx.Err()) return } requestChan := make(chan struct{}, maxConcurrentRequests) @@ -335,7 +334,7 @@ func (handle *HandleT) mainLoop(ctx context.Context, clientName string) { if len(reports) == 0 { select { case <-ctx.Done(): - handle.logger.Infof("stopping mainLoop for client %s : %s", clientName, ctx.Err()) + handle.log.Infof("stopping mainLoop for client %s : %s", clientName, ctx.Err()) return case <-time.After(handle.sleepInterval): } @@ -376,7 +375,7 @@ func (handle *HandleT) mainLoop(ctx context.Context, clientName string) { } _, err = dbHandle.Exec(sqlStatement) if err != nil { - handle.logger.Errorf(`[ Reporting ]: Error deleting local reports from %s: %v`, REPORTS_TABLE, err) + handle.log.Errorf(`[ Reporting ]: Error deleting local reports from %s: %v`, REPORTS_TABLE, err) } } @@ -405,7 +404,7 @@ func (handle *HandleT) sendMetric(ctx context.Context, netClient *http.Client, c httpRequestStart := time.Now() resp, err := netClient.Do(req) if err != nil { - handle.logger.Error(err.Error()) + handle.log.Error(err.Error()) return err } @@ -417,7 +416,7 @@ func (handle *HandleT) sendMetric(ctx context.Context, netClient *http.Client, c defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { - handle.logger.Error(err.Error()) + handle.log.Error(err.Error()) return err } @@ -429,10 +428,10 @@ func (handle *HandleT) sendMetric(ctx context.Context, netClient *http.Client, c b := backoff.WithContext(backoff.NewExponentialBackOff(), ctx) err = backoff.RetryNotify(operation, b, func(err error, t time.Duration) { - handle.logger.Errorf(`[ Reporting ]: Error reporting to service: %v`, err) + handle.log.Errorf(`[ Reporting ]: Error reporting to service: %v`, err) }) if err != nil { - handle.logger.Errorf(`[ Reporting ]: Error making request to reporting service: %v`, err) + handle.log.Errorf(`[ Reporting ]: Error making request to reporting service: %v`, err) } return err } diff --git a/enterprise/reporting/setup.go b/enterprise/reporting/setup.go index beb0a4ba7ef..21afc7e83a6 100644 --- a/enterprise/reporting/setup.go +++ b/enterprise/reporting/setup.go @@ -7,21 +7,22 @@ import ( "github.com/rudderlabs/rudder-server/config" backendconfig "github.com/rudderlabs/rudder-server/config/backend-config" "github.com/rudderlabs/rudder-server/rruntime" + "github.com/rudderlabs/rudder-server/utils/logger" "github.com/rudderlabs/rudder-server/utils/types" ) type Factory struct { - EnterpriseToken string - + EnterpriseToken string + Log logger.Logger once sync.Once reportingInstance types.ReportingI - - // for debug purposes, to be removed - init uint32 } // Setup initializes Suppress User feature func (m *Factory) Setup(backendConfig backendconfig.BackendConfig) types.ReportingI { + if m.Log == nil { + m.Log = logger.NewLogger().Child("enterprise").Child("reporting") + } m.once.Do(func() { reportingEnabled := config.GetBool("Reporting.enabled", types.DEFAULT_REPORTING_ENABLED) if !reportingEnabled { @@ -34,7 +35,7 @@ func (m *Factory) Setup(backendConfig backendconfig.BackendConfig) types.Reporti return } - h := NewFromEnvConfig() + h := NewFromEnvConfig(m.Log) rruntime.Go(func() { h.setup(backendConfig) }) diff --git a/enterprise/suppress-user/factory.go b/enterprise/suppress-user/factory.go new file mode 100644 index 00000000000..096a0bce350 --- /dev/null +++ b/enterprise/suppress-user/factory.go @@ -0,0 +1,82 @@ +package suppression + +import ( + "context" + "fmt" + "io" + "net/http" + "path" + "time" + + "github.com/rudderlabs/rudder-server/config" + backendconfig "github.com/rudderlabs/rudder-server/config/backend-config" + "github.com/rudderlabs/rudder-server/rruntime" + "github.com/rudderlabs/rudder-server/utils/logger" + "github.com/rudderlabs/rudder-server/utils/misc" + "github.com/rudderlabs/rudder-server/utils/types" +) + +type Factory struct { + EnterpriseToken string + Log logger.Logger +} + +// Setup initializes the user suppression feature +func (m *Factory) Setup(ctx context.Context, backendConfig backendconfig.BackendConfig) (types.UserSuppression, error) { + if m.Log == nil { + m.Log = logger.NewLogger().Child("enterprise").Child("suppress-user") + } + if m.EnterpriseToken == "" { + m.Log.Info("Suppress User feature is enterprise only") + return &NOOP{}, nil + } + m.Log.Info("Setting up Suppress User Feature") + backendConfig.WaitForConfig(ctx) + var repository Repository + if config.GetBool("BackendConfig.Regulations.useBadgerDB", true) { + tmpDir, err := misc.CreateTMPDIR() + if err != nil { + return nil, fmt.Errorf("could not create tmp dir: %w", err) + } + path := path.Join(tmpDir, "suppression") + + // TODO: implement seeder source, to retrieve the initial state from a persisted backup + var seederSource func() (io.Reader, error) + + repository, err = NewBadgerRepository( + path, + m.Log, + WithSeederSource(seederSource), + WithMaxSeedWait(config.GetDuration("BackendConfig.Regulations.maxSeedWait", 5, time.Second))) + if err != nil { + return nil, fmt.Errorf("could not create badger repository: %w", err) + } + } else { + repository = NewMemoryRepository(m.Log) + } + + var pollInterval time.Duration + config.RegisterDurationConfigVariable(300, &pollInterval, true, time.Second, "BackendConfig.Regulations.pollInterval") + + syncer, err := NewSyncer( + config.GetString("SUPPRESS_USER_BACKEND_URL", "https://api.rudderstack.com"), + backendConfig.Identity(), + repository, + WithLogger(m.Log), + WithHttpClient(&http.Client{Timeout: config.GetDuration("HttpClient.suppressUser.timeout", 30, time.Second)}), + WithPageSize(config.GetInt("BackendConfig.Regulations.pageSize", 5000)), + WithPollIntervalFn(func() time.Duration { return pollInterval }), + ) + if err != nil { + return nil, err + } + + h := newHandler(ctx, repository, m.Log) + + rruntime.Go(func() { + syncer.SyncLoop(ctx) + _ = repository.Stop() + }) + + return h, nil +} diff --git a/enterprise/suppress-user/handler.go b/enterprise/suppress-user/handler.go new file mode 100644 index 00000000000..b571c98b9f1 --- /dev/null +++ b/enterprise/suppress-user/handler.go @@ -0,0 +1,57 @@ +package suppression + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/rudderlabs/rudder-server/utils/logger" + "github.com/samber/lo" +) + +// newHandler creates a new handler for the suppression feature +func newHandler(ctx context.Context, r Repository, log logger.Logger) *handler { + h := &handler{ + r: r, + log: log, + } + // we don't want to flood logs if, e.g. the suppression repository is restoring, + // so we are debouncing the logging + h.errLog.debounceLog, h.errLog.cancel = lo.NewDebounce(1*time.Second, func() { + h.errLog.errMu.Lock() + defer h.errLog.errMu.Unlock() + if h.errLog.err != nil { + h.log.Warn(h.errLog.err.Error()) + } + }) + go func() { + <-ctx.Done() + h.errLog.cancel() + }() + return h +} + +// handler is a handle to this object +type handler struct { + log logger.Logger + r Repository + errLog struct { + debounceLog func() // logs suppression failures with a debounce, once every 1 second + cancel func() // cancels the debounce timer + errMu sync.Mutex + err error // the last error + } +} + +func (h *handler) IsSuppressedUser(workspaceID, userID, sourceID string) bool { + h.log.Debugf("IsSuppressedUser called for workspace: %s, user %s, source %s", workspaceID, userID, sourceID) + suppressed, err := h.r.Suppress(workspaceID, userID, sourceID) + if err != nil { + h.errLog.errMu.Lock() + h.errLog.err = fmt.Errorf("suppression check failed for workspace: %s, user: %s, source: %s: %w", workspaceID, userID, sourceID, err) + h.errLog.debounceLog() + h.errLog.errMu.Unlock() + } + return suppressed +} diff --git a/enterprise/suppress-user/handler_test.go b/enterprise/suppress-user/handler_test.go new file mode 100644 index 00000000000..431427cedc9 --- /dev/null +++ b/enterprise/suppress-user/handler_test.go @@ -0,0 +1,45 @@ +package suppression + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/rudderlabs/rudder-server/utils/logger" + "github.com/stretchr/testify/require" +) + +func TestDebounceLogConcurrency(t *testing.T) { + ctx := context.Background() + log := &tLog{Logger: logger.NOP} + h := newHandler(ctx, &fakeSuppresser{}, log) + + var wg sync.WaitGroup + wg.Add(1000) + for i := 0; i < 1000; i++ { + go func() { + defer wg.Done() + _ = h.IsSuppressedUser("workspaceID", "userID", "sourceID") + }() + } + wg.Wait() + require.Less(t, log.times, 1000) +} + +type fakeSuppresser struct { + Repository +} + +func (*fakeSuppresser) Suppress(_, _, _ string) (bool, error) { + return false, fmt.Errorf("some error") +} + +type tLog struct { + times int + logger.Logger +} + +func (t *tLog) Warn(_ ...interface{}) { + t.times++ +} diff --git a/enterprise/suppress-user/internal/badgerdb/badgerdb.go b/enterprise/suppress-user/internal/badgerdb/badgerdb.go new file mode 100644 index 00000000000..8593cbe6715 --- /dev/null +++ b/enterprise/suppress-user/internal/badgerdb/badgerdb.go @@ -0,0 +1,307 @@ +package badgerdb + +import ( + "errors" + "fmt" + "io" + "os" + "path" + "sync" + "time" + + "github.com/dgraph-io/badger/v3" + "github.com/dgraph-io/badger/v3/options" + "github.com/rudderlabs/rudder-server/enterprise/suppress-user/model" + "github.com/rudderlabs/rudder-server/utils/logger" + "github.com/samber/lo" +) + +// the key used in badgerdb to store the current token +const tokenKey = "__token__" + +// Opt is a function that configures a badgerdb repository +type Opt func(*Repository) + +// WithSeederSource sets the source of the seed data +func WithSeederSource(seederSource func() (io.Reader, error)) Opt { + return func(r *Repository) { + r.seederSource = seederSource + } +} + +// WithMaxSeedWait sets the maximum time to wait for the seed to complete. +// If the seed takes longer than this, the repository will be started in restoring state and all +// repository methods will return [ErrRestoring] until the seed completes. The default wait time is 10 seconds. +func WithMaxSeedWait(maxSeedWait time.Duration) Opt { + return func(r *Repository) { + r.maxSeedWait = maxSeedWait + } +} + +// Repository is a repository backed by badgerdb +type Repository struct { + // logger to use + log logger.Logger + // path to the badger db directory + path string + // max number of goroutines to use (badger config) + maxGoroutines int + + maxSeedWait time.Duration + seederSource func() (io.Reader, error) + + db *badger.DB + + // lock to prevent concurrent access to db during restore + restoringLock sync.RWMutex + restoring bool + closeOnce sync.Once + closed chan struct{} +} + +// NewRepository returns a new repository backed by badgerdb. +func NewRepository(basePath string, log logger.Logger, opts ...Opt) (*Repository, error) { + b := &Repository{ + log: log, + path: path.Join(basePath, "badgerdbv3"), + maxGoroutines: 1, + maxSeedWait: 10 * time.Second, + } + for _, opt := range opts { + opt(b) + } + + return b, b.start() +} + +// GetToken returns the current token +func (b *Repository) GetToken() ([]byte, error) { + b.restoringLock.RLock() + defer b.restoringLock.RUnlock() // release the read lock at the end of the operation + if b.restoring { + return nil, model.ErrRestoring + } + if b.db.IsClosed() { + return nil, badger.ErrDBClosed + } + + var token []byte + err := b.db.View(func(txn *badger.Txn) error { + item, err := txn.Get([]byte(tokenKey)) + if err != nil { + return fmt.Errorf("could not get token: %w", err) + } + if err = item.Value(func(val []byte) error { + token = val + return nil + }); err != nil { + return fmt.Errorf("could not get token value: %w", err) + } + return nil + }) + if err != nil && !errors.Is(err, badger.ErrKeyNotFound) { + return nil, err + } + return token, nil +} + +// Suppress returns true if the given user is suppressed, false otherwise +func (b *Repository) Suppress(workspaceID, userID, sourceID string) (bool, error) { + b.restoringLock.RLock() + defer b.restoringLock.RUnlock() + if b.restoring { + return false, model.ErrRestoring + } + if b.db.IsClosed() { + return false, badger.ErrDBClosed + } + + keyPrefix := fmt.Sprintf("%s:%s:", workspaceID, userID) + err := b.db.View(func(txn *badger.Txn) error { + wildcardKey := keyPrefix + model.Wildcard + _, err := txn.Get([]byte(wildcardKey)) + if err == nil { + return nil + } + if !errors.Is(err, badger.ErrKeyNotFound) { + return fmt.Errorf("could not get wildcard key %s: %w", wildcardKey, err) + } + sourceKey := keyPrefix + sourceID + if _, err = txn.Get([]byte(sourceKey)); err != nil { + return fmt.Errorf("could not get sourceID key %s: %w", sourceKey, err) + } + return err + }) + if err != nil { + if errors.Is(err, badger.ErrKeyNotFound) { + return false, nil + } + return false, err + } + return true, nil +} + +// Add adds the given suppressions to the repository +func (b *Repository) Add(suppressions []model.Suppression, token []byte) error { + b.restoringLock.RLock() + defer b.restoringLock.RUnlock() + if b.restoring { + return model.ErrRestoring + } + if b.db.IsClosed() { + return badger.ErrDBClosed + } + wb := b.db.NewWriteBatch() + defer wb.Cancel() + + for i := range suppressions { + suppression := suppressions[i] + keyPrefix := fmt.Sprintf("%s:%s:", suppression.WorkspaceID, suppression.UserID) + var keys []string + if len(suppression.SourceIDs) == 0 { + keys = []string{keyPrefix + model.Wildcard} + } else { + keys = make([]string, len(suppression.SourceIDs)) + for i, sourceID := range suppression.SourceIDs { + keys[i] = keyPrefix + sourceID + } + } + for _, key := range keys { + var err error + if suppression.Canceled { + err = wb.Delete([]byte(key)) + } else { + err = wb.Set([]byte(key), []byte("")) + } + if err != nil { + return fmt.Errorf("could not add key %s (canceled:%t) in write batch: %w", key, suppression.Canceled, err) + } + } + + } + if err := wb.Set([]byte(tokenKey), token); err != nil { + return fmt.Errorf("could not add token key %s in write batch: %w", tokenKey, err) + } + if err := wb.Flush(); err != nil { + return fmt.Errorf("could not flush write batch: %w", err) + } + return nil +} + +// start the repository +func (b *Repository) start() error { + b.closed = make(chan struct{}) + var seeder io.Reader + if _, err := os.Stat(b.path); os.IsNotExist(err) && b.seederSource != nil { + if seeder, err = b.seederSource(); err != nil { + return err + } + } + + opts := badger. + DefaultOptions(b.path). + WithLogger(blogger{b.log}). + WithCompression(options.None). + WithIndexCacheSize(16 << 20). // 16mb + WithNumGoroutines(b.maxGoroutines) + var err error + b.db, err = badger.Open(opts) + if err != nil { + return err + } + + if seeder != nil { + restoreDone := lo.Async(func() error { + if err := b.Restore(seeder); err != nil { + b.log.Error("Failed to restore badgerdb", "error", err) + return err + } + return nil + }) + select { + case <-restoreDone: + case <-time.After(b.maxSeedWait): + b.log.Warn("Badgerdb still restoring after %s, proceeding...", b.maxSeedWait) + } + } + + go func() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for { + select { + case <-b.closed: + return + case <-ticker.C: + } + again: // see https://dgraph.io/docs/badger/get-started/#garbage-collection + err := b.db.RunValueLogGC(0.7) + if err == nil { + goto again + } + } + }() + return nil +} + +// Stop stops the repository +func (b *Repository) Stop() error { + var err error + b.closeOnce.Do(func() { + close(b.closed) + err = b.db.Close() + }) + return err +} + +// Backup writes a backup of the repository to the given writer +func (b *Repository) Backup(w io.Writer) error { + b.restoringLock.RLock() + defer b.restoringLock.RUnlock() + if b.restoring { + return model.ErrRestoring + } + if b.db.IsClosed() { + return badger.ErrDBClosed + } + _, err := b.db.Backup(w, 0) + return err +} + +// Restore restores the repository from the given reader +func (b *Repository) Restore(r io.Reader) (err error) { + if b.isRestoring() { + return model.ErrRestoring + } + if b.db.IsClosed() { + return badger.ErrDBClosed + } + b.setRestoring(true) + defer b.setRestoring(false) + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic during restore: %v", r) + } + }() + return b.db.Load(r, b.maxGoroutines) +} + +func (b *Repository) setRestoring(restoring bool) { + b.restoringLock.Lock() + b.restoring = restoring + b.restoringLock.Unlock() +} + +func (b *Repository) isRestoring() bool { + b.restoringLock.RLock() + defer b.restoringLock.RUnlock() + return b.restoring +} + +type blogger struct { + logger.Logger +} + +func (l blogger) Warningf(fmt string, args ...interface{}) { + l.Warnf(fmt, args...) +} diff --git a/enterprise/suppress-user/internal/badgerdb/badgerdb_backup_bench_test.go b/enterprise/suppress-user/internal/badgerdb/badgerdb_backup_bench_test.go new file mode 100644 index 00000000000..b42b6aaf8eb --- /dev/null +++ b/enterprise/suppress-user/internal/badgerdb/badgerdb_backup_bench_test.go @@ -0,0 +1,91 @@ +package badgerdb_test + +import ( + "compress/gzip" + "fmt" + "os" + "path" + "testing" + "time" + + "github.com/rudderlabs/rudder-server/enterprise/suppress-user/internal/badgerdb" + "github.com/rudderlabs/rudder-server/enterprise/suppress-user/model" + "github.com/rudderlabs/rudder-server/utils/logger" + "github.com/stretchr/testify/require" +) + +// BenchmarkBackupRestore benchmarks the backup and restore time of the badger repository +// after seeding it with a very large number of suppressions. +func BenchmarkBackupRestore(b *testing.B) { + b.StopTimer() + totalSuppressions := 40_000_000 + batchSize := 5000 + backupFilename := path.Join(b.TempDir(), "backup.badger") + + repo1Path := path.Join(b.TempDir(), "repo-1") + repo1, err := badgerdb.NewRepository(repo1Path, logger.NOP) + require.NoError(b, err) + + for i := 0; i < totalSuppressions/batchSize; i++ { + suppressions := generateSuppressions(i*batchSize, batchSize) + token := []byte(fmt.Sprintf("token%d", i)) + require.NoError(b, repo1.Add(suppressions, token)) + } + + b.Run("backup", func(b *testing.B) { + b.StopTimer() + start := time.Now() + f, err := os.Create(backupFilename) + w, _ := gzip.NewWriterLevel(f, gzip.BestSpeed) + defer func() { _ = f.Close() }() + require.NoError(b, err) + b.StartTimer() + require.NoError(b, repo1.Backup(w)) + w.Close() + b.StopTimer() + dur := time.Since(start) + fileInfo, err := f.Stat() + require.NoError(b, err) + b.ReportMetric(float64(fileInfo.Size()/1024/1024), "filesize(MB)") + b.ReportMetric(dur.Seconds(), "duration(sec)") + require.NoError(b, repo1.Stop()) + }) + + b.Run("restore", func(b *testing.B) { + b.StopTimer() + repo2Path := path.Join(b.TempDir(), "repo-2") + repo2, err := badgerdb.NewRepository(repo2Path, logger.NOP) + require.NoError(b, err) + + f, err := os.Open(backupFilename) + require.NoError(b, err) + defer func() { _ = f.Close() }() + r, err := gzip.NewReader(f) + require.NoError(b, err) + fileInfo, err := f.Stat() + require.NoError(b, err) + b.ReportMetric(float64(fileInfo.Size()/1024/1024), "filesize(MB)") + + start := time.Now() + b.StartTimer() + require.NoError(b, repo2.Restore(r)) + b.StopTimer() + r.Close() + b.ReportMetric(time.Since(start).Seconds(), "duration(sec)") + + require.NoError(b, repo2.Stop()) + }) +} + +func generateSuppressions(startFrom, batchSize int) []model.Suppression { + var res []model.Suppression + + for i := startFrom; i < startFrom+batchSize; i++ { + res = append(res, model.Suppression{ + Canceled: false, + WorkspaceID: "1yaBlqltp5Y4V2NK8qePowlyaaaa", + UserID: fmt.Sprintf("client-%d", i), + }) + } + return res +} diff --git a/enterprise/suppress-user/internal/badgerdb/badgerdb_repo_test.go b/enterprise/suppress-user/internal/badgerdb/badgerdb_repo_test.go new file mode 100644 index 00000000000..188959765e2 --- /dev/null +++ b/enterprise/suppress-user/internal/badgerdb/badgerdb_repo_test.go @@ -0,0 +1,18 @@ +package badgerdb_test + +import ( + "testing" + + "github.com/rudderlabs/rudder-server/enterprise/suppress-user/internal/badgerdb" + "github.com/rudderlabs/rudder-server/enterprise/suppress-user/internal/repotest" + "github.com/rudderlabs/rudder-server/utils/logger" + "github.com/stretchr/testify/require" +) + +// TestBadgerRepoSpec tests the badgerdb repository implementation. +func TestBadgerRepoSpec(t *testing.T) { + path := t.TempDir() + repo, err := badgerdb.NewRepository(path, logger.NOP) + require.NoError(t, err) + repotest.RunRepositoryTestSuite(t, repo) +} diff --git a/enterprise/suppress-user/internal/badgerdb/badgerdb_test.go b/enterprise/suppress-user/internal/badgerdb/badgerdb_test.go new file mode 100644 index 00000000000..0f951874292 --- /dev/null +++ b/enterprise/suppress-user/internal/badgerdb/badgerdb_test.go @@ -0,0 +1,152 @@ +package badgerdb_test + +import ( + "bytes" + "context" + "io" + "path" + "strings" + "sync" + "testing" + "time" + + "github.com/dgraph-io/badger/v3" + + "github.com/gofrs/uuid" + "github.com/rudderlabs/rudder-server/enterprise/suppress-user/internal/badgerdb" + "github.com/rudderlabs/rudder-server/enterprise/suppress-user/model" + "github.com/rudderlabs/rudder-server/utils/logger" + "github.com/stretchr/testify/require" +) + +type readerFunc func(p []byte) (n int, err error) + +func (f readerFunc) Read(p []byte) (n int, err error) { + return f(p) +} + +// TestBadgerRepository contains badgerdb-specific tests. +func TestBadgerRepository(t *testing.T) { + basePath := path.Join(t.TempDir(), strings.ReplaceAll(uuid.Must(uuid.NewV4()).String(), "-", "")) + token := []byte("token") + repo, err := badgerdb.NewRepository(basePath, logger.NOP) + require.NoError(t, err) + + t.Run("trying to use a repository during restore", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var reader readerFunc = func(_ []byte) (int, error) { + <-ctx.Done() + return 0, ctx.Err() + } + var wg sync.WaitGroup + wg.Add(1) + go func() { + require.Error(t, repo.Restore(reader)) + wg.Done() + }() + + time.Sleep(1 * time.Millisecond) + err := repo.Add([]model.Suppression{ + { + WorkspaceID: "workspace1", + UserID: "user1", + SourceIDs: []string{}, + }, + { + WorkspaceID: "workspace2", + UserID: "user2", + SourceIDs: []string{"source1"}, + }, + }, token) + require.Error(t, err, "it should return an error when trying to add a suppression to a repository that is restoring") + require.ErrorIs(t, model.ErrRestoring, err) + + err = repo.Restore(nil) + require.Error(t, err, "it should return an error when trying to restore a repository that is already restoring") + require.ErrorIs(t, model.ErrRestoring, err) + + err = repo.Backup(nil) + require.Error(t, err, "it should return an error when trying to backup a repository that is already restoring") + require.ErrorIs(t, model.ErrRestoring, err) + + _, err = repo.GetToken() + require.Error(t, err, "it should return an error when trying to get the token from a repository that is restoring") + require.ErrorIs(t, model.ErrRestoring, err) + + _, err = repo.Suppress("workspace2", "user2", "source2") + require.Error(t, err, "it should return an error when trying to suppress a user from a repository that is restoring") + require.ErrorIs(t, model.ErrRestoring, err) + + cancel() + wg.Wait() // wait for the restore to finish + }) + + defer func() { _ = repo.Stop() }() + + t.Run("trying to start a second repository using the same path", func(t *testing.T) { + _, err := badgerdb.NewRepository(basePath, logger.NOP) + require.Error(t, err, "it should return an error when trying to start a second repository using the same path") + }) + + backup := []byte{} + buffer := bytes.NewBuffer(backup) + t.Run("backup", func(t *testing.T) { + require.NoError(t, repo.Backup(buffer), "it should be able to backup the repository without an error") + }) + + t.Run("restore", func(t *testing.T) { + require.NoError(t, repo.Restore(buffer), "it should be able to restore the repository without an error") + }) + + t.Run("new with seeder", func(t *testing.T) { + basePath := path.Join(t.TempDir(), "badger-test-2") + _, err := badgerdb.NewRepository(basePath, logger.NOP, badgerdb.WithSeederSource(func() (io.Reader, error) { + return buffer, nil + }), badgerdb.WithMaxSeedWait(1*time.Millisecond)) + require.NoError(t, err) + }) + + t.Run("try to restore invalid data", func(t *testing.T) { + r := bytes.NewBuffer([]byte("invalid data")) + require.Error(t, repo.Restore(r), "it should return an error when trying to restore invalid data") + }) + + t.Run("badgerdb errors", func(t *testing.T) { + require.NoError(t, repo.Stop(), "it should be able to stop the badgerdb instance without an error") + + _, err := repo.Suppress("workspace1", "user1", "") + require.Error(t, err) + + _, err = repo.GetToken() + require.Error(t, err) + + require.Error(t, repo.Add([]model.Suppression{}, []byte(""))) + + require.Error(t, repo.Add([]model.Suppression{{ + WorkspaceID: "workspace1", + UserID: "user1", + SourceIDs: []string{}, + }}, []byte("token"))) + }) + + t.Run("trying to use a closed repository", func(t *testing.T) { + repo, err := badgerdb.NewRepository(basePath, logger.NOP) + require.NoError(t, err) + require.NoError(t, repo.Stop()) + + require.Equal(t, repo.Add(nil, nil), badger.ErrDBClosed) + + s, err := repo.Suppress("", "", "") + require.False(t, s) + require.Equal(t, err, badger.ErrDBClosed) + + _, err = repo.GetToken() + require.Equal(t, err, badger.ErrDBClosed) + + require.Equal(t, repo.Backup(nil), badger.ErrDBClosed) + + require.Equal(t, repo.Restore(nil), badger.ErrDBClosed) + }) +} diff --git a/enterprise/suppress-user/internal/memory/memory.go b/enterprise/suppress-user/internal/memory/memory.go new file mode 100644 index 00000000000..e3e7356facc --- /dev/null +++ b/enterprise/suppress-user/internal/memory/memory.go @@ -0,0 +1,104 @@ +package memory + +import ( + "io" + "sync" + + "github.com/rudderlabs/rudder-server/enterprise/suppress-user/model" + "github.com/rudderlabs/rudder-server/utils/logger" +) + +// Repository is a repository backed by memory. +type Repository struct { + log logger.Logger + token []byte + suppressionsMu sync.RWMutex + suppressions map[string]map[string]map[string]struct{} +} + +// NewRepository returns a new repository backed by memory. +func NewRepository(log logger.Logger) *Repository { + m := &Repository{ + log: log, + suppressions: make(map[string]map[string]map[string]struct{}), + } + return m +} + +// GetToken returns the current token +func (m *Repository) GetToken() ([]byte, error) { + return m.token, nil +} + +// Suppress returns true if the given user is suppressed, false otherwise +func (m *Repository) Suppress(workspaceID, userID, sourceID string) (bool, error) { + m.suppressionsMu.RLock() + defer m.suppressionsMu.RUnlock() + workspace, ok := m.suppressions[workspaceID] + if !ok { + return false, nil + } + sourceIDs, ok := workspace[userID] + if !ok { + return false, nil + } + if _, ok := sourceIDs[model.Wildcard]; ok { + return true, nil + } + if _, ok := sourceIDs[sourceID]; ok { + return true, nil + } + return false, nil +} + +// Add adds the given suppressions to the repository +func (m *Repository) Add(suppressions []model.Suppression, token []byte) error { + m.suppressionsMu.Lock() + defer m.suppressionsMu.Unlock() + for i := range suppressions { + suppression := suppressions[i] + var keys []string + if len(suppression.SourceIDs) == 0 { + keys = []string{model.Wildcard} + } else { + keys = make([]string, len(suppression.SourceIDs)) + copy(keys, suppression.SourceIDs) + } + workspace, ok := m.suppressions[suppression.WorkspaceID] + if !ok { + workspace = make(map[string]map[string]struct{}) + m.suppressions[suppression.WorkspaceID] = workspace + } + user, ok := workspace[suppression.UserID] + if !ok { + user = make(map[string]struct{}) + m.suppressions[suppression.WorkspaceID][suppression.UserID] = user + } + if suppression.Canceled { + for _, key := range keys { + delete(user, key) + } + } else { + for _, key := range keys { + user[key] = struct{}{} + } + } + } + m.token = token + return nil +} + +// Stop is a no-op for the memory repository. +func (*Repository) Stop() error { + return nil +} + +// Backup is not supported for the memory repository. +func (*Repository) Backup(_ io.Writer) error { + return model.ErrNotSupported +} + +// Restore is not supported for the memory repository. +func (*Repository) Restore(_ io.Reader) error { + return model.ErrNotSupported +} diff --git a/enterprise/suppress-user/internal/memory/memory_repo_test.go b/enterprise/suppress-user/internal/memory/memory_repo_test.go new file mode 100644 index 00000000000..b4de08588c1 --- /dev/null +++ b/enterprise/suppress-user/internal/memory/memory_repo_test.go @@ -0,0 +1,14 @@ +package memory_test + +import ( + "testing" + + "github.com/rudderlabs/rudder-server/enterprise/suppress-user/internal/memory" + "github.com/rudderlabs/rudder-server/enterprise/suppress-user/internal/repotest" + "github.com/rudderlabs/rudder-server/utils/logger" +) + +// TestMemoryRepoSpec tests the memory repository implementation. +func TestMemoryRepoSpec(t *testing.T) { + repotest.RunRepositoryTestSuite(t, memory.NewRepository(logger.NOP)) +} diff --git a/enterprise/suppress-user/internal/repotest/repotest.go b/enterprise/suppress-user/internal/repotest/repotest.go new file mode 100644 index 00000000000..26267f53e9e --- /dev/null +++ b/enterprise/suppress-user/internal/repotest/repotest.go @@ -0,0 +1,162 @@ +package repotest + +import ( + "testing" + + suppression "github.com/rudderlabs/rudder-server/enterprise/suppress-user" + "github.com/rudderlabs/rudder-server/enterprise/suppress-user/model" + "github.com/stretchr/testify/require" +) + +func RunRepositoryTestSuite(t *testing.T, repo suppression.Repository) { + token := []byte("token") + + t.Run("get the token before setting anything", func(t *testing.T) { + rtoken, err := repo.GetToken() + require.NoError(t, err) + require.Nil(t, rtoken, "it should return nil when trying to get the token before setting it") + }) + + t.Run("adding suppressions", func(t *testing.T) { + err := repo.Add([]model.Suppression{ + { + WorkspaceID: "workspace1", + UserID: "user1", + SourceIDs: []string{}, + }, + { + WorkspaceID: "workspace2", + UserID: "user2", + SourceIDs: []string{"source1"}, + }, + }, token) + require.NoError(t, err, "it should be able to add some suppressions without an error") + }) + + t.Run("get token after setting it", func(t *testing.T) { + rtoken, err := repo.GetToken() + require.NoError(t, err) + require.Equal(t, token, rtoken, "it should return the token that was previously set") + }) + + t.Run("wildcard suppression", func(t *testing.T) { + suppressed, err := repo.Suppress("workspace1", "user1", "source1") + require.NoError(t, err) + require.True(t, suppressed, "it should return true when trying to suppress a user that is suppressed by a wildcard suppression") + + suppressed, err = repo.Suppress("workspace1", "user1", "source2") + require.NoError(t, err) + require.True(t, suppressed, "it should return true when trying to suppress a user that is suppressed by a wildcard suppression") + }) + + t.Run("exact suppression", func(t *testing.T) { + suppressed, err := repo.Suppress("workspace2", "user2", "source1") + require.NoError(t, err) + require.True(t, suppressed, "it should return true when trying to suppress a user that is suppressed by an exact suppression") + }) + + t.Run("non matching key", func(t *testing.T) { + suppressed, err := repo.Suppress("workspace3", "user3", "source2") + require.NoError(t, err) + require.False(t, suppressed, "it should return false when trying to suppress a user that is not suppressed") + }) + + t.Run("non matching suppression", func(t *testing.T) { + suppressed, err := repo.Suppress("workspace2", "user2", "source2") + require.NoError(t, err) + require.False(t, suppressed, "it should return false when trying to suppress a user that is suppressed for a different sourceID") + }) + + t.Run("canceling a suppression", func(t *testing.T) { + suppressed, err := repo.Suppress("workspace1", "user1", "source1") + require.NoError(t, err) + require.True(t, suppressed, "it should return true when trying to suppress a user that is suppressed by a wildcard suppression") + + token2 := []byte("token2") + err = repo.Add([]model.Suppression{ + { + WorkspaceID: "workspace1", + Canceled: true, + UserID: "user1", + SourceIDs: []string{}, + }, + }, token2) + require.NoError(t, err) + rtoken, err := repo.GetToken() + require.NoError(t, err) + require.Equal(t, token2, rtoken) + + suppressed, err = repo.Suppress("workspace1", "user1", "source1") + require.NoError(t, err) + require.False(t, suppressed, "it should return false when trying to suppress a user that was suppressed by a wildcard suppression after the suppression has been canceled") + }) + + t.Run("multiple suppressions for the same userID", func(t *testing.T) { + err := repo.Add([]model.Suppression{ + { + WorkspaceID: "workspaceX", + UserID: "userX", + SourceIDs: []string{}, + }, + { + WorkspaceID: "workspaceX", + UserID: "userX", + SourceIDs: []string{"source1"}, + }, + { + WorkspaceID: "workspaceX", + UserID: "userX", + SourceIDs: []string{"source2"}, + }, + }, token) + require.NoError(t, err, "it should be able to add some suppressions without an error") + + suppressed, err := repo.Suppress("workspaceX", "userX", "sourceX") + require.NoError(t, err) + require.True(t, suppressed, "it should return true when trying to suppress a user that is suppressed by a wildcard suppression") + + require.NoError(t, repo.Add([]model.Suppression{ + { + Canceled: true, + WorkspaceID: "workspaceX", + UserID: "userX", + SourceIDs: []string{}, + }, + }, token)) + suppressed, err = repo.Suppress("workspaceX", "userX", "sourceX") + require.NoError(t, err) + require.False(t, suppressed, "it should return false when trying to suppress a user that is no longer suppressed by a wildcard suppression") + + suppressed, err = repo.Suppress("workspaceX", "userX", "source1") + require.NoError(t, err) + require.True(t, suppressed, "it should return true when trying to suppress a user that is still suppressed by an exact match suppression") + + require.NoError(t, repo.Add([]model.Suppression{ + { + Canceled: true, + WorkspaceID: "workspaceX", + UserID: "userX", + SourceIDs: []string{"source1"}, + }, + }, token)) + suppressed, err = repo.Suppress("workspaceX", "userX", "source1") + require.NoError(t, err) + require.False(t, suppressed, "it should return false when trying to suppress a user that is no longer suppressed by an exact match suppression") + + suppressed, err = repo.Suppress("workspaceX", "userX", "source2") + require.NoError(t, err) + require.True(t, suppressed, "it should return true when trying to suppress a user that is still suppressed by an exact match suppression") + + require.NoError(t, repo.Add([]model.Suppression{ + { + Canceled: true, + WorkspaceID: "workspaceX", + UserID: "userX", + SourceIDs: []string{"source2"}, + }, + }, token)) + suppressed, err = repo.Suppress("workspaceX", "userX", "source2") + require.NoError(t, err) + require.False(t, suppressed, "it should return false when trying to suppress a user that is no longer suppressed by an exact match suppression") + }) +} diff --git a/enterprise/suppress-user/model/model.go b/enterprise/suppress-user/model/model.go new file mode 100644 index 00000000000..85194b172d4 --- /dev/null +++ b/enterprise/suppress-user/model/model.go @@ -0,0 +1,16 @@ +package model + +import "errors" + +var ( + ErrRestoring = errors.New("repository is restoring") + ErrNotSupported = errors.New("operation not supported") +) +var Wildcard = "*" + +type Suppression struct { + WorkspaceID string `json:"workspaceId"` + Canceled bool `json:"canceled"` + UserID string `json:"userId"` + SourceIDs []string `json:"sourceIds"` +} diff --git a/enterprise/suppress-user/repository.go b/enterprise/suppress-user/repository.go new file mode 100644 index 00000000000..36cbafdfd87 --- /dev/null +++ b/enterprise/suppress-user/repository.go @@ -0,0 +1,46 @@ +package suppression + +import ( + "io" + + "github.com/rudderlabs/rudder-server/enterprise/suppress-user/internal/badgerdb" + "github.com/rudderlabs/rudder-server/enterprise/suppress-user/internal/memory" + "github.com/rudderlabs/rudder-server/enterprise/suppress-user/model" + "github.com/rudderlabs/rudder-server/utils/logger" +) + +// Repository provides a generic interface for managing user suppressions +type Repository interface { + // Stop stops the repository + Stop() error + + // GetToken returns the current token + GetToken() ([]byte, error) + + // Add adds the given suppressions to the repository + Add(suppressions []model.Suppression, token []byte) error + + // Suppress returns true if the given user is suppressed, false otherwise + Suppress(workspaceID, userID, sourceID string) (bool, error) + + // Backup writes a backup of the repository to the given writer + Backup(w io.Writer) error + + // Restore restores the repository from the given reader + Restore(r io.Reader) error +} + +// NewMemoryRepository returns a new repository backed by memory. +func NewMemoryRepository(log logger.Logger) Repository { + return memory.NewRepository(log) +} + +var ( + WithSeederSource = badgerdb.WithSeederSource + WithMaxSeedWait = badgerdb.WithMaxSeedWait +) + +// NewBadgerRepository returns a new repository backed by badgerDB. +func NewBadgerRepository(path string, log logger.Logger, opts ...badgerdb.Opt) (Repository, error) { + return badgerdb.NewRepository(path, log, opts...) +} diff --git a/enterprise/suppress-user/repository_bench_test.go b/enterprise/suppress-user/repository_bench_test.go new file mode 100644 index 00000000000..294946fae3c --- /dev/null +++ b/enterprise/suppress-user/repository_bench_test.go @@ -0,0 +1,94 @@ +package suppression_test + +import ( + "fmt" + "math/rand" + "testing" + "time" + + suppression "github.com/rudderlabs/rudder-server/enterprise/suppress-user" + "github.com/rudderlabs/rudder-server/enterprise/suppress-user/model" + "github.com/rudderlabs/rudder-server/utils/logger" + "github.com/stretchr/testify/require" +) + +func BenchmarkAddAndSuppress(b *testing.B) { + totalSuppressions := 20_000_000 + totalReads := 10_000 + batchSize := 5000 + + runAddSuppressionsBenchmark := func(b *testing.B, repo suppression.Repository, batchSize, totalSuppressions int) { + var totalTime time.Duration + var totalAddSuppressions int + for i := 0; i < totalSuppressions/batchSize; i++ { + suppressions1 := generateSuppressions(i*batchSize/2, batchSize/2) + suppressions2 := generateSuppressions(i*batchSize/2, batchSize/2) + token := []byte(fmt.Sprintf("token%d", i)) + start := time.Now() + require.NoError(b, repo.Add(suppressions1, token)) + require.NoError(b, repo.Add(suppressions2, token)) + totalTime += time.Since(start) + totalAddSuppressions += batchSize + } + b.ReportMetric(float64(totalSuppressions)/totalTime.Seconds(), "suppressions/s(add)") + } + + runSuppressBenchmark := func(b *testing.B, repo suppression.Repository, totalSuppressions, totalReads int) { + var totalTime time.Duration + for i := 0; i < totalReads; i++ { + start := time.Now() + idx := randomInt(totalSuppressions * 2) // multiply by 2 to include non-existing keys suppressions + _, err := repo.Suppress(fmt.Sprintf("workspace%d", idx), fmt.Sprintf("user%d", idx), fmt.Sprintf("source%d", idx)) + require.NoError(b, err) + totalTime += time.Since(start) + } + b.ReportMetric(float64(totalSuppressions)/totalTime.Seconds(), "suppressions/s(read)") + } + + runAddAndSuppressBenchmark := func(b *testing.B, repo suppression.Repository, totalSuppressions, batchSize, totalReads int) { + runAddSuppressionsBenchmark(b, repo, batchSize, totalSuppressions) + runSuppressBenchmark(b, repo, totalSuppressions, totalReads) + } + + b.Run("badger", func(b *testing.B) { + repo, err := suppression.NewBadgerRepository(b.TempDir(), logger.NOP) + require.NoError(b, err) + defer func() { _ = repo.Stop() }() + runAddAndSuppressBenchmark(b, repo, totalSuppressions, batchSize, totalReads) + }) + + b.Run("memory", func(b *testing.B) { + repo := suppression.NewMemoryRepository(logger.NOP) + defer func() { _ = repo.Stop() }() + runAddAndSuppressBenchmark(b, repo, totalSuppressions, batchSize, totalReads) + }) +} + +func generateSuppressions(startFrom, batchSize int) []model.Suppression { + var res []model.Suppression + + for i := startFrom; i < startFrom+batchSize; i++ { + var sourceIDs []string + wildcard := randomInt(2) == 0 + if wildcard { + sourceIDs = []string{} + } else { + sourceIDs = []string{fmt.Sprintf("source%d", i), "otherSource", "anotherSource"} + } + res = append(res, model.Suppression{ + Canceled: randomInt(2) == 0, + WorkspaceID: fmt.Sprintf("workspace%d", i), + UserID: fmt.Sprintf("user%d", i), + SourceIDs: sourceIDs, + }) + } + return res +} + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func randomInt(lt int) int { + return rand.Int() % lt // skipcq: GSC-G404 +} diff --git a/enterprise/suppress-user/setup.go b/enterprise/suppress-user/setup.go deleted file mode 100644 index c1ef39a7e11..00000000000 --- a/enterprise/suppress-user/setup.go +++ /dev/null @@ -1,51 +0,0 @@ -package suppression - -import ( - "context" - "strconv" - "time" - - "github.com/rudderlabs/rudder-server/config" - backendconfig "github.com/rudderlabs/rudder-server/config/backend-config" - "github.com/rudderlabs/rudder-server/utils/logger" - "github.com/rudderlabs/rudder-server/utils/types" -) - -type Factory struct { - EnterpriseToken string -} - -var ( - regulationsPollInterval time.Duration - configBackendURL string - suppressionApiPageSize int -) - -func loadConfig() { - config.RegisterDurationConfigVariable(300, ®ulationsPollInterval, true, time.Second, "BackendConfig.Regulations.pollInterval") - config.RegisterIntConfigVariable(50, &suppressionApiPageSize, false, 1, "BackendConfig.Regulations.pageSize") - configBackendURL = config.GetString("CONFIG_BACKEND_URL", "https://api.rudderstack.com") -} - -// Setup initializes Suppress User feature -func (m *Factory) Setup(backendConfig backendconfig.BackendConfig) (types.UserSuppression, error) { - pkgLogger = logger.NewLogger().Child("enterprise").Child("suppress-user") - - if m.EnterpriseToken == "" { - pkgLogger.Info("Suppress User feature is enterprise only") - return &NOOP{}, nil - } - - pkgLogger.Info("[[ SuppressUser ]] Setting up Suppress User Feature") - loadConfig() - ctx := context.TODO() - backendConfig.WaitForConfig(ctx) - suppressUser := &SuppressRegulationHandler{ - RegulationsPollInterval: regulationsPollInterval, - ID: backendConfig.Identity(), - pageSize: strconv.Itoa(suppressionApiPageSize), - } - suppressUser.setup(ctx) - - return suppressUser, nil -} diff --git a/enterprise/suppress-user/suppressUser.go b/enterprise/suppress-user/suppressUser.go deleted file mode 100644 index 96eb430ecad..00000000000 --- a/enterprise/suppress-user/suppressUser.go +++ /dev/null @@ -1,270 +0,0 @@ -package suppression - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "net/url" - "strconv" - "sync" - "time" - - "github.com/cenkalti/backoff" - "github.com/rudderlabs/rudder-server/config" - "github.com/rudderlabs/rudder-server/services/controlplane/identity" - "github.com/rudderlabs/rudder-server/utils/misc" - "github.com/rudderlabs/rudder-server/utils/types/deployment" - - "github.com/rudderlabs/rudder-server/utils/logger" - - "github.com/rudderlabs/rudder-server/rruntime" -) - -// SuppressRegulationHandler is a handle to this object -type SuppressRegulationHandler struct { - Client *http.Client - RegulationBackendURL string - RegulationsPollInterval time.Duration - ID identity.Identifier - userSpecificSuppressedSourceMap map[string]map[string]sourceFilter - regulationsSubscriberLock sync.RWMutex - suppressAPIToken string - pageSize string - once sync.Once -} - -type sourceFilter struct { - all bool - specific map[string]struct{} -} - -var pkgLogger logger.Logger - -type apiResponse struct { - SourceRegulations []sourceRegulation `json:"items"` - Token string `json:"token"` -} - -type sourceRegulation struct { - Canceled bool `json:"canceled"` - WorkspaceID string `json:"workspaceId"` - UserID string `json:"userId"` - SourceIDs []string `json:"sourceIds"` -} - -func (suppressUser *SuppressRegulationHandler) setup(ctx context.Context) { - suppressUser.RegulationBackendURL = configBackendURL - switch suppressUser.ID.Type() { - case deployment.DedicatedType: - suppressUser.RegulationBackendURL += fmt.Sprintf("/dataplane/workspaces/%s/regulations/suppressions", suppressUser.ID.ID()) - case deployment.MultiTenantType: - suppressUser.RegulationBackendURL += fmt.Sprintf("/dataplane/namespaces/%s/regulations/suppressions", suppressUser.ID.ID()) - default: - panic("invalid deployment type") - } - rruntime.Go(func() { - suppressUser.regulationSyncLoop(ctx) - }) -} - -func (suppressUser *SuppressRegulationHandler) IsSuppressedUser(workspaceID, userID, sourceID string) bool { - suppressUser.init() - pkgLogger.Debugf("IsSuppressedUser called for %v, %v, %v", workspaceID, sourceID, userID) - suppressUser.regulationsSubscriberLock.RLock() - defer suppressUser.regulationsSubscriberLock.RUnlock() - if _, ok := suppressUser.userSpecificSuppressedSourceMap[workspaceID]; ok { - if _, ok := suppressUser.userSpecificSuppressedSourceMap[workspaceID][userID]; ok { - m := suppressUser.userSpecificSuppressedSourceMap[workspaceID][userID] - if m.all { - return true - } - if _, ok := m.specific[sourceID]; ok { - return true - } - } - } - return false -} - -// Gets the regulations from data regulation service -func (suppressUser *SuppressRegulationHandler) regulationSyncLoop(ctx context.Context) { - suppressUser.init() - pageSize, err := strconv.Atoi(suppressUser.pageSize) - if err != nil { - pkgLogger.Error("invalid page size") - suppressUser.pageSize = "" - pageSize = 0 - } - - for { - if ctx.Err() != nil { - return - } - pkgLogger.Info("Fetching Regulations") - regulations, err := suppressUser.getSourceRegulationsFromRegulationService() - if err != nil { - misc.SleepCtx(ctx, regulationsPollInterval) - continue - } - // need to discuss the correct place tp put this lock - suppressUser.regulationsSubscriberLock.Lock() - for _, sourceRegulation := range regulations { - userId := sourceRegulation.UserID - workspaceID := sourceRegulation.WorkspaceID - _, ok := suppressUser.userSpecificSuppressedSourceMap[workspaceID] - if !ok { - suppressUser.userSpecificSuppressedSourceMap[workspaceID] = make(map[string]sourceFilter) - } - if len(sourceRegulation.SourceIDs) == 0 { - if _, ok := suppressUser.userSpecificSuppressedSourceMap[workspaceID][userId]; !ok { - if !sourceRegulation.Canceled { - m := sourceFilter{ - all: true, - specific: map[string]struct{}{}, - } - suppressUser.userSpecificSuppressedSourceMap[workspaceID][userId] = m - continue - } - } - m := suppressUser.userSpecificSuppressedSourceMap[workspaceID][userId] - if sourceRegulation.Canceled { - m.all = false - } else { - m.all = true - } - suppressUser.userSpecificSuppressedSourceMap[workspaceID][userId] = m - } else { - if _, ok := suppressUser.userSpecificSuppressedSourceMap[workspaceID][userId]; !ok { - if !sourceRegulation.Canceled { - m := sourceFilter{ - specific: map[string]struct{}{}, - } - for _, srcId := range sourceRegulation.SourceIDs { - m.specific[srcId] = struct{}{} - } - suppressUser.userSpecificSuppressedSourceMap[workspaceID][userId] = m - continue - } - } - m := suppressUser.userSpecificSuppressedSourceMap[workspaceID][userId] - if sourceRegulation.Canceled { - for _, srcId := range sourceRegulation.SourceIDs { - delete(m.specific, srcId) // will be no-op if key is not there in map - } - } else { - for _, srcId := range sourceRegulation.SourceIDs { - m.specific[srcId] = struct{}{} - } - } - suppressUser.userSpecificSuppressedSourceMap[workspaceID][userId] = m - } - } - suppressUser.regulationsSubscriberLock.Unlock() - - if len(regulations) == 0 || len(regulations) < pageSize { - misc.SleepCtx(ctx, regulationsPollInterval) - } - } -} - -func (suppressUser *SuppressRegulationHandler) getSourceRegulationsFromRegulationService() ([]sourceRegulation, error) { - urlStr := suppressUser.RegulationBackendURL - urlValQuery := url.Values{} - if suppressUser.pageSize != "" { - urlValQuery.Set("pageSize", suppressUser.pageSize) - } - if suppressUser.suppressAPIToken != "" { - urlValQuery.Set("pageToken", suppressUser.suppressAPIToken) - } - if len(urlValQuery) > 0 { - urlStr += "?" + urlValQuery.Encode() - } - - var resp *http.Response - var respBody []byte - - operation := func() error { - var err error - req, err := http.NewRequest("GET", urlStr, http.NoBody) - pkgLogger.Debugf("regulation service URL: %s", urlStr) - if err != nil { - return err - } - req.SetBasicAuth(suppressUser.ID.BasicAuth()) - req.Header.Set("Content-Type", "application/json") - - resp, err = suppressUser.Client.Do(req) - if err != nil { - return err - } - // If statusCode is not 2xx, then returning empty regulations - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - err = fmt.Errorf("status code %v", resp.StatusCode) - pkgLogger.Errorf("[[ Workspace-config ]] Failed to fetch source regulations. statusCode: %v, error: %v", - resp.StatusCode, err) - return err - } - - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - pkgLogger.Error(err) - } - }(resp.Body) - - respBody, err = io.ReadAll(resp.Body) - if err != nil { - pkgLogger.Error(err) - return err - } - return err - } - - backoffWithMaxRetry := backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 3) - err := backoff.RetryNotify(operation, backoffWithMaxRetry, func(err error, t time.Duration) { - pkgLogger.Errorf("[[ Workspace-config ]] Failed to fetch source regulations from API with error: %v, retrying after %v", err, t) - }) - if err != nil { - pkgLogger.Error("Error sending request to the server: ", err) - return []sourceRegulation{}, err - } - if respBody == nil { - pkgLogger.Error("nil response body, returning") - return []sourceRegulation{}, errors.New("nil response body") - } - var sourceRegulationsJSON apiResponse - err = json.Unmarshal(respBody, &sourceRegulationsJSON) - if err != nil { - pkgLogger.Error("Error while parsing request: ", err, resp.StatusCode) - return []sourceRegulation{}, err - } - // TODO: remove this once regulation Service is updated - for i := range sourceRegulationsJSON.SourceRegulations { - sourceRegulation := &sourceRegulationsJSON.SourceRegulations[i] - if sourceRegulation.WorkspaceID == "" { - sourceRegulation.WorkspaceID = suppressUser.ID.ID() - } - } - - if sourceRegulationsJSON.Token == "" { - pkgLogger.Errorf("[[ Workspace-config ]] No token found in the source regulations response: %v", string(respBody)) - return sourceRegulationsJSON.SourceRegulations, fmt.Errorf("no token returned in regulation API response") - } - suppressUser.suppressAPIToken = sourceRegulationsJSON.Token - return sourceRegulationsJSON.SourceRegulations, nil -} - -func (suppressUser *SuppressRegulationHandler) init() { - suppressUser.once.Do(func() { - pkgLogger.Info("init Regulations") - if len(suppressUser.userSpecificSuppressedSourceMap) == 0 { - suppressUser.userSpecificSuppressedSourceMap = map[string]map[string]sourceFilter{} - } - if suppressUser.Client == nil { - suppressUser.Client = &http.Client{Timeout: config.GetDuration("HttpClient.suppressUser.timeout", 30, time.Second)} - } - }) -} diff --git a/enterprise/suppress-user/suppressUser_test.go b/enterprise/suppress-user/suppressUser_test.go deleted file mode 100644 index aacbe6a7636..00000000000 --- a/enterprise/suppress-user/suppressUser_test.go +++ /dev/null @@ -1,396 +0,0 @@ -package suppression - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "testing" - "time" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" - "github.com/rudderlabs/rudder-server/config" - backendconfig "github.com/rudderlabs/rudder-server/config/backend-config" - "github.com/rudderlabs/rudder-server/services/controlplane/identity" - "github.com/rudderlabs/rudder-server/utils/logger" - "github.com/stretchr/testify/require" -) - -var _ = Describe("SuppressUser Test", func() { - var testSuppressUser *SuppressRegulationHandler - BeforeEach(func() { - config.Reset() - logger.Reset() - backendconfig.Init() - pkgLogger = logger.NewLogger().Child("enterprise").Child("suppress-user") - testSuppressUser = &SuppressRegulationHandler{ - Client: new(http.Client), - RegulationsPollInterval: time.Duration(100), - ID: &identity.Workspace{ - WorkspaceID: "workspace1", - WorkspaceToken: "token-1", - }, - } - }) - expectedRespRegulations := sourceRegulation{ - WorkspaceID: "workspace1", - Canceled: false, - UserID: "user-1", - SourceIDs: []string{"src-1", "src-2"}, - } - expectedResp := apiResponse{ - SourceRegulations: []sourceRegulation{expectedRespRegulations}, - Token: "tempToken123", - } - - Context("getSourceRegulationsFromRegulationService error cases", func() { - It("wrong server address", func() { - srv := createSimpleTestServer(nil) - defer srv.Close() - _, err := testSuppressUser.getSourceRegulationsFromRegulationService() - Expect(err.Error()).NotTo(Equal(nil)) - }) - - It("500 server error", func() { - srv := createSimpleTestServer(&serverInp{statusCode: 500}) - defer srv.Close() - testSuppressUser.RegulationBackendURL = srv.URL - _, err := testSuppressUser.getSourceRegulationsFromRegulationService() - Expect(err.Error()).To(Equal("status code 500")) - }) - - It("invalid data in response body", func() { - srv := createSimpleTestServer(&serverInp{statusCode: 200, respBody: []byte("")}) - defer srv.Close() - testSuppressUser.RegulationBackendURL = srv.URL - _, err := testSuppressUser.getSourceRegulationsFromRegulationService() - Expect(err.Error()).To(Equal("unexpected end of JSON input")) - }) - - It("invalid data in response body", func() { - srv := createSimpleTestServer(&serverInp{statusCode: 200, respBody: []byte("{w")}) - defer srv.Close() - testSuppressUser.RegulationBackendURL = srv.URL - _, err := testSuppressUser.getSourceRegulationsFromRegulationService() - Expect(err.Error()).To(Equal("invalid character 'w' looking for beginning of object key string")) - }) - - It("no token in response body", func() { - srv := createSimpleTestServer(&serverInp{statusCode: 200, respBody: []byte("{}")}) - defer srv.Close() - testSuppressUser.RegulationBackendURL = srv.URL - _, err := testSuppressUser.getSourceRegulationsFromRegulationService() - Expect(err.Error()).To(Equal("no token returned in regulation API response")) - }) - }) - - Context("getSourceRegulationsFromRegulationService valid response", func() { - It("no token in response body", func() { - tempResp := expectedResp - tempResp.Token = "" - expectedRespBody, _ := json.Marshal(tempResp) - srv := createSimpleTestServer(&serverInp{statusCode: 200, respBody: expectedRespBody}) - defer srv.Close() - testSuppressUser.RegulationBackendURL = srv.URL - resp, err := testSuppressUser.getSourceRegulationsFromRegulationService() - Expect(err.Error()).To(Equal("no token returned in regulation API response")) - Expect(resp).To(Equal([]sourceRegulation{expectedRespRegulations})) - }) - }) - - Context("IsSuppressedUser", func() { - It("user suppression rule added and user-id is same", func() { - r := expectedRespRegulations - expectedResp := apiResponse{ - SourceRegulations: []sourceRegulation{r}, - Token: "tempToken123", - } - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - queryParams := r.URL.Query()["pageToken"] - if len(queryParams) != 0 { - expectedResp = apiResponse{ - Token: "tempToken123", - } - expectedRespBody, _ := json.Marshal(expectedResp) - w.Write(expectedRespBody) - } else { - expectedRespBody, _ := json.Marshal(expectedResp) - w.Write(expectedRespBody) - } - })) - defer srv.Close() - configBackendURL = srv.URL - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - testSuppressUser.setup(ctx) - Eventually(func() bool { return testSuppressUser.IsSuppressedUser("workspace1", "user-1", "src-1") }).Should(BeTrue()) - }) - - It("user suppression cancelled after adding first", func() { - tempResp := expectedResp - tempResp.SourceRegulations = append(tempResp.SourceRegulations, sourceRegulation{ - WorkspaceID: "ws-1", - Canceled: false, - UserID: "user-2", - SourceIDs: []string{"src-1", "src-2"}, - }) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - queryParams := r.URL.Query()["pageToken"] - if len(queryParams) != 0 { - r := expectedRespRegulations - r.Canceled = true - expectedResp := apiResponse{ - SourceRegulations: []sourceRegulation{r}, - Token: "tempToken123", - } - expectedRespBody, _ := json.Marshal(expectedResp) - w.Write(expectedRespBody) - } else { - expectedRespBody, _ := json.Marshal(tempResp) - w.Write(expectedRespBody) - } - })) - defer srv.Close() - configBackendURL = srv.URL - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - testSuppressUser.setup(ctx) - Eventually(func() bool { return testSuppressUser.IsSuppressedUser("ws-1", "user-2", "src-1") }).Should(BeTrue()) - Eventually(func() bool { return testSuppressUser.IsSuppressedUser("ws-1", "user-1", "src-1") }).Should(BeFalse()) - }) - - It("user suppression rule added for all the sources", func() { - r1 := sourceRegulation{ - WorkspaceID: "ws-1", - Canceled: false, - UserID: "user-1", - SourceIDs: []string{}, - } - r2 := sourceRegulation{ - WorkspaceID: "ws-1", - Canceled: false, - UserID: "user-2", - SourceIDs: []string{"src-2"}, - } - expectedResp := apiResponse{ - SourceRegulations: []sourceRegulation{r1, r2}, - Token: "tempToken123", - } - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - queryParams := r.URL.Query()["pageToken"] - if len(queryParams) != 0 { - expectedResp = apiResponse{ - Token: "tempToken123", - } - expectedRespBody, _ := json.Marshal(expectedResp) - w.Write(expectedRespBody) - } else { - expectedRespBody, _ := json.Marshal(expectedResp) - w.Write(expectedRespBody) - } - })) - defer srv.Close() - configBackendURL = srv.URL - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - testSuppressUser.setup(ctx) - Eventually(func() bool { return testSuppressUser.IsSuppressedUser("ws-1", "user-1", "src-1") }).Should(BeTrue()) - Eventually(func() bool { return testSuppressUser.IsSuppressedUser("ws-1", "user-2", "src-2") }).Should(BeTrue()) - Eventually(func() bool { return testSuppressUser.IsSuppressedUser("ws-1", "user-2", "src-1") }).Should(BeFalse()) - }) - - It("user suppression rule added for all the sources and then cancelled", func() { - r1 := sourceRegulation{ - WorkspaceID: "ws-1", - Canceled: false, - UserID: "user-1", - SourceIDs: []string{}, - } - r2 := sourceRegulation{ - WorkspaceID: "ws-1", - Canceled: false, - UserID: "user-2", - SourceIDs: []string{"src-2"}, - } - expectedResp := apiResponse{ - SourceRegulations: []sourceRegulation{r1, r2}, - Token: "tempToken123", - } - firstCheck := make(chan struct{}) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - queryParams := r.URL.Query()["pageToken"] - if len(queryParams) != 0 { - <-firstCheck - r1 = sourceRegulation{ - WorkspaceID: "ws-1", - Canceled: true, - UserID: "user-1", - SourceIDs: []string{}, - } - expectedResp = apiResponse{ - SourceRegulations: []sourceRegulation{r1}, - Token: "tempToken123", - } - expectedRespBody, _ := json.Marshal(expectedResp) - w.Write(expectedRespBody) - } else { - expectedRespBody, _ := json.Marshal(expectedResp) - w.Write(expectedRespBody) - } - })) - defer srv.Close() - - configBackendURL = srv.URL - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - testSuppressUser.setup(ctx) - Eventually(func() bool { return testSuppressUser.IsSuppressedUser("ws-1", "user-1", "src-1") }).Should(BeTrue()) - Eventually(func() bool { return testSuppressUser.IsSuppressedUser("ws-1", "user-2", "src-2") }).Should(BeTrue()) - Eventually(func() bool { return testSuppressUser.IsSuppressedUser("ws-1", "user-2", "src-1") }).Should(BeFalse()) - close(firstCheck) - Eventually(func() bool { return testSuppressUser.IsSuppressedUser("ws-1", "user-1", "src-1") }).Should(BeFalse()) - }) - }) - - Context("adaptations for multi-tenant", func() { - It("supports older version of regulation-service", func() { - // it doesn't return workspaceID as part of the regulations - // (in a single-tenant setup) - r := sourceRegulation{ - Canceled: false, - UserID: "user-1", - SourceIDs: []string{"src-1"}, - } - - expectedResp := apiResponse{ - SourceRegulations: []sourceRegulation{r}, - Token: "tempToken1234", - } - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - fmt.Println(r.URL.String()) - queryParams := r.URL.Query()["pageToken"] - if len(queryParams) != 0 { - expectedResp = apiResponse{ - Token: "tempToken1234", - } - expectedRespBody, _ := json.Marshal(expectedResp) - w.Write(expectedRespBody) - } else { - expectedRespBody, _ := json.Marshal(expectedResp) - w.Write(expectedRespBody) - } - })) - defer srv.Close() - configBackendURL = srv.URL - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - testSuppressUser.setup(ctx) - Expect(testSuppressUser.RegulationBackendURL). - To(Equal( - srv.URL + "/dataplane/workspaces/workspace1/regulations/suppressions", - )) - Eventually(func() bool { - return testSuppressUser.IsSuppressedUser("workspace1", "user-1", "src-1") - }).Should(BeTrue()) - }) - }) - - It("supports fetching namespaces' suppressions - contains multiple workspaces' suppressions", func() { - r1 := sourceRegulation{ - WorkspaceID: "ws-1", - Canceled: false, - UserID: "user-1", - SourceIDs: []string{"src-1"}, - } - r2 := sourceRegulation{ - WorkspaceID: "ws-2", - Canceled: false, - UserID: "user-2", - SourceIDs: []string{"src-2"}, - } - expectedResp := apiResponse{ - SourceRegulations: []sourceRegulation{r1, r2}, - Token: "tempToken123", - } - testSuppressUser.ID = &identity.Namespace{ - Namespace: "ns-1", - HostedSecret: `secret`, - } - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - expectedRespBody, _ := json.Marshal(expectedResp) - w.Write(expectedRespBody) - })) - defer srv.Close() - configBackendURL = srv.URL - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - testSuppressUser.setup(ctx) - Expect(testSuppressUser.RegulationBackendURL). - To(Equal( - srv.URL + "/dataplane/namespaces/ns-1/regulations/suppressions", - )) - Eventually(func() bool { - return testSuppressUser.IsSuppressedUser("ws-1", "user-1", "src-1") - }).Should(BeTrue()) - Eventually(func() bool { - return testSuppressUser.IsSuppressedUser("ws-2", "user-2", "src-2") - }).Should(BeTrue()) - }) -}) - -type serverInp struct { - statusCode int - respBody []byte -} - -func createSimpleTestServer(inp *serverInp) *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if inp != nil { - w.WriteHeader(inp.statusCode) - _, err := w.Write(inp.respBody) - if err != nil { - fmt.Println("failed to write data to response body in test server") - return - } - } - })) -} - -func TestSuppressRegulationHandler_IsSuppressedUser(t *testing.T) { - config.Reset() - logger.Reset() - pkgLogger = logger.NewLogger().Child("enterprise").Child("suppress-user") - - suppressUserMap := make(map[string]map[string]sourceFilter) - suppressUserMap["ws-1"] = make(map[string]sourceFilter) - suppressUserMap["ws-1"]["user1"] = sourceFilter{ - all: true, - specific: nil, - } - specificSrc := map[string]struct{}{ - "src1": {}, - "src2": {}, - } - suppressUserMap["ws-1"]["user2"] = sourceFilter{ - all: false, - specific: specificSrc, - } - s := &SuppressRegulationHandler{ - userSpecificSuppressedSourceMap: suppressUserMap, - } - - require.True(t, s.IsSuppressedUser("ws-1", "user1", "src1")) - require.True(t, s.IsSuppressedUser("ws-1", "user1", "randomNewSrc")) - require.True(t, s.IsSuppressedUser("ws-1", "user2", "src1")) - require.True(t, s.IsSuppressedUser("ws-1", "user2", "src2")) - require.False(t, s.IsSuppressedUser("ws-1", "user2", "src3")) - require.False(t, s.IsSuppressedUser("ws-1", "user2", "randomNewSrc")) -} diff --git a/enterprise/suppress-user/suppress_user_test.go b/enterprise/suppress-user/suppress_user_test.go new file mode 100644 index 00000000000..e50965c34ed --- /dev/null +++ b/enterprise/suppress-user/suppress_user_test.go @@ -0,0 +1,350 @@ +package suppression + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "path" + "strings" + "time" + + "github.com/gofrs/uuid" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/rudderlabs/rudder-server/config" + backendconfig "github.com/rudderlabs/rudder-server/config/backend-config" + "github.com/rudderlabs/rudder-server/enterprise/suppress-user/model" + "github.com/rudderlabs/rudder-server/services/controlplane/identity" + "github.com/rudderlabs/rudder-server/utils/logger" +) + +var _ = Describe("Suppress user", func() { + Context("memory", func() { + generateTests(func() Repository { + return NewMemoryRepository(logger.NewLogger()) + }) + }) + + Context("badgerdb", func() { + generateTests(func() Repository { + p := path.Join(GinkgoT().TempDir(), strings.ReplaceAll(uuid.Must(uuid.NewV4()).String(), "-", "")) + r, err := NewBadgerRepository(p, logger.NOP) + Expect(err).To(BeNil()) + return r + }) + }) +}) + +func generateTests(getRepo func() Repository) { + type syncResponse struct { + expectedUrl string + statusCode int + respBody []byte + } + identifier := &identity.Workspace{ + WorkspaceID: "workspace-1", + } + defaultSuppression := model.Suppression{ + Canceled: false, + WorkspaceID: "workspace-1", + UserID: "user-1", + SourceIDs: []string{"src-1", "src-2"}, + } + defaultResponse := suppressionsResponse{ + Items: []model.Suppression{defaultSuppression}, + Token: "tempToken123", + } + + var h *handler + var serverResponse syncResponse + var server *httptest.Server + var ctx context.Context + var cancel context.CancelFunc + newTestServer := func() *httptest.Server { + var count int + var prevRespBody []byte + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if serverResponse.expectedUrl != "" { + Expect(r.URL.Path).To(Equal(serverResponse.expectedUrl)) + } + var respBody []byte + // send the expected payload if it is the first time or the payload has changed + if count == 0 || prevRespBody != nil && string(prevRespBody) != string(serverResponse.respBody) { + respBody = serverResponse.respBody + prevRespBody = serverResponse.respBody + count++ + } else { // otherwise send an response containing no items + respBody, _ = json.Marshal(suppressionsResponse{ + Token: "tempToken123", + }) + } + + w.WriteHeader(serverResponse.statusCode) + _, _ = w.Write(respBody) + })) + } + + BeforeEach(func() { + ctx, cancel = context.WithCancel(context.Background()) + config.Reset() + backendconfig.Init() + server = newTestServer() + + r := getRepo() + h = newHandler(ctx, r, logger.NOP) + }) + AfterEach(func() { + server.Close() + cancel() + }) + Context("sync error scenarios", func() { + It("returns an error when a wrong server address is provided", func() { + _, _, err := MustNewSyncer("", identifier, h.r).sync(nil) + Expect(err.Error()).NotTo(Equal(nil)) + }) + + It("returns an error when server responds with HTTP 500", func() { + serverResponse = syncResponse{ + statusCode: 500, + respBody: []byte(""), + } + _, _, err := MustNewSyncer(server.URL, identifier, h.r).sync(nil) + Expect(err.Error()).To(Equal("status code 500")) + }) + + It("returns an error when server responds with invalid (empty) data in the response body", func() { + serverResponse = syncResponse{ + statusCode: 200, + respBody: []byte(""), + } + _, _, err := MustNewSyncer(server.URL, identifier, h.r).sync(nil) + Expect(err.Error()).To(Equal("unexpected end of JSON input")) + }) + + It("returns an error when server responds with invalid (corrupted json) data in the response body", func() { + serverResponse = syncResponse{ + statusCode: 200, + respBody: []byte("{w"), + } + _, _, err := MustNewSyncer(server.URL, identifier, h.r).sync(nil) + Expect(err.Error()).To(Equal("invalid character 'w' looking for beginning of object key string")) + }) + + It("returns an error when server responds with no token in the response body", func() { + resp := defaultResponse + resp.Token = "" + respBody, _ := json.Marshal(resp) + serverResponse = syncResponse{ + statusCode: 200, + respBody: respBody, + } + suppressions, _, err := MustNewSyncer(server.URL, identifier, h.r).sync(nil) + Expect(err.Error()).To(Equal("no token returned in regulation API response")) + Expect(suppressions).To(Equal(defaultResponse.Items)) + }) + }) + + Context("handler, repository, syncer integration", func() { + It("exact user suppression match", func() { + respBody, _ := json.Marshal(defaultResponse) + serverResponse = syncResponse{ + expectedUrl: fmt.Sprintf("/dataplane/workspaces/%s/regulations/suppressions", identifier.WorkspaceID), + statusCode: 200, + respBody: respBody, + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s := MustNewSyncer(server.URL, identifier, h.r, WithPollIntervalFn(func() time.Duration { return 1 * time.Millisecond })) + go func() { + s.SyncLoop(ctx) + }() + Eventually(func() bool { return h.IsSuppressedUser("workspace-1", "user-1", "src-1") }).Should(BeTrue()) + }) + + It("user suppression added, then cancelled", func() { + resp := defaultResponse + resp.Items = []model.Suppression{ + defaultSuppression, + { + Canceled: false, + UserID: "user-2", + SourceIDs: []string{"src-1", "src-2"}, + }, + } + respBody, _ := json.Marshal(resp) + serverResponse = syncResponse{ + statusCode: 200, + respBody: respBody, + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s := MustNewSyncer(server.URL, identifier, h.r, WithPollIntervalFn(func() time.Duration { return 1 * time.Millisecond })) + go func() { + s.SyncLoop(ctx) + }() + Eventually(func() bool { return h.IsSuppressedUser("workspace-1", "user-2", "src-1") }).Should(BeTrue()) + Eventually(func() bool { return h.IsSuppressedUser("workspace-1", "user-1", "src-1") }).Should(BeTrue()) + + resp.Items[0].Canceled = true + respBody, _ = json.Marshal(resp) + serverResponse = syncResponse{ + statusCode: 200, + respBody: respBody, + } + Eventually(func() bool { return h.IsSuppressedUser("workspace-1", "user-1", "src-1") }).Should(BeFalse()) + }) + + It("wildcard user suppression match", func() { + resp := defaultResponse + resp.Items = []model.Suppression{ + { + Canceled: false, + UserID: "user-1", + SourceIDs: []string{}, + }, + { + Canceled: false, + UserID: "user-2", + SourceIDs: []string{"src-2"}, + }, + } + respBody, _ := json.Marshal(resp) + serverResponse = syncResponse{ + statusCode: 200, + respBody: respBody, + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s := MustNewSyncer(server.URL, identifier, h.r, WithPollIntervalFn(func() time.Duration { return 1 * time.Millisecond })) + go func() { + s.SyncLoop(ctx) + }() + Eventually(func() bool { return h.IsSuppressedUser("workspace-1", "user-1", "src-1") }).Should(BeTrue()) + Eventually(func() bool { return h.IsSuppressedUser("workspace-1", "user-2", "src-2") }).Should(BeTrue()) + Eventually(func() bool { return h.IsSuppressedUser("workspace-1", "user-2", "src-1") }).Should(BeFalse()) + }) + + It("wildcard user suppression rule added and then cancelled", func() { + resp := defaultResponse + resp.Items = []model.Suppression{ + { + Canceled: false, + UserID: "user-1", + SourceIDs: []string{}, + }, + { + Canceled: false, + UserID: "user-2", + SourceIDs: []string{"src-2"}, + }, + } + respBody, _ := json.Marshal(resp) + serverResponse = syncResponse{ + statusCode: 200, + respBody: respBody, + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s := MustNewSyncer(server.URL, identifier, h.r, WithPollIntervalFn(func() time.Duration { return 1 * time.Millisecond })) + go func() { + s.SyncLoop(ctx) + }() + + Eventually(func() bool { return h.IsSuppressedUser("workspace-1", "user-1", "src-1") }).Should(BeTrue()) + Eventually(func() bool { return h.IsSuppressedUser("workspace-1", "user-2", "src-2") }).Should(BeTrue()) + Eventually(func() bool { return h.IsSuppressedUser("workspace-1", "user-2", "src-1") }).Should(BeFalse()) + + resp.Items[0].Canceled = true + respBody, _ = json.Marshal(resp) + serverResponse = syncResponse{ + statusCode: 200, + respBody: respBody, + } + Eventually(func() bool { return h.IsSuppressedUser("workspace-1", "user-1", "src-1") }).Should(BeFalse()) + }) + + It("try to sync while restoring", func() { + respBody, _ := json.Marshal(defaultResponse) + serverResponse = syncResponse{ + statusCode: 200, + respBody: respBody, + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + var r readerFunc = func(_ []byte) (int, error) { + time.Sleep(1 * time.Second) + return 0, errors.New("read error") + } + go func() { + err := h.r.Restore(r) + Expect(err).To(Not(BeNil())) + }() + s := MustNewSyncer(server.URL, identifier, h.r, WithPollIntervalFn(func() time.Duration { return 1 * time.Millisecond })) + go func() { + s.SyncLoop(ctx) + }() + Eventually(func() bool { return h.IsSuppressedUser("workspace-1", "user-1", "src-1") }, 5*time.Second, 100*time.Millisecond).Should(BeTrue()) + }) + }) + + Context("multi-tenant support", func() { + It("supports older version of regulation service", func() { + // older version of regulation service doesn't return workspaceID as part of the suppressions + resp := defaultResponse + sup := &resp.Items[0] + sup.WorkspaceID = "" + respBody, _ := json.Marshal(resp) + serverResponse = syncResponse{ + statusCode: 200, + respBody: respBody, + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s := MustNewSyncer(server.URL, identifier, h.r, WithPollIntervalFn(func() time.Duration { return 1 * time.Millisecond })) + go func() { + s.SyncLoop(ctx) + }() + Eventually(func() bool { return h.IsSuppressedUser("workspace-1", "user-1", "src-1") }).Should(BeTrue()) + }) + }) + + It("supports syncing suppressions for namespace", func() { + namespaceIdentifier := &identity.Namespace{ + Namespace: "namespace-1", + } + + resp := defaultResponse + resp.Items = []model.Suppression{ + defaultSuppression, + { + WorkspaceID: "workspace-2", + Canceled: false, + UserID: "user-2", + SourceIDs: []string{"src-1", "src-2"}, + }, + } + respBody, _ := json.Marshal(resp) + serverResponse = syncResponse{ + expectedUrl: fmt.Sprintf("/dataplane/namespaces/%s/regulations/suppressions", namespaceIdentifier.Namespace), + statusCode: 200, + respBody: respBody, + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s := MustNewSyncer(server.URL, namespaceIdentifier, h.r, WithPollIntervalFn(func() time.Duration { return 1 * time.Millisecond })) + go func() { + s.SyncLoop(ctx) + }() + Eventually(func() bool { return h.IsSuppressedUser("workspace-2", "user-2", "src-1") }).Should(BeTrue()) + Eventually(func() bool { return h.IsSuppressedUser("workspace-1", "user-1", "src-1") }).Should(BeTrue()) + }) +} + +type readerFunc func(p []byte) (n int, err error) + +func (f readerFunc) Read(p []byte) (n int, err error) { + return f(p) +} diff --git a/enterprise/suppress-user/syncer.go b/enterprise/suppress-user/syncer.go new file mode 100644 index 00000000000..4de681ac9b8 --- /dev/null +++ b/enterprise/suppress-user/syncer.go @@ -0,0 +1,230 @@ +package suppression + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "time" + + "github.com/cenkalti/backoff" + "github.com/rudderlabs/rudder-server/config" + "github.com/rudderlabs/rudder-server/enterprise/suppress-user/model" + "github.com/rudderlabs/rudder-server/services/controlplane/identity" + "github.com/rudderlabs/rudder-server/utils/logger" + "github.com/rudderlabs/rudder-server/utils/misc" + "github.com/rudderlabs/rudder-server/utils/types/deployment" +) + +// SyncerOpt represents a configuration option for the syncer +type SyncerOpt func(*Syncer) + +// WithHttpClient sets the http client to use +func WithHttpClient(client *http.Client) SyncerOpt { + return func(c *Syncer) { + c.client = client + } +} + +// WithPageSize sets the page size for each sync request +func WithPageSize(pageSize int) SyncerOpt { + return func(c *Syncer) { + c.pageSize = pageSize + } +} + +// WithPollIntervalFn sets the interval at which the syncer will poll the backend +func WithPollIntervalFn(pollIntervalFn func() time.Duration) SyncerOpt { + return func(c *Syncer) { + c.pollIntervalFn = pollIntervalFn + } +} + +// WithLogger sets the logger to use in the syncer +func WithLogger(log logger.Logger) SyncerOpt { + return func(c *Syncer) { + c.log = log + } +} + +// MustNewSyncer creates a new syncer, panics if an error occurs +func MustNewSyncer(baseURL string, identifier identity.Identifier, r Repository, opts ...SyncerOpt) *Syncer { + s, err := NewSyncer(baseURL, identifier, r, opts...) + if err != nil { + panic(err) + } + return s +} + +// NewSyncer creates a new syncer +func NewSyncer(baseURL string, identifier identity.Identifier, r Repository, opts ...SyncerOpt) (*Syncer, error) { + var url string + switch identifier.Type() { + case deployment.DedicatedType: + url = fmt.Sprintf("%s/dataplane/workspaces/%s/regulations/suppressions", baseURL, identifier.ID()) + case deployment.MultiTenantType: + url = fmt.Sprintf("%s/dataplane/namespaces/%s/regulations/suppressions", baseURL, identifier.ID()) + default: + return nil, fmt.Errorf("unsupported deployment type: %s", identifier.Type()) + } + + s := &Syncer{ + url: url, + r: r, + log: logger.NOP, + client: &http.Client{}, + pageSize: 100, + pollIntervalFn: func() time.Duration { return 30 * time.Second }, + defaultWorkspaceID: identifier.ID(), + } + for _, opt := range opts { + opt(s) + } + return s, nil +} + +// Syncer is responsible for syncing suppressions from the backend to the repository +type Syncer struct { + url string + r Repository + + client *http.Client + log logger.Logger + pageSize int + pollIntervalFn func() time.Duration + defaultWorkspaceID string +} + +// SyncLoop runs the sync loop until the provided context is done +func (s *Syncer) SyncLoop(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-time.After(s.pollIntervalFn()): + } + again: + s.log.Info("Fetching Regulations") + token, err := s.r.GetToken() + if err != nil { + if errors.Is(err, model.ErrRestoring) { + if err := misc.SleepCtx(ctx, 1*time.Second); err != nil { + return + } + goto again + } + s.log.Errorf("Failed to get token from repository: %w", err) + continue + } + s.log.Info("Fetching Regulations") + suppressions, nextToken, err := s.sync(token) + if err != nil { + continue + } + // TODO: this won't be needed once data regulation service gets updated + for i := range suppressions { + suppression := &suppressions[i] + if suppression.WorkspaceID == "" { + suppression.WorkspaceID = s.defaultWorkspaceID + } + } + err = s.r.Add(suppressions, nextToken) + if err != nil { + s.log.Errorf("Failed to add %d suppressions to repository: %w", len(suppressions), err) + continue + } + if len(suppressions) != 0 { + goto again + } + } +} + +// sync fetches suppressions from the backend +func (s *Syncer) sync(token []byte) ([]model.Suppression, []byte, error) { + urlStr := s.url + urlValQuery := url.Values{} + if s.pageSize > 0 { + urlValQuery.Set("pageSize", strconv.Itoa(s.pageSize)) + } + if len(token) > 0 { + urlValQuery.Set("pageToken", string(token)) + } + if len(urlValQuery) > 0 { + urlStr += "?" + urlValQuery.Encode() + } + + var resp *http.Response + var respBody []byte + + operation := func() error { + var err error + req, err := http.NewRequest("GET", urlStr, http.NoBody) + s.log.Debugf("regulation service URL: %s", urlStr) + if err != nil { + return err + } + workspaceToken := config.GetWorkspaceToken() + req.SetBasicAuth(workspaceToken, "") + req.Header.Set("Content-Type", "application/json") + + resp, err = s.client.Do(req) + if err != nil { + return err + } + defer func() { + err := resp.Body.Close() + if err != nil { + s.log.Error(err) + } + }() + + // If statusCode is not 2xx, then returning empty regulations + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + err = fmt.Errorf("status code %v", resp.StatusCode) + s.log.Errorf("Failed to fetch source regulations. statusCode: %v, error: %v", + resp.StatusCode, err) + return err + } + + respBody, err = io.ReadAll(resp.Body) + if err != nil { + s.log.Error(err) + return err + } + return err + } + + backoffWithMaxRetry := backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 3) + err := backoff.RetryNotify(operation, backoffWithMaxRetry, func(err error, t time.Duration) { + s.log.Errorf("Failed to fetch source regulations from API with error: %v, retrying after %v", err, t) + }) + if err != nil { + s.log.Error("Error sending request to the server: ", err) + return []model.Suppression{}, nil, err + } + if respBody == nil { + s.log.Error("nil response body, returning") + return []model.Suppression{}, nil, errors.New("nil response body") + } + var respJSON suppressionsResponse + err = json.Unmarshal(respBody, &respJSON) + if err != nil { + s.log.Error("Error while parsing response: ", err, resp.StatusCode) + return []model.Suppression{}, nil, err + } + + if respJSON.Token == "" { + s.log.Errorf("No token found in the source regulations response: %v", string(respBody)) + return respJSON.Items, nil, fmt.Errorf("no token returned in regulation API response") + } + return respJSON.Items, []byte(respJSON.Token), nil +} + +type suppressionsResponse struct { + Items []model.Suppression `json:"items"` + Token string `json:"token"` +} diff --git a/enterprise/suppress-user/syncer_test.go b/enterprise/suppress-user/syncer_test.go new file mode 100644 index 00000000000..c75503f7bce --- /dev/null +++ b/enterprise/suppress-user/syncer_test.go @@ -0,0 +1,36 @@ +package suppression + +import ( + "net/http" + "testing" + + "github.com/rudderlabs/rudder-server/services/controlplane/identity" + "github.com/rudderlabs/rudder-server/utils/logger" + "github.com/stretchr/testify/require" +) + +func TestSyncer(t *testing.T) { + t.Run("panic", func(t *testing.T) { + require.Panics(t, func() { + MustNewSyncer( + "", &identity.NOOP{}, nil, + WithHttpClient(&http.Client{}), + WithPageSize(1), + WithLogger(logger.NOP)) + }) + }) + t.Run("options", func(t *testing.T) { + httpClient := &http.Client{} + pageSize := 1 + log := logger.NOP + s := MustNewSyncer( + "", &identity.Workspace{}, nil, + WithHttpClient(httpClient), + WithPageSize(pageSize), + WithLogger(log)) + + require.Equal(t, httpClient, s.client) + require.Equal(t, pageSize, s.pageSize) + require.Equal(t, log, s.log) + }) +} diff --git a/gateway/gateway.go b/gateway/gateway.go index ff8f4a6794e..c40cdfb11e1 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -1624,6 +1624,7 @@ Setup initializes this module: This function will block until backend config is initially received. */ func (gateway *HandleT) Setup( + ctx context.Context, application app.App, backendConfig backendconfig.BackendConfig, jobsDB jobsdb.JobsDB, rateLimiter ratelimiter.RateLimiter, versionHandler func(w http.ResponseWriter, r *http.Request), rsourcesService rsources.JobService, @@ -1679,7 +1680,7 @@ func (gateway *HandleT) Setup( if enableSuppressUserFeature && gateway.application.Features().SuppressUser != nil { var err error - gateway.suppressUserHandler, err = application.Features().SuppressUser.Setup(gateway.backendConfig) + gateway.suppressUserHandler, err = application.Features().SuppressUser.Setup(ctx, gateway.backendConfig) if err != nil { return fmt.Errorf("could not setup suppress user feature: %w", err) } diff --git a/gateway/gateway_test.go b/gateway/gateway_test.go index f7082e0e17c..1aa473a18be 100644 --- a/gateway/gateway_test.go +++ b/gateway/gateway_test.go @@ -182,7 +182,7 @@ var _ = Describe("Gateway Enterprise", func() { c.mockSuppressUserFeature = mocksApp.NewMockSuppressUserFeature(c.mockCtrl) c.initializeEnterpriseAppFeatures() - c.mockSuppressUserFeature.EXPECT().Setup(gomock.Any()).AnyTimes().Return(c.mockSuppressUser, nil) + c.mockSuppressUserFeature.EXPECT().Setup(gomock.Any(), gomock.Any()).AnyTimes().Return(c.mockSuppressUser, nil) c.mockSuppressUser.EXPECT().IsSuppressedUser(WorkspaceID, NormalUserID, SourceIDEnabled).Return(false).AnyTimes() c.mockSuppressUser.EXPECT().IsSuppressedUser(WorkspaceID, SuppressedUserID, SourceIDEnabled).Return(true).AnyTimes() @@ -202,7 +202,7 @@ var _ = Describe("Gateway Enterprise", func() { gateway := &HandleT{} BeforeEach(func() { - err := gateway.Setup(c.mockApp, c.mockBackendConfig, c.mockJobsDB, nil, c.mockVersionHandler, rsources.NewNoOpService()) + err := gateway.Setup(context.Background(), c.mockApp, c.mockBackendConfig, c.mockJobsDB, nil, c.mockVersionHandler, rsources.NewNoOpService()) Expect(err).To(BeNil()) }) @@ -250,7 +250,7 @@ var _ = Describe("Gateway", func() { Context("Initialization", func() { It("should wait for backend config", func() { gateway := &HandleT{} - err := gateway.Setup(c.mockApp, c.mockBackendConfig, c.mockJobsDB, nil, c.mockVersionHandler, rsources.NewNoOpService()) + err := gateway.Setup(context.Background(), c.mockApp, c.mockBackendConfig, c.mockJobsDB, nil, c.mockVersionHandler, rsources.NewNoOpService()) Expect(err).To(BeNil()) err = gateway.Shutdown() Expect(err).To(BeNil()) @@ -272,7 +272,7 @@ var _ = Describe("Gateway", func() { BeforeEach(func() { gateway = &HandleT{} - err := gateway.Setup(c.mockApp, c.mockBackendConfig, c.mockJobsDB, nil, c.mockVersionHandler, rsources.NewNoOpService()) + err := gateway.Setup(context.Background(), c.mockApp, c.mockBackendConfig, c.mockJobsDB, nil, c.mockVersionHandler, rsources.NewNoOpService()) Expect(err).To(BeNil()) }) @@ -375,7 +375,7 @@ var _ = Describe("Gateway", func() { BeforeEach(func() { gateway = &HandleT{} SetEnableRateLimit(true) - err := gateway.Setup(c.mockApp, c.mockBackendConfig, c.mockJobsDB, c.mockRateLimiter, c.mockVersionHandler, rsources.NewNoOpService()) + err := gateway.Setup(context.Background(), c.mockApp, c.mockBackendConfig, c.mockJobsDB, c.mockRateLimiter, c.mockVersionHandler, rsources.NewNoOpService()) Expect(err).To(BeNil()) }) @@ -411,7 +411,7 @@ var _ = Describe("Gateway", func() { BeforeEach(func() { gateway = &HandleT{} - err := gateway.Setup(c.mockApp, c.mockBackendConfig, c.mockJobsDB, nil, c.mockVersionHandler, rsources.NewNoOpService()) + err := gateway.Setup(context.Background(), c.mockApp, c.mockBackendConfig, c.mockJobsDB, nil, c.mockVersionHandler, rsources.NewNoOpService()) Expect(err).To(BeNil()) }) @@ -538,7 +538,7 @@ var _ = Describe("Gateway", func() { BeforeEach(func() { gateway = &HandleT{} - err := gateway.Setup(c.mockApp, c.mockBackendConfig, c.mockJobsDB, nil, c.mockVersionHandler, rsources.NewNoOpService()) + err := gateway.Setup(context.Background(), c.mockApp, c.mockBackendConfig, c.mockJobsDB, nil, c.mockVersionHandler, rsources.NewNoOpService()) Expect(err).To(BeNil()) }) @@ -570,7 +570,7 @@ var _ = Describe("Gateway", func() { Expect(err).To(BeNil()) err = os.Setenv("RSERVER_WAREHOUSE_MODE", config.OffMode) Expect(err).To(BeNil()) - err = gateway.Setup(c.mockApp, c.mockBackendConfig, c.mockJobsDB, nil, c.mockVersionHandler, rsources.NewNoOpService()) + err = gateway.Setup(context.Background(), c.mockApp, c.mockBackendConfig, c.mockJobsDB, nil, c.mockVersionHandler, rsources.NewNoOpService()) Expect(err).To(BeNil()) defer func() { diff --git a/go.mod b/go.mod index 66b88ec6e6a..55bd7c19bae 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,7 @@ require ( github.com/cenkalti/backoff/v4 v4.1.3 github.com/denisenkom/go-mssqldb v0.12.0 github.com/dgraph-io/badger/v2 v2.2007.4 + github.com/dgraph-io/badger/v3 v3.2103.3 github.com/fsnotify/fsnotify v1.5.4 github.com/go-redis/redis v6.15.7+incompatible github.com/gofrs/uuid v4.2.0+incompatible @@ -115,7 +116,7 @@ require ( github.com/coreos/go-systemd/v22 v22.3.2 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/dgraph-io/ristretto v0.0.3-0.20200630154024-f66de99634de // indirect + github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 // indirect github.com/docker/cli v20.10.14+incompatible // indirect github.com/docker/docker v20.10.21+incompatible // indirect @@ -207,8 +208,10 @@ require ( require ( github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.3 // indirect + github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/foxcpp/go-mockdns v1.0.1-0.20220408113050-3599dc5d2c7d github.com/golang-sql/sqlexp v0.0.0-20170517235910-f1bb20e5a188 // indirect + github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect github.com/pelletier/go-toml/v2 v2.0.5 // indirect diff --git a/go.sum b/go.sum index 7f0e28cb5e3..a89efe6b5dd 100644 --- a/go.sum +++ b/go.sum @@ -268,6 +268,7 @@ github.com/census-instrumentation/opencensus-proto v0.3.0/go.mod h1:f6KPmirojxKA github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/checkpoint-restore/go-criu/v5 v5.3.0/go.mod h1:E/eQpaFtUKGOOSEBZgmKAcn+zUUwWxqcaKZlF54wK8E= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= @@ -321,8 +322,11 @@ github.com/denisenkom/go-mssqldb v0.12.0/go.mod h1:iiK0YP1ZeepvmBQk/QpLEhhTNJgfz github.com/devigned/tab v0.1.1/go.mod h1:XG9mPq0dFghrYvoBF3xdRrJzSTX1b7IQrvaL9mzjeJY= github.com/dgraph-io/badger/v2 v2.2007.4 h1:TRWBQg8UrlUhaFdco01nO2uXwzKS7zd+HVdwV/GHc4o= github.com/dgraph-io/badger/v2 v2.2007.4/go.mod h1:vSw/ax2qojzbN6eXHIx6KPKtCSHJN/Uz0X0VPruTIhk= -github.com/dgraph-io/ristretto v0.0.3-0.20200630154024-f66de99634de h1:t0UHb5vdojIDUqktM6+xJAfScFBsVpXZmqC9dsgJmeA= +github.com/dgraph-io/badger/v3 v3.2103.3 h1:s63J1pisDhKpzWslXFe+ChuthuZptpwTE6qEKoczPb4= +github.com/dgraph-io/badger/v3 v3.2103.3/go.mod h1:4MPiseMeDQ3FNCYwRbbcBOGJLf5jsE0PPFzRiKjtcdw= github.com/dgraph-io/ristretto v0.0.3-0.20200630154024-f66de99634de/go.mod h1:KPxhHT9ZxKefz+PCeOGsrHpl1qZ7i70dGTu2u+Ahh6E= +github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8= +github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 h1:fAjc9m62+UWV/WAFKLNi6ZS0675eEUC9y3AlwSbQu1Y= github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= @@ -461,6 +465,7 @@ github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2V github.com/golang-sql/sqlexp v0.0.0-20170517235910-f1bb20e5a188 h1:+eHOFJl1BaXrQxKX+T06f78590z4qA2ZzBTqahsKSE4= github.com/golang-sql/sqlexp v0.0.0-20170517235910-f1bb20e5a188/go.mod h1:vXjM/+wXQnTPR4KqTKDgJukSZ6amVRtWMPEjE6sQoK8= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -508,6 +513,7 @@ github.com/gomodule/redigo v1.8.5/go.mod h1:P9dn9mFrCBvWhGE1wpxx6fgq7BAeLBk+UUUz github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/flatbuffers v1.11.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/flatbuffers v1.12.1/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/flatbuffers v2.0.0+incompatible h1:dicJ2oXwypfwUGnB2/TYWYEKiuk9eYQlQO/AnOHl5mI= github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -1275,6 +1281,7 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= diff --git a/mocks/app/mock_features.go b/mocks/app/mock_features.go index 50831153975..1855e39b5b8 100644 --- a/mocks/app/mock_features.go +++ b/mocks/app/mock_features.go @@ -5,6 +5,7 @@ package mock_app import ( + context "context" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -36,16 +37,16 @@ func (m *MockSuppressUserFeature) EXPECT() *MockSuppressUserFeatureMockRecorder } // Setup mocks base method. -func (m *MockSuppressUserFeature) Setup(arg0 backendconfig.BackendConfig) (types.UserSuppression, error) { +func (m *MockSuppressUserFeature) Setup(arg0 context.Context, arg1 backendconfig.BackendConfig) (types.UserSuppression, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Setup", arg0) + ret := m.ctrl.Call(m, "Setup", arg0, arg1) ret0, _ := ret[0].(types.UserSuppression) ret1, _ := ret[1].(error) return ret0, ret1 } // Setup indicates an expected call of Setup. -func (mr *MockSuppressUserFeatureMockRecorder) Setup(arg0 interface{}) *gomock.Call { +func (mr *MockSuppressUserFeatureMockRecorder) Setup(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Setup", reflect.TypeOf((*MockSuppressUserFeature)(nil).Setup), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Setup", reflect.TypeOf((*MockSuppressUserFeature)(nil).Setup), arg0, arg1) } diff --git a/processor/processor_test.go b/processor/processor_test.go index 3e55f3abe49..fc127db1b94 100644 --- a/processor/processor_test.go +++ b/processor/processor_test.go @@ -2478,8 +2478,8 @@ var _ = Describe("TestConfigFilter", func() { "configFilters": ["long_config1", "long_config2"] } }` - _ = json.Unmarshal([]byte(intgConfigStr), &intgConfig) - _ = json.Unmarshal([]byte(destDefStr), &destDef) + Expect(json.Unmarshal([]byte(intgConfigStr), &intgConfig)).To(BeNil()) + Expect(json.Unmarshal([]byte(destDefStr), &destDef)).To(BeNil()) intgConfig.DestinationDefinition = destDef expectedEvent := transformer.TransformerEventT{ Message: types.SingularEventT{ diff --git a/services/debugger/destination/eventDeliveryStatusUploader_test.go b/services/debugger/destination/eventDeliveryStatusUploader_test.go index 502d1228470..e6a7b13364d 100644 --- a/services/debugger/destination/eventDeliveryStatusUploader_test.go +++ b/services/debugger/destination/eventDeliveryStatusUploader_test.go @@ -161,17 +161,25 @@ var faultyData = DeliveryStatusT{ } type eventDeliveryStatusUploaderContext struct { - asyncHelper testutils.AsyncTestHelper - mockCtrl *gomock.Controller - configInitialised bool - mockBackendConfig *mocksBackendConfig.MockBackendConfig + asyncHelper testutils.AsyncTestHelper + mockCtrl *gomock.Controller } func (c *eventDeliveryStatusUploaderContext) Setup() { c.mockCtrl = gomock.NewController(GinkgoT()) - c.mockBackendConfig = mocksBackendConfig.NewMockBackendConfig(c.mockCtrl) - c.configInitialised = false - Setup(c.mockBackendConfig) + mockBackendConfig := mocksBackendConfig.NewMockBackendConfig(c.mockCtrl) + tFunc := c.asyncHelper.ExpectAndNotifyCallback() + mockBackendConfig.EXPECT().Subscribe(gomock.Any(), backendconfig.TopicBackendConfig). + DoAndReturn(func(ctx context.Context, topic backendconfig.Topic) pubsub.DataChannel { + // on Subscribe, emulate a backend configuration event + ch := make(chan pubsub.DataEvent, 1) + ch <- pubsub.DataEvent{Data: map[string]backendconfig.ConfigT{WorkspaceID: sampleBackendConfig}, Topic: string(topic)} + close(ch) + tFunc() + return ch + }).Times(1) + Setup(mockBackendConfig) + c.asyncHelper.WaitWithTimeout(1 * time.Second) } func initEventDeliveryStatusUploader() { @@ -186,7 +194,6 @@ var _ = Describe("eventDeliveryStatusUploader", func() { var ( c *eventDeliveryStatusUploaderContext deliveryStatus DeliveryStatusT - mockCall *gomock.Call ) BeforeEach(func() { @@ -205,16 +212,6 @@ var _ = Describe("eventDeliveryStatusUploader", func() { EventType: `some_event_type`, } disableEventDeliveryStatusUploads = false - mockCall = c.mockBackendConfig.EXPECT().Subscribe(gomock.Any(), backendconfig.TopicBackendConfig). - DoAndReturn(func(ctx context.Context, topic backendconfig.Topic) pubsub.DataChannel { - // on Subscribe, emulate a backend configuration event - ch := make(chan pubsub.DataEvent, 1) - ch <- pubsub.DataEvent{Data: map[string]backendconfig.ConfigT{WorkspaceID: sampleBackendConfig}, Topic: string(topic)} - c.configInitialised = true - close(ch) - - return ch - }) }) AfterEach(func() { @@ -223,49 +220,20 @@ var _ = Describe("eventDeliveryStatusUploader", func() { Context("RecordEventDeliveryStatus", func() { It("returns false if disableEventDeliveryStatusUploads is true", func() { - tFunc := c.asyncHelper.ExpectAndNotifyCallback() - mockCall.Do(func(ctx context.Context, topic backendconfig.Topic) pubsub.DataChannel { - tFunc() - return make(pubsub.DataChannel) - }).Times(1) - - c.asyncHelper.WaitWithTimeout(5 * time.Second) disableEventDeliveryStatusUploads = true Expect(RecordEventDeliveryStatus(DestinationIDEnabledA, &deliveryStatus)).To(BeFalse()) }) It("returns false if destination_id is not in uploadEnabledDestinationIDs", func() { - tFunc := c.asyncHelper.ExpectAndNotifyCallback() - mockCall.Do(func(ctx context.Context, topic backendconfig.Topic) pubsub.DataChannel { - tFunc() - - return make(pubsub.DataChannel) - }).Times(1) - - c.asyncHelper.WaitWithTimeout(5 * time.Second) Expect(RecordEventDeliveryStatus(DestinationIDEnabledB, &deliveryStatus)).To(BeFalse()) }) It("records events", func() { - tFunc := c.asyncHelper.ExpectAndNotifyCallback() - mockCall.Do(func(ctx context.Context, topic backendconfig.Topic) pubsub.DataChannel { - tFunc() - return make(pubsub.DataChannel) - }).Times(1) - - c.asyncHelper.WaitWithTimeout(5 * time.Second) eventuallyFunc := func() bool { return RecordEventDeliveryStatus(DestinationIDEnabledA, &deliveryStatus) } Eventually(eventuallyFunc).Should(BeTrue()) }) It("transforms payload properly", func() { - tFunc := c.asyncHelper.ExpectAndNotifyCallback() - mockCall.Do(func(ctx context.Context, topic backendconfig.Topic) pubsub.DataChannel { - tFunc() - return nil - }).Times(1) - - c.asyncHelper.WaitWithTimeout(5 * time.Second) edsUploader := EventDeliveryStatusUploader{} rawJSON, err := edsUploader.Transform([]interface{}{&deliveryStatus}) Expect(err).To(BeNil()) @@ -274,13 +242,6 @@ var _ = Describe("eventDeliveryStatusUploader", func() { }) It("sends empty json if transformation fails", func() { - tFunc := c.asyncHelper.ExpectAndNotifyCallback() - mockCall.Do(func(ctx context.Context, topic backendconfig.Topic) pubsub.DataChannel { - tFunc() - return nil - }).Times(1) - - c.asyncHelper.WaitWithTimeout(5 * time.Second) edsUploader := EventDeliveryStatusUploader{} rawJSON, err := edsUploader.Transform([]interface{}{&faultyData}) Expect(err.Error()).To(ContainSubstring("error calling MarshalJSON"))