From 2d54ab0f4b2b21a76239926755b2bb55c00568b9 Mon Sep 17 00:00:00 2001 From: zhzhang Date: Tue, 10 Dec 2024 19:12:52 +0800 Subject: [PATCH 01/34] Fix repo permission check bug --- component/repo.go | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/component/repo.go b/component/repo.go index 7a18fa68..95a73d01 100644 --- a/component/repo.go +++ b/component/repo.go @@ -1429,13 +1429,10 @@ func (c *repoComponentImpl) AllowReadAccess(ctx context.Context, repoType types. } func (c *repoComponentImpl) AllowWriteAccess(ctx context.Context, repoType types.RepositoryType, namespace, name, username string) (bool, error) { - repo, err := c.repoStore.FindByPath(ctx, repoType, namespace, name) + _, err := c.repoStore.FindByPath(ctx, repoType, namespace, name) if err != nil { return false, fmt.Errorf("failed to find repo, error: %w", err) } - if !repo.Private { - return true, nil - } if username == "" { return false, ErrUserNotFound @@ -1445,13 +1442,10 @@ func (c *repoComponentImpl) AllowWriteAccess(ctx context.Context, repoType types } func (c *repoComponentImpl) AllowAdminAccess(ctx context.Context, repoType types.RepositoryType, namespace, name, username string) (bool, error) { - repo, err := c.repoStore.FindByPath(ctx, repoType, namespace, name) + _, err := c.repoStore.FindByPath(ctx, repoType, namespace, name) if err != nil { return false, fmt.Errorf("failed to find repo, error: %w", err) } - if !repo.Private { - return true, nil - } if username == "" { return false, ErrUserNotFound From 9e84a565a58a28ade95b498301b9a8b5bce97c38 Mon Sep 17 00:00:00 2001 From: vincent Date: Wed, 11 Dec 2024 10:39:24 +0800 Subject: [PATCH 02/34] Move crontab job to temporal workflow (#202) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 泽华 --- api/workflow/activity/calc_recom_score.go | 19 +++++ api/workflow/activity/sync_as_client.go | 28 ++++++++ api/workflow/cron_calc_recom_score.go | 32 +++++++++ api/workflow/cron_sync_as_client.go | 32 +++++++++ api/workflow/cron_worker.go | 88 +++++++++++++++++++++++ api/workflow/worker.go | 6 +- cmd/csghub-server/cmd/start/server.go | 11 +++ common/config/config.go | 5 ++ common/config/config.toml.example | 4 ++ component/multi_sync.go | 12 ++-- scripts/init.sh | 63 ---------------- 11 files changed, 229 insertions(+), 71 deletions(-) create mode 100644 api/workflow/activity/calc_recom_score.go create mode 100644 api/workflow/activity/sync_as_client.go create mode 100644 api/workflow/cron_calc_recom_score.go create mode 100644 api/workflow/cron_sync_as_client.go create mode 100644 api/workflow/cron_worker.go diff --git a/api/workflow/activity/calc_recom_score.go b/api/workflow/activity/calc_recom_score.go new file mode 100644 index 00000000..a2704f60 --- /dev/null +++ b/api/workflow/activity/calc_recom_score.go @@ -0,0 +1,19 @@ +package activity + +import ( + "context" + "log/slog" + + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/component" +) + +func CalcRecomScore(ctx context.Context, config *config.Config) error { + c, err := component.NewRecomComponent(config) + if err != nil { + slog.Error("failed to create recom component", "err", err) + return err + } + c.CalculateRecomScore(context.Background()) + return nil +} diff --git a/api/workflow/activity/sync_as_client.go b/api/workflow/activity/sync_as_client.go new file mode 100644 index 00000000..99e4ca3e --- /dev/null +++ b/api/workflow/activity/sync_as_client.go @@ -0,0 +1,28 @@ +package activity + +import ( + "context" + "log/slog" + + "opencsg.com/csghub-server/builder/multisync" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/component" +) + +func SyncAsClient(ctx context.Context, config *config.Config) error { + c, err := component.NewMultiSyncComponent(config) + if err != nil { + slog.Error("failed to create multi sync component", "err", err) + return err + } + syncClientSettingStore := database.NewSyncClientSettingStore() + setting, err := syncClientSettingStore.First(ctx) + if err != nil { + slog.Error("failed to find sync client setting", "error", err) + return err + } + apiDomain := config.MultiSync.SaasAPIDomain + sc := multisync.FromOpenCSG(apiDomain, setting.Token) + return c.SyncAsClient(ctx, sc) +} diff --git a/api/workflow/cron_calc_recom_score.go b/api/workflow/cron_calc_recom_score.go new file mode 100644 index 00000000..d6c44fab --- /dev/null +++ b/api/workflow/cron_calc_recom_score.go @@ -0,0 +1,32 @@ +package workflow + +import ( + "time" + + "go.temporal.io/sdk/temporal" + "go.temporal.io/sdk/workflow" + "opencsg.com/csghub-server/api/workflow/activity" + "opencsg.com/csghub-server/common/config" +) + +func CalcRecomScoreWorkflow(ctx workflow.Context, config *config.Config) error { + logger := workflow.GetLogger(ctx) + logger.Info("calc recom score workflow started") + + retryPolicy := &temporal.RetryPolicy{ + MaximumAttempts: 3, + } + + options := workflow.ActivityOptions{ + StartToCloseTimeout: time.Hour * 1, + RetryPolicy: retryPolicy, + } + + ctx = workflow.WithActivityOptions(ctx, options) + err := workflow.ExecuteActivity(ctx, activity.CalcRecomScore, config).Get(ctx, nil) + if err != nil { + logger.Error("failed to calc recom score", "error", err) + return err + } + return nil +} diff --git a/api/workflow/cron_sync_as_client.go b/api/workflow/cron_sync_as_client.go new file mode 100644 index 00000000..54f5ab6b --- /dev/null +++ b/api/workflow/cron_sync_as_client.go @@ -0,0 +1,32 @@ +package workflow + +import ( + "time" + + "go.temporal.io/sdk/temporal" + "go.temporal.io/sdk/workflow" + "opencsg.com/csghub-server/api/workflow/activity" + "opencsg.com/csghub-server/common/config" +) + +func SyncAsClientWorkflow(ctx workflow.Context, config *config.Config) error { + logger := workflow.GetLogger(ctx) + logger.Info("sync as client workflow started") + + retryPolicy := &temporal.RetryPolicy{ + MaximumAttempts: 3, + } + + options := workflow.ActivityOptions{ + StartToCloseTimeout: time.Hour * 1, + RetryPolicy: retryPolicy, + } + + ctx = workflow.WithActivityOptions(ctx, options) + err := workflow.ExecuteActivity(ctx, activity.SyncAsClient, config).Get(ctx, nil) + if err != nil { + logger.Error("failed to sync as client", "error", err) + return err + } + return nil +} diff --git a/api/workflow/cron_worker.go b/api/workflow/cron_worker.go new file mode 100644 index 00000000..b0409c12 --- /dev/null +++ b/api/workflow/cron_worker.go @@ -0,0 +1,88 @@ +package workflow + +import ( + "context" + "fmt" + + enumspb "go.temporal.io/api/enums/v1" + "go.temporal.io/sdk/client" + "go.temporal.io/sdk/worker" + "opencsg.com/csghub-server/api/workflow/activity" + "opencsg.com/csghub-server/common/config" +) + +const ( + AlreadyScheduledMessage = "schedule with this ID is already registered" + CronJobQueueName = "workflow_cron_queue" +) + +func RegisterCronJobs(config *config.Config) error { + var err error + if wfClient == nil { + wfClient, err = client.Dial(client.Options{ + HostPort: config.WorkFLow.Endpoint, + }) + if err != nil { + return fmt.Errorf("unable to create workflow client, error:%w", err) + } + } + + if !config.Saas { + _, err = wfClient.ScheduleClient().Create(context.Background(), client.ScheduleOptions{ + ID: "sync-as-client-schedule", + Spec: client.ScheduleSpec{ + CronExpressions: []string{config.CronJob.SyncAsClientCronExpression}, + }, + Overlap: enumspb.SCHEDULE_OVERLAP_POLICY_SKIP, + Action: &client.ScheduleWorkflowAction{ + ID: "sync-as-client-workflow", + TaskQueue: CronJobQueueName, + Workflow: SyncAsClientWorkflow, + Args: []interface{}{config}, + }, + }) + if err != nil && err.Error() != AlreadyScheduledMessage { + return fmt.Errorf("unable to create schedule, error:%w", err) + } + } + + _, err = wfClient.ScheduleClient().Create(context.Background(), client.ScheduleOptions{ + ID: "calc-recom-score-schedule", + Spec: client.ScheduleSpec{ + CronExpressions: []string{config.CronJob.CalcRecomScoreCronExpression}, + }, + Overlap: enumspb.SCHEDULE_OVERLAP_POLICY_SKIP, + Action: &client.ScheduleWorkflowAction{ + ID: "calc-recom-score-workflow", + TaskQueue: CronJobQueueName, + Workflow: CalcRecomScoreWorkflow, + Args: []interface{}{config}, + }, + }) + if err != nil && err.Error() != AlreadyScheduledMessage { + return fmt.Errorf("unable to create schedule, error:%w", err) + } + + return nil +} + +func StartCronWorker(config *config.Config) error { + var err error + if wfClient == nil { + wfClient, err = client.Dial(client.Options{ + HostPort: config.WorkFLow.Endpoint, + }) + if err != nil { + return fmt.Errorf("unable to create workflow client, error:%w", err) + } + } + wfWorker = worker.New(wfClient, CronJobQueueName, worker.Options{}) + if !config.Saas { + wfWorker.RegisterWorkflow(SyncAsClientWorkflow) + wfWorker.RegisterActivity(activity.SyncAsClient) + } + wfWorker.RegisterWorkflow(CalcRecomScoreWorkflow) + wfWorker.RegisterActivity(activity.CalcRecomScore) + + return wfWorker.Start() +} diff --git a/api/workflow/worker.go b/api/workflow/worker.go index 973f523d..bad5bd62 100644 --- a/api/workflow/worker.go +++ b/api/workflow/worker.go @@ -11,8 +11,10 @@ import ( const HandlePushQueueName = "workflow_handle_push_queue" -var wfWorker worker.Worker -var wfClient client.Client +var ( + wfWorker worker.Worker + wfClient client.Client +) func StartWorker(config *config.Config) error { var err error diff --git a/cmd/csghub-server/cmd/start/server.go b/cmd/csghub-server/cmd/start/server.go index e66787af..ff24eb0d 100644 --- a/cmd/csghub-server/cmd/start/server.go +++ b/cmd/csghub-server/cmd/start/server.go @@ -91,6 +91,17 @@ var serverCmd = &cobra.Command{ if err != nil { return fmt.Errorf("failed to start worker: %w", err) } + + err = workflow.RegisterCronJobs(cfg) + if err != nil { + return fmt.Errorf("failed to register cron jobs: %w", err) + } + + err = workflow.StartCronWorker(cfg) + if err != nil { + return fmt.Errorf("failed to start cron worker: %w", err) + } + server := httpbase.NewGracefulServer( httpbase.GraceServerOpt{ Port: cfg.APIServer.Port, diff --git a/common/config/config.go b/common/config/config.go index 7d85a642..c69ccc46 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -224,6 +224,11 @@ type Config struct { // S3PublicBucket is used to store public files, should set bucket same with portal S3PublicBucket string `env:"STARHUB_SERVER_ARGO_S3_PUBLIC_BUCKET"` } + + CronJob struct { + SyncAsClientCronExpression string `env:"STARHUB_SERVER_CRON_JOB_SYNC_AS_CLIENT_CRON_EXPRESSION, default=0 * * * *"` + CalcRecomScoreCronExpression string `env:"STARHUB_SERVER_CRON_JOB_CLAC_RECOM_SCORE_CRON_EXPRESSION, default=0 1 * * *"` + } } func SetConfigFile(file string) { diff --git a/common/config/config.toml.example b/common/config/config.toml.example index b2948124..e1073ea0 100644 --- a/common/config/config.toml.example +++ b/common/config/config.toml.example @@ -161,3 +161,7 @@ encoded_sensitive_words = "5Lmg6L+R5bmzLHhpamlucGluZw==" [workflow] endpoint = "localhost:7233" + +[cron_job] +sync_as_client_cron_expression = "0 * * * *" +calc_recom_score_cron_expression = "0 1 * * *" diff --git a/component/multi_sync.go b/component/multi_sync.go index 38e2494d..edb44597 100644 --- a/component/multi_sync.go +++ b/component/multi_sync.go @@ -242,7 +242,7 @@ func (c *multiSyncComponentImpl) createLocalDataset(ctx context.Context, m *type } err = c.repo.DeleteAllTags(ctx, newDBRepo.ID) - if err != nil { + if err != nil && err != sql.ErrNoRows { slog.Error("failed to delete database tag", slog.Any("error", err)) } @@ -253,14 +253,14 @@ func (c *multiSyncComponentImpl) createLocalDataset(ctx context.Context, m *type } err = c.repo.DeleteAllFiles(ctx, newDBRepo.ID) - if err != nil { + if err != nil && err != sql.ErrNoRows { slog.Error("failed to delete database files", slog.Any("error", err)) } ctxGetFileList, cancel := context.WithTimeout(ctx, 5*time.Second) files, err := sc.FileList(ctxGetFileList, s) cancel() - if err != nil { + if err != nil && err != sql.ErrNoRows { slog.Error("failed to get all files of repo", slog.Any("sync_version", s), slog.Any("error", err)) } if len(files) > 0 { @@ -367,7 +367,7 @@ func (c *multiSyncComponentImpl) createLocalModel(ctx context.Context, m *types. }) } err = c.repo.DeleteAllTags(ctx, newDBRepo.ID) - if err != nil { + if err != nil && err != sql.ErrNoRows { slog.Error("failed to delete database tag", slog.Any("error", err)) } err = c.repo.BatchCreateRepoTags(ctx, repoTags) @@ -377,14 +377,14 @@ func (c *multiSyncComponentImpl) createLocalModel(ctx context.Context, m *types. } err = c.repo.DeleteAllFiles(ctx, newDBRepo.ID) - if err != nil { + if err != nil && err != sql.ErrNoRows { slog.Error("failed to delete all files for repo", slog.Any("error", err)) } ctxGetFileList, cancel := context.WithTimeout(ctx, 5*time.Second) files, err := sc.FileList(ctxGetFileList, s) cancel() - if err != nil { + if err != nil && err != sql.ErrNoRows { slog.Error("failed to get all files of repo", slog.Any("sync_version", s), slog.Any("error", err)) } if len(files) > 0 { diff --git a/scripts/init.sh b/scripts/init.sh index 59a8f0e6..b7cc5712 100755 --- a/scripts/init.sh +++ b/scripts/init.sh @@ -72,69 +72,6 @@ if [ "$STARHUB_SERVER_GITSERVER_TYPE" = "gitea" ]; then fi fi - -# Create cron job -cron="" -read_and_set_cron() { - env_variable=$1 - default_value=$2 - - cron=${!env_variable} - - if [[ -z $cron ]]; then - cron=$default_value - fi -} - -current_cron_jobs=$(crontab -l 2>/dev/null) - -if echo "$current_cron_jobs" | grep -qF "starhub logscan gitea"; then - echo "Gitea log scan job already exists" -else - echo "Creating cron job for gitea logscan..." - read_and_set_cron "STARHUB_SERVER_CRON_LOGSCAN" "0 23 * * *" - (crontab -l ;echo "$cron STARHUB_DATABASE_DSN=$STARHUB_DATABASE_DSN /starhub-bin/starhub logscan gitea --path /starhub-bin/logs/gitea.log >> /starhub-bin/cron.log 2>&1") | crontab - -fi - -if echo "$current_cron_jobs" | grep -qF "calc-recom-score"; then - echo "Calculate score job already exists" -else - echo "Creating cron job for repository recommendation score calculation..." - read_and_set_cron "STARHUB_SERVER_CRON_CALC_RECOM_SCORE" "0 1 * * *" - (crontab -l ;echo "$cron STARHUB_DATABASE_DSN=$STARHUB_DATABASE_DSN STARHUB_SERVER_GITSERVER_HOST=$STARHUB_SERVER_GITSERVER_HOST STARHUB_SERVER_GITSERVER_USERNAME=$STARHUB_SERVER_GITSERVER_USERNAME STARHUB_SERVER_GITSERVER_PASSWORD=$STARHUB_SERVER_GITSERVER_PASSWORD /starhub-bin/starhub cron calc-recom-score >> /starhub-bin/cron-calc-recom-score.log 2>&1") | crontab - -fi - -if echo "$current_cron_jobs" | grep -qF "create-push-mirror"; then - echo "Create push mirror job already exists" -else - echo "Creating cron job for push mirror creation..." - read_and_set_cron "STARHUB_SERVER_CRON_PUSH_MIRROR" "*/10 * * * *" - (crontab -l ;echo "$cron STARHUB_DATABASE_DSN=$STARHUB_DATABASE_DSN STARHUB_SERVER_GITSERVER_HOST=$STARHUB_SERVER_GITSERVER_HOST STARHUB_SERVER_GITSERVER_USERNAME=$STARHUB_SERVER_GITSERVER_USERNAME STARHUB_SERVER_GITSERVER_PASSWORD=$STARHUB_SERVER_GITSERVER_PASSWORD STARHUB_SERVER_MIRRORSERVER_HOST=$STARHUB_SERVER_MIRRORSERVER_HOST STARHUB_SERVER_MIRRORSERVER_USERNAME=$STARHUB_SERVER_MIRRORSERVER_USERNAME STARHUB_SERVER_MIRRORSERVER_PASSWORD=$STARHUB_SERVER_MIRRORSERVER_PASSWORD /starhub-bin/starhub cron create-push-mirror >> /starhub-bin/create-push-mirror.log 2>&1") | crontab - -fi - -if echo "$current_cron_jobs" | grep -qF "check-mirror-progress"; then - echo "Check mirror progress job already exists" -else - echo "Creating cron job for update mirror status and progress..." - read_and_set_cron "STARHUB_SERVER_CRON_PUSH_MIRROR" "*/5 * * * *" - (crontab -l ;echo "$cron STARHUB_SERVER_GITSERVER_URL=$STARHUB_SERVER_GITSERVER_URL STARHUB_SERVER_FRONTEND_URL=$STARHUB_SERVER_FRONTEND_URL STARHUB_DATABASE_DSN=$STARHUB_DATABASE_DSN STARHUB_SERVER_GITSERVER_HOST=$STARHUB_SERVER_GITSERVER_HOST STARHUB_SERVER_GITSERVER_USERNAME=$STARHUB_SERVER_GITSERVER_USERNAME STARHUB_SERVER_GITSERVER_PASSWORD=$STARHUB_SERVER_GITSERVER_PASSWORD STARHUB_SERVER_MIRRORSERVER_HOST=$STARHUB_SERVER_MIRRORSERVER_HOST STARHUB_SERVER_MIRRORSERVER_USERNAME=$STARHUB_SERVER_MIRRORSERVER_USERNAME STARHUB_SERVER_MIRRORSERVER_PASSWORD=$STARHUB_SERVER_MIRRORSERVER_PASSWORD STARHUB_SERVER_REDIS_ENDPOINT=$STARHUB_SERVER_REDIS_ENDPOINT STARHUB_SERVER_REDIS_USER=$STARHUB_SERVER_REDIS_USER STARHUB_SERVER_REDIS_PASSWORD=$STARHUB_SERVER_REDIS_PASSWORD /starhub-bin/starhub mirror check-mirror-progress >> /starhub-bin/check-mirror-progress.log 2>&1") | crontab - -fi - -if [ "$STARHUB_SERVER_SAAS" == "false" ]; then - if echo "$current_cron_jobs" | grep -qF "sync-as-client"; then - echo "Sync as client job already exists" - else - echo "Creating cron job for sync saas sync verions..." - read_and_set_cron "STARHUB_SERVER_CRON_SYNC_AS_CLIENT" "*/10 * * * *" - (crontab -l ;echo "$cron STARHUB_SERVER_REDIS_ENDPOINT=$STARHUB_SERVER_REDIS_ENDPOINT STARHUB_SERVER_REDIS_USER=$STARHUB_SERVER_REDIS_USER STARHUB_SERVER_REDIS_PASSWORD=$STARHUB_SERVER_REDIS_PASSWORD STARHUB_DATABASE_DSN=$STARHUB_DATABASE_DSN STARHUB_SERVER_GITSERVER_TYPE=$STARHUB_SERVER_GITSERVER_TYPE STARHUB_SERVER_GITALY_TOKEN=$STARHUB_SERVER_GITALY_TOKEN STARHUB_SERVER_GITALY_SERVER_SOCKET=$STARHUB_SERVER_GITALY_SERVER_SOCKET STARHUB_SERVER_GITSERVER_HOST=$STARHUB_SERVER_GITSERVER_HOST STARHUB_SERVER_GITSERVER_USERNAME=$STARHUB_SERVER_GITSERVER_USERNAME STARHUB_SERVER_GITSERVER_PASSWORD=$STARHUB_SERVER_GITSERVER_PASSWORD /starhub-bin/starhub sync sync-as-client >> /starhub-bin/cron-sync-as-client.log 2>&1") | crontab - - fi -else - echo "Saas does not need sync-as-client cron job" -fi -# Reload cron server -service cron restart -echo "Done." - echo "Database setup..." echo "Migration init" From aec96509b06a34f959bada513ccb67ddae7f476c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 12 Dec 2024 13:15:55 +0800 Subject: [PATCH 03/34] Bump golang.org/x/crypto from 0.27.0 to 0.31.0 (#209) Bumps [golang.org/x/crypto](https://github.com/golang/crypto) from 0.27.0 to 0.31.0. - [Commits](https://github.com/golang/crypto/compare/v0.27.0...v0.31.0) --- updated-dependencies: - dependency-name: golang.org/x/crypto dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 12 ++++++------ go.sum | 20 ++++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/go.mod b/go.mod index de97be62..a4f188ca 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,7 @@ require ( github.com/uptrace/bun/driver/sqliteshim v1.1.16 github.com/uptrace/bun/extra/bundebug v1.1.16 gitlab.com/gitlab-org/gitaly/v16 v16.11.8 + go.temporal.io/api v1.40.0 go.temporal.io/sdk v1.30.0 google.golang.org/grpc v1.66.0 gopkg.in/yaml.v2 v2.4.0 @@ -140,7 +141,6 @@ require ( go.opentelemetry.io/otel v1.24.0 // indirect go.opentelemetry.io/otel/metric v1.24.0 // indirect go.opentelemetry.io/otel/trace v1.24.0 // indirect - go.temporal.io/api v1.40.0 // indirect golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect google.golang.org/api v0.169.0 // indirect @@ -237,13 +237,13 @@ require ( go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.26.0 // indirect golang.org/x/arch v0.8.0 // indirect - golang.org/x/crypto v0.27.0 + golang.org/x/crypto v0.31.0 golang.org/x/net v0.28.0 // indirect golang.org/x/oauth2 v0.22.0 // indirect - golang.org/x/sync v0.8.0 // indirect - golang.org/x/sys v0.25.0 // indirect - golang.org/x/term v0.24.0 // indirect - golang.org/x/text v0.18.0 // indirect + golang.org/x/sync v0.10.0 // indirect + golang.org/x/sys v0.28.0 // indirect + golang.org/x/term v0.27.0 // indirect + golang.org/x/text v0.21.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect diff --git a/go.sum b/go.sum index 84ace3ac..e6efc3c7 100644 --- a/go.sum +++ b/go.sum @@ -717,8 +717,8 @@ golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45 golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= -golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -796,8 +796,8 @@ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -838,8 +838,8 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= -golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -848,8 +848,8 @@ golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo= golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM= -golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= +golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= +golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -860,8 +860,8 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= -golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.0.0-20170424234030-8be79e1e0910/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= From 377dcafec48302d32975c00a6ee5305bae89ef17 Mon Sep 17 00:00:00 2001 From: Yiling-J Date: Mon, 16 Dec 2024 14:12:25 +0800 Subject: [PATCH 04/34] Add More Component Tests (#210) * Merge branch 'feature/component_tests' into 'main' Add dataset_view/git_http component tests See merge request product/starhub/starhub-server!700 * Merge branch 'feature/component_tests' into 'main' Add runtime_arch/mirror component tests See merge request product/starhub/starhub-server!701 * Merge branch 'feature/component_tests' into 'main' Add dataset/collection component tests See merge request product/starhub/starhub-server!703 * Merge branch 'feature/component_tests' into 'main' Add code/multi-sync component tests See merge request product/starhub/starhub-server!709 * fix test * Merge branch 'feature/component_tests' into 'main' Add some component tests See merge request product/starhub/starhub-server!714 --------- Co-authored-by: yiling.ji --- .mockery.yaml | 10 +- Makefile | 2 +- .../builder/multisync/mock_Client.go | 329 +++++++ .../builder/parquet/mock_Reader.go | 156 +++ .../component/mock_SensitiveComponent.go | 211 ++++ .../component/mock_TagComponent.go | 58 -- common/tests/stores.go | 54 ++ component/code.go | 66 +- component/code_test.go | 225 +++++ component/collection.go | 60 +- component/collection_test.go | 220 +++++ component/dataset.go | 80 +- component/dataset_test.go | 251 +++++ component/discussion.go | 52 +- component/discussion_test.go | 54 +- component/git_http.go | 64 +- component/git_http_test.go | 492 ++++++++++ component/internal.go | 28 +- component/internal_test.go | 153 +++ component/mirror_source.go | 20 +- component/mirror_source_test.go | 104 ++ component/model_test.go | 38 +- component/multi_sync.go | 82 +- component/multi_sync_test.go | 174 ++++ component/recom.go | 24 +- component/recom_test.go | 43 +- component/runtime_architecture.go | 74 +- component/runtime_architecture_test.go | 209 ++++ component/space_resource_test.go | 94 ++ component/space_sdk.go | 16 +- component/space_sdk_test.go | 64 ++ component/tag.go | 48 +- component/tag_test.go | 90 ++ component/wire.go | 240 +++++ component/wire_gen_test.go | 900 +++++++++++++++++- component/wireset.go | 233 ++++- go.sum | 1 + 37 files changed, 4581 insertions(+), 438 deletions(-) create mode 100644 _mocks/opencsg.com/csghub-server/builder/multisync/mock_Client.go create mode 100644 _mocks/opencsg.com/csghub-server/builder/parquet/mock_Reader.go create mode 100644 _mocks/opencsg.com/csghub-server/component/mock_SensitiveComponent.go create mode 100644 component/code_test.go create mode 100644 component/collection_test.go create mode 100644 component/dataset_test.go create mode 100644 component/git_http_test.go create mode 100644 component/internal_test.go create mode 100644 component/mirror_source_test.go create mode 100644 component/multi_sync_test.go create mode 100644 component/runtime_architecture_test.go create mode 100644 component/space_resource_test.go create mode 100644 component/space_sdk_test.go create mode 100644 component/tag_test.go diff --git a/.mockery.yaml b/.mockery.yaml index ec11083a..630d57a4 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -16,6 +16,7 @@ packages: AccountingComponent: SpaceComponent: RuntimeArchitectureComponent: + SensitiveComponent: opencsg.com/csghub-server/user/component: config: interfaces: @@ -89,4 +90,11 @@ packages: config: interfaces: AccountingClient: - + opencsg.com/csghub-server/builder/parquet: + config: + interfaces: + Reader: + opencsg.com/csghub-server/builder/multisync: + config: + interfaces: + Client: diff --git a/Makefile b/Makefile index f46e7380..21c7b05a 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ lint: golangci-lint run cover: - go test -coverprofile=cover.out -coverpkg=./... ./... + go test -coverprofile=cover.out ./... go tool cover -html=cover.out -o cover.html open cover.html diff --git a/_mocks/opencsg.com/csghub-server/builder/multisync/mock_Client.go b/_mocks/opencsg.com/csghub-server/builder/multisync/mock_Client.go new file mode 100644 index 00000000..22ca9225 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/builder/multisync/mock_Client.go @@ -0,0 +1,329 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package multisync + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + types "opencsg.com/csghub-server/common/types" +) + +// MockClient is an autogenerated mock type for the Client type +type MockClient struct { + mock.Mock +} + +type MockClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockClient) EXPECT() *MockClient_Expecter { + return &MockClient_Expecter{mock: &_m.Mock} +} + +// DatasetInfo provides a mock function with given fields: ctx, v +func (_m *MockClient) DatasetInfo(ctx context.Context, v types.SyncVersion) (*types.Dataset, error) { + ret := _m.Called(ctx, v) + + if len(ret) == 0 { + panic("no return value specified for DatasetInfo") + } + + var r0 *types.Dataset + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.SyncVersion) (*types.Dataset, error)); ok { + return rf(ctx, v) + } + if rf, ok := ret.Get(0).(func(context.Context, types.SyncVersion) *types.Dataset); ok { + r0 = rf(ctx, v) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Dataset) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.SyncVersion) error); ok { + r1 = rf(ctx, v) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_DatasetInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DatasetInfo' +type MockClient_DatasetInfo_Call struct { + *mock.Call +} + +// DatasetInfo is a helper method to define mock.On call +// - ctx context.Context +// - v types.SyncVersion +func (_e *MockClient_Expecter) DatasetInfo(ctx interface{}, v interface{}) *MockClient_DatasetInfo_Call { + return &MockClient_DatasetInfo_Call{Call: _e.mock.On("DatasetInfo", ctx, v)} +} + +func (_c *MockClient_DatasetInfo_Call) Run(run func(ctx context.Context, v types.SyncVersion)) *MockClient_DatasetInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.SyncVersion)) + }) + return _c +} + +func (_c *MockClient_DatasetInfo_Call) Return(_a0 *types.Dataset, _a1 error) *MockClient_DatasetInfo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_DatasetInfo_Call) RunAndReturn(run func(context.Context, types.SyncVersion) (*types.Dataset, error)) *MockClient_DatasetInfo_Call { + _c.Call.Return(run) + return _c +} + +// FileList provides a mock function with given fields: ctx, v +func (_m *MockClient) FileList(ctx context.Context, v types.SyncVersion) ([]types.File, error) { + ret := _m.Called(ctx, v) + + if len(ret) == 0 { + panic("no return value specified for FileList") + } + + var r0 []types.File + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.SyncVersion) ([]types.File, error)); ok { + return rf(ctx, v) + } + if rf, ok := ret.Get(0).(func(context.Context, types.SyncVersion) []types.File); ok { + r0 = rf(ctx, v) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.File) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.SyncVersion) error); ok { + r1 = rf(ctx, v) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_FileList_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'FileList' +type MockClient_FileList_Call struct { + *mock.Call +} + +// FileList is a helper method to define mock.On call +// - ctx context.Context +// - v types.SyncVersion +func (_e *MockClient_Expecter) FileList(ctx interface{}, v interface{}) *MockClient_FileList_Call { + return &MockClient_FileList_Call{Call: _e.mock.On("FileList", ctx, v)} +} + +func (_c *MockClient_FileList_Call) Run(run func(ctx context.Context, v types.SyncVersion)) *MockClient_FileList_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.SyncVersion)) + }) + return _c +} + +func (_c *MockClient_FileList_Call) Return(_a0 []types.File, _a1 error) *MockClient_FileList_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_FileList_Call) RunAndReturn(run func(context.Context, types.SyncVersion) ([]types.File, error)) *MockClient_FileList_Call { + _c.Call.Return(run) + return _c +} + +// Latest provides a mock function with given fields: ctx, currentVersion +func (_m *MockClient) Latest(ctx context.Context, currentVersion int64) (types.SyncVersionResponse, error) { + ret := _m.Called(ctx, currentVersion) + + if len(ret) == 0 { + panic("no return value specified for Latest") + } + + var r0 types.SyncVersionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int64) (types.SyncVersionResponse, error)); ok { + return rf(ctx, currentVersion) + } + if rf, ok := ret.Get(0).(func(context.Context, int64) types.SyncVersionResponse); ok { + r0 = rf(ctx, currentVersion) + } else { + r0 = ret.Get(0).(types.SyncVersionResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, currentVersion) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_Latest_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Latest' +type MockClient_Latest_Call struct { + *mock.Call +} + +// Latest is a helper method to define mock.On call +// - ctx context.Context +// - currentVersion int64 +func (_e *MockClient_Expecter) Latest(ctx interface{}, currentVersion interface{}) *MockClient_Latest_Call { + return &MockClient_Latest_Call{Call: _e.mock.On("Latest", ctx, currentVersion)} +} + +func (_c *MockClient_Latest_Call) Run(run func(ctx context.Context, currentVersion int64)) *MockClient_Latest_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockClient_Latest_Call) Return(_a0 types.SyncVersionResponse, _a1 error) *MockClient_Latest_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_Latest_Call) RunAndReturn(run func(context.Context, int64) (types.SyncVersionResponse, error)) *MockClient_Latest_Call { + _c.Call.Return(run) + return _c +} + +// ModelInfo provides a mock function with given fields: ctx, v +func (_m *MockClient) ModelInfo(ctx context.Context, v types.SyncVersion) (*types.Model, error) { + ret := _m.Called(ctx, v) + + if len(ret) == 0 { + panic("no return value specified for ModelInfo") + } + + var r0 *types.Model + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.SyncVersion) (*types.Model, error)); ok { + return rf(ctx, v) + } + if rf, ok := ret.Get(0).(func(context.Context, types.SyncVersion) *types.Model); ok { + r0 = rf(ctx, v) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Model) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.SyncVersion) error); ok { + r1 = rf(ctx, v) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_ModelInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ModelInfo' +type MockClient_ModelInfo_Call struct { + *mock.Call +} + +// ModelInfo is a helper method to define mock.On call +// - ctx context.Context +// - v types.SyncVersion +func (_e *MockClient_Expecter) ModelInfo(ctx interface{}, v interface{}) *MockClient_ModelInfo_Call { + return &MockClient_ModelInfo_Call{Call: _e.mock.On("ModelInfo", ctx, v)} +} + +func (_c *MockClient_ModelInfo_Call) Run(run func(ctx context.Context, v types.SyncVersion)) *MockClient_ModelInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.SyncVersion)) + }) + return _c +} + +func (_c *MockClient_ModelInfo_Call) Return(_a0 *types.Model, _a1 error) *MockClient_ModelInfo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_ModelInfo_Call) RunAndReturn(run func(context.Context, types.SyncVersion) (*types.Model, error)) *MockClient_ModelInfo_Call { + _c.Call.Return(run) + return _c +} + +// ReadMeData provides a mock function with given fields: ctx, v +func (_m *MockClient) ReadMeData(ctx context.Context, v types.SyncVersion) (string, error) { + ret := _m.Called(ctx, v) + + if len(ret) == 0 { + panic("no return value specified for ReadMeData") + } + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.SyncVersion) (string, error)); ok { + return rf(ctx, v) + } + if rf, ok := ret.Get(0).(func(context.Context, types.SyncVersion) string); ok { + r0 = rf(ctx, v) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(context.Context, types.SyncVersion) error); ok { + r1 = rf(ctx, v) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_ReadMeData_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReadMeData' +type MockClient_ReadMeData_Call struct { + *mock.Call +} + +// ReadMeData is a helper method to define mock.On call +// - ctx context.Context +// - v types.SyncVersion +func (_e *MockClient_Expecter) ReadMeData(ctx interface{}, v interface{}) *MockClient_ReadMeData_Call { + return &MockClient_ReadMeData_Call{Call: _e.mock.On("ReadMeData", ctx, v)} +} + +func (_c *MockClient_ReadMeData_Call) Run(run func(ctx context.Context, v types.SyncVersion)) *MockClient_ReadMeData_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.SyncVersion)) + }) + return _c +} + +func (_c *MockClient_ReadMeData_Call) Return(_a0 string, _a1 error) *MockClient_ReadMeData_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_ReadMeData_Call) RunAndReturn(run func(context.Context, types.SyncVersion) (string, error)) *MockClient_ReadMeData_Call { + _c.Call.Return(run) + return _c +} + +// NewMockClient creates a new instance of MockClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockClient { + mock := &MockClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/builder/parquet/mock_Reader.go b/_mocks/opencsg.com/csghub-server/builder/parquet/mock_Reader.go new file mode 100644 index 00000000..6e31702e --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/builder/parquet/mock_Reader.go @@ -0,0 +1,156 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package parquet + +import mock "github.com/stretchr/testify/mock" + +// MockReader is an autogenerated mock type for the Reader type +type MockReader struct { + mock.Mock +} + +type MockReader_Expecter struct { + mock *mock.Mock +} + +func (_m *MockReader) EXPECT() *MockReader_Expecter { + return &MockReader_Expecter{mock: &_m.Mock} +} + +// RowCount provides a mock function with given fields: objName +func (_m *MockReader) RowCount(objName string) (int, error) { + ret := _m.Called(objName) + + if len(ret) == 0 { + panic("no return value specified for RowCount") + } + + var r0 int + var r1 error + if rf, ok := ret.Get(0).(func(string) (int, error)); ok { + return rf(objName) + } + if rf, ok := ret.Get(0).(func(string) int); ok { + r0 = rf(objName) + } else { + r0 = ret.Get(0).(int) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(objName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockReader_RowCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RowCount' +type MockReader_RowCount_Call struct { + *mock.Call +} + +// RowCount is a helper method to define mock.On call +// - objName string +func (_e *MockReader_Expecter) RowCount(objName interface{}) *MockReader_RowCount_Call { + return &MockReader_RowCount_Call{Call: _e.mock.On("RowCount", objName)} +} + +func (_c *MockReader_RowCount_Call) Run(run func(objName string)) *MockReader_RowCount_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockReader_RowCount_Call) Return(count int, err error) *MockReader_RowCount_Call { + _c.Call.Return(count, err) + return _c +} + +func (_c *MockReader_RowCount_Call) RunAndReturn(run func(string) (int, error)) *MockReader_RowCount_Call { + _c.Call.Return(run) + return _c +} + +// TopN provides a mock function with given fields: objName, count +func (_m *MockReader) TopN(objName string, count int) ([]string, [][]interface{}, error) { + ret := _m.Called(objName, count) + + if len(ret) == 0 { + panic("no return value specified for TopN") + } + + var r0 []string + var r1 [][]interface{} + var r2 error + if rf, ok := ret.Get(0).(func(string, int) ([]string, [][]interface{}, error)); ok { + return rf(objName, count) + } + if rf, ok := ret.Get(0).(func(string, int) []string); ok { + r0 = rf(objName, count) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(string, int) [][]interface{}); ok { + r1 = rf(objName, count) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([][]interface{}) + } + } + + if rf, ok := ret.Get(2).(func(string, int) error); ok { + r2 = rf(objName, count) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockReader_TopN_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TopN' +type MockReader_TopN_Call struct { + *mock.Call +} + +// TopN is a helper method to define mock.On call +// - objName string +// - count int +func (_e *MockReader_Expecter) TopN(objName interface{}, count interface{}) *MockReader_TopN_Call { + return &MockReader_TopN_Call{Call: _e.mock.On("TopN", objName, count)} +} + +func (_c *MockReader_TopN_Call) Run(run func(objName string, count int)) *MockReader_TopN_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string), args[1].(int)) + }) + return _c +} + +func (_c *MockReader_TopN_Call) Return(columns []string, rows [][]interface{}, err error) *MockReader_TopN_Call { + _c.Call.Return(columns, rows, err) + return _c +} + +func (_c *MockReader_TopN_Call) RunAndReturn(run func(string, int) ([]string, [][]interface{}, error)) *MockReader_TopN_Call { + _c.Call.Return(run) + return _c +} + +// NewMockReader creates a new instance of MockReader. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockReader(t interface { + mock.TestingT + Cleanup(func()) +}) *MockReader { + mock := &MockReader{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/component/mock_SensitiveComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_SensitiveComponent.go new file mode 100644 index 00000000..e8d6517d --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/mock_SensitiveComponent.go @@ -0,0 +1,211 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + types "opencsg.com/csghub-server/common/types" +) + +// MockSensitiveComponent is an autogenerated mock type for the SensitiveComponent type +type MockSensitiveComponent struct { + mock.Mock +} + +type MockSensitiveComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockSensitiveComponent) EXPECT() *MockSensitiveComponent_Expecter { + return &MockSensitiveComponent_Expecter{mock: &_m.Mock} +} + +// CheckImage provides a mock function with given fields: ctx, scenario, ossBucketName, ossObjectName +func (_m *MockSensitiveComponent) CheckImage(ctx context.Context, scenario string, ossBucketName string, ossObjectName string) (bool, error) { + ret := _m.Called(ctx, scenario, ossBucketName, ossObjectName) + + if len(ret) == 0 { + panic("no return value specified for CheckImage") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (bool, error)); ok { + return rf(ctx, scenario, ossBucketName, ossObjectName) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) bool); ok { + r0 = rf(ctx, scenario, ossBucketName, ossObjectName) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, scenario, ossBucketName, ossObjectName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockSensitiveComponent_CheckImage_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckImage' +type MockSensitiveComponent_CheckImage_Call struct { + *mock.Call +} + +// CheckImage is a helper method to define mock.On call +// - ctx context.Context +// - scenario string +// - ossBucketName string +// - ossObjectName string +func (_e *MockSensitiveComponent_Expecter) CheckImage(ctx interface{}, scenario interface{}, ossBucketName interface{}, ossObjectName interface{}) *MockSensitiveComponent_CheckImage_Call { + return &MockSensitiveComponent_CheckImage_Call{Call: _e.mock.On("CheckImage", ctx, scenario, ossBucketName, ossObjectName)} +} + +func (_c *MockSensitiveComponent_CheckImage_Call) Run(run func(ctx context.Context, scenario string, ossBucketName string, ossObjectName string)) *MockSensitiveComponent_CheckImage_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MockSensitiveComponent_CheckImage_Call) Return(_a0 bool, _a1 error) *MockSensitiveComponent_CheckImage_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSensitiveComponent_CheckImage_Call) RunAndReturn(run func(context.Context, string, string, string) (bool, error)) *MockSensitiveComponent_CheckImage_Call { + _c.Call.Return(run) + return _c +} + +// CheckRequestV2 provides a mock function with given fields: ctx, req +func (_m *MockSensitiveComponent) CheckRequestV2(ctx context.Context, req types.SensitiveRequestV2) (bool, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CheckRequestV2") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.SensitiveRequestV2) (bool, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.SensitiveRequestV2) bool); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context, types.SensitiveRequestV2) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockSensitiveComponent_CheckRequestV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckRequestV2' +type MockSensitiveComponent_CheckRequestV2_Call struct { + *mock.Call +} + +// CheckRequestV2 is a helper method to define mock.On call +// - ctx context.Context +// - req types.SensitiveRequestV2 +func (_e *MockSensitiveComponent_Expecter) CheckRequestV2(ctx interface{}, req interface{}) *MockSensitiveComponent_CheckRequestV2_Call { + return &MockSensitiveComponent_CheckRequestV2_Call{Call: _e.mock.On("CheckRequestV2", ctx, req)} +} + +func (_c *MockSensitiveComponent_CheckRequestV2_Call) Run(run func(ctx context.Context, req types.SensitiveRequestV2)) *MockSensitiveComponent_CheckRequestV2_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.SensitiveRequestV2)) + }) + return _c +} + +func (_c *MockSensitiveComponent_CheckRequestV2_Call) Return(_a0 bool, _a1 error) *MockSensitiveComponent_CheckRequestV2_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSensitiveComponent_CheckRequestV2_Call) RunAndReturn(run func(context.Context, types.SensitiveRequestV2) (bool, error)) *MockSensitiveComponent_CheckRequestV2_Call { + _c.Call.Return(run) + return _c +} + +// CheckText provides a mock function with given fields: ctx, scenario, text +func (_m *MockSensitiveComponent) CheckText(ctx context.Context, scenario string, text string) (bool, error) { + ret := _m.Called(ctx, scenario, text) + + if len(ret) == 0 { + panic("no return value specified for CheckText") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (bool, error)); ok { + return rf(ctx, scenario, text) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) bool); ok { + r0 = rf(ctx, scenario, text) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, scenario, text) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockSensitiveComponent_CheckText_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckText' +type MockSensitiveComponent_CheckText_Call struct { + *mock.Call +} + +// CheckText is a helper method to define mock.On call +// - ctx context.Context +// - scenario string +// - text string +func (_e *MockSensitiveComponent_Expecter) CheckText(ctx interface{}, scenario interface{}, text interface{}) *MockSensitiveComponent_CheckText_Call { + return &MockSensitiveComponent_CheckText_Call{Call: _e.mock.On("CheckText", ctx, scenario, text)} +} + +func (_c *MockSensitiveComponent_CheckText_Call) Run(run func(ctx context.Context, scenario string, text string)) *MockSensitiveComponent_CheckText_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MockSensitiveComponent_CheckText_Call) Return(_a0 bool, _a1 error) *MockSensitiveComponent_CheckText_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSensitiveComponent_CheckText_Call) RunAndReturn(run func(context.Context, string, string) (bool, error)) *MockSensitiveComponent_CheckText_Call { + _c.Call.Return(run) + return _c +} + +// NewMockSensitiveComponent creates a new instance of MockSensitiveComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockSensitiveComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockSensitiveComponent { + mock := &MockSensitiveComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/component/mock_TagComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_TagComponent.go index af076cd7..05021f20 100644 --- a/_mocks/opencsg.com/csghub-server/component/mock_TagComponent.go +++ b/_mocks/opencsg.com/csghub-server/component/mock_TagComponent.go @@ -24,64 +24,6 @@ func (_m *MockTagComponent) EXPECT() *MockTagComponent_Expecter { return &MockTagComponent_Expecter{mock: &_m.Mock} } -// AllTags provides a mock function with given fields: ctx -func (_m *MockTagComponent) AllTags(ctx context.Context) ([]database.Tag, error) { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for AllTags") - } - - var r0 []database.Tag - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) ([]database.Tag, error)); ok { - return rf(ctx) - } - if rf, ok := ret.Get(0).(func(context.Context) []database.Tag); ok { - r0 = rf(ctx) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]database.Tag) - } - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockTagComponent_AllTags_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AllTags' -type MockTagComponent_AllTags_Call struct { - *mock.Call -} - -// AllTags is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockTagComponent_Expecter) AllTags(ctx interface{}) *MockTagComponent_AllTags_Call { - return &MockTagComponent_AllTags_Call{Call: _e.mock.On("AllTags", ctx)} -} - -func (_c *MockTagComponent_AllTags_Call) Run(run func(ctx context.Context)) *MockTagComponent_AllTags_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) - }) - return _c -} - -func (_c *MockTagComponent_AllTags_Call) Return(_a0 []database.Tag, _a1 error) *MockTagComponent_AllTags_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockTagComponent_AllTags_Call) RunAndReturn(run func(context.Context) ([]database.Tag, error)) *MockTagComponent_AllTags_Call { - _c.Call.Return(run) - return _c -} - // AllTagsByScopeAndCategory provides a mock function with given fields: ctx, scope, category func (_m *MockTagComponent) AllTagsByScopeAndCategory(ctx context.Context, scope string, category string) ([]*database.Tag, error) { ret := _m.Called(ctx, scope, category) diff --git a/common/tests/stores.go b/common/tests/stores.go index 46380274..2197113f 100644 --- a/common/tests/stores.go +++ b/common/tests/stores.go @@ -21,6 +21,7 @@ type MockStores struct { Prompt database.PromptStore Namespace database.NamespaceStore LfsMetaObject database.LfsMetaObjectStore + LfsLock database.LfsLockStore Mirror database.MirrorStore MirrorSource database.MirrorSourceStore AccessToken database.AccessTokenStore @@ -35,6 +36,14 @@ type MockStores struct { SpaceSdk database.SpaceSdkStore Recom database.RecomStore RepoRuntimeFramework database.RepositoriesRuntimeFrameworkStore + Discussion database.DiscussionStore + RuntimeArch database.RuntimeArchitecturesStore + ResourceModel database.ResourceModelStore + GitServerAccessToken database.GitServerAccessTokenStore + Org database.OrgStore + MultiSync database.MultiSyncStore + File database.FileStore + SSH database.SSHKeyStore } func NewMockStores(t interface { @@ -56,6 +65,7 @@ func NewMockStores(t interface { Prompt: mockdb.NewMockPromptStore(t), Namespace: mockdb.NewMockNamespaceStore(t), LfsMetaObject: mockdb.NewMockLfsMetaObjectStore(t), + LfsLock: mockdb.NewMockLfsLockStore(t), Mirror: mockdb.NewMockMirrorStore(t), MirrorSource: mockdb.NewMockMirrorSourceStore(t), AccessToken: mockdb.NewMockAccessTokenStore(t), @@ -70,6 +80,14 @@ func NewMockStores(t interface { SpaceSdk: mockdb.NewMockSpaceSdkStore(t), Recom: mockdb.NewMockRecomStore(t), RepoRuntimeFramework: mockdb.NewMockRepositoriesRuntimeFrameworkStore(t), + Discussion: mockdb.NewMockDiscussionStore(t), + RuntimeArch: mockdb.NewMockRuntimeArchitecturesStore(t), + ResourceModel: mockdb.NewMockResourceModelStore(t), + GitServerAccessToken: mockdb.NewMockGitServerAccessTokenStore(t), + Org: mockdb.NewMockOrgStore(t), + MultiSync: mockdb.NewMockMultiSyncStore(t), + File: mockdb.NewMockFileStore(t), + SSH: mockdb.NewMockSSHKeyStore(t), } } @@ -129,6 +147,10 @@ func (s *MockStores) LfsMetaObjectMock() *mockdb.MockLfsMetaObjectStore { return s.LfsMetaObject.(*mockdb.MockLfsMetaObjectStore) } +func (s *MockStores) LfsLockMock() *mockdb.MockLfsLockStore { + return s.LfsLock.(*mockdb.MockLfsLockStore) +} + func (s *MockStores) MirrorMock() *mockdb.MockMirrorStore { return s.Mirror.(*mockdb.MockMirrorStore) } @@ -184,3 +206,35 @@ func (s *MockStores) RecomMock() *mockdb.MockRecomStore { func (s *MockStores) RepoRuntimeFrameworkMock() *mockdb.MockRepositoriesRuntimeFrameworkStore { return s.RepoRuntimeFramework.(*mockdb.MockRepositoriesRuntimeFrameworkStore) } + +func (s *MockStores) DiscussionMock() *mockdb.MockDiscussionStore { + return s.Discussion.(*mockdb.MockDiscussionStore) +} + +func (s *MockStores) RuntimeArchMock() *mockdb.MockRuntimeArchitecturesStore { + return s.RuntimeArch.(*mockdb.MockRuntimeArchitecturesStore) +} + +func (s *MockStores) ResourceModelMock() *mockdb.MockResourceModelStore { + return s.ResourceModel.(*mockdb.MockResourceModelStore) +} + +func (s *MockStores) GitServerAccessTokenMock() *mockdb.MockGitServerAccessTokenStore { + return s.GitServerAccessToken.(*mockdb.MockGitServerAccessTokenStore) +} + +func (s *MockStores) OrgMock() *mockdb.MockOrgStore { + return s.Org.(*mockdb.MockOrgStore) +} + +func (s *MockStores) MultiSyncMock() *mockdb.MockMultiSyncStore { + return s.MultiSync.(*mockdb.MockMultiSyncStore) +} + +func (s *MockStores) FileMock() *mockdb.MockFileStore { + return s.File.(*mockdb.MockFileStore) +} + +func (s *MockStores) SSHMock() *mockdb.MockSSHKeyStore { + return s.SSH.(*mockdb.MockSSHKeyStore) +} diff --git a/component/code.go b/component/code.go index d2cf14af..07f7a8b1 100644 --- a/component/code.go +++ b/component/code.go @@ -5,7 +5,10 @@ import ( "fmt" "log/slog" + "opencsg.com/csghub-server/builder/git" + "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/git/membership" + "opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/types" @@ -27,19 +30,32 @@ type CodeComponent interface { func NewCodeComponent(config *config.Config) (CodeComponent, error) { c := &codeComponentImpl{} var err error - c.repoComponentImpl, err = NewRepoComponentImpl(config) + c.repoComponent, err = NewRepoComponentImpl(config) if err != nil { return nil, err } - c.cs = database.NewCodeStore() - c.rs = database.NewRepoStore() + c.codeStore = database.NewCodeStore() + c.repoStore = database.NewRepoStore() + gs, err := git.NewGitServer(config) + if err != nil { + return nil, fmt.Errorf("failed to create git server, error: %w", err) + } + c.gitServer = gs + c.config = config + c.userLikesStore = database.NewUserLikesStore() + c.userSvcClient = rpc.NewUserSvcHttpClient(fmt.Sprintf("%s:%d", config.User.Host, config.User.Port), + rpc.AuthWithApiKey(config.APIToken)) return c, nil } type codeComponentImpl struct { - *repoComponentImpl - cs database.CodeStore - rs database.RepoStore + config *config.Config + repoComponent RepoComponent + codeStore database.CodeStore + repoStore database.RepoStore + userLikesStore database.UserLikesStore + gitServer gitserver.GitServer + userSvcClient rpc.UserSvcClient } func (c *codeComponentImpl) Create(ctx context.Context, req *types.CreateCodeReq) (*types.Code, error) { @@ -61,7 +77,7 @@ func (c *codeComponentImpl) Create(ctx context.Context, req *types.CreateCodeReq req.RepoType = types.CodeRepo req.Readme = generateReadmeData(req.License) req.Nickname = nickname - _, dbRepo, err := c.CreateRepo(ctx, req.CreateRepoReq) + _, dbRepo, err := c.repoComponent.CreateRepo(ctx, req.CreateRepoReq) if err != nil { return nil, err } @@ -71,13 +87,13 @@ func (c *codeComponentImpl) Create(ctx context.Context, req *types.CreateCodeReq RepositoryID: dbRepo.ID, } - code, err := c.cs.Create(ctx, dbCode) + code, err := c.codeStore.Create(ctx, dbCode) if err != nil { return nil, fmt.Errorf("failed to create database code, cause: %w", err) } // Create README.md file - err = c.git.CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ + err = c.gitServer.CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ Username: dbRepo.User.Username, Email: dbRepo.User.Email, Message: initCommitMessage, @@ -93,7 +109,7 @@ func (c *codeComponentImpl) Create(ctx context.Context, req *types.CreateCodeReq } // Create .gitattributes file - err = c.git.CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ + err = c.gitServer.CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ Username: dbRepo.User.Username, Email: dbRepo.User.Email, Message: initCommitMessage, @@ -149,7 +165,7 @@ func (c *codeComponentImpl) Index(ctx context.Context, filter *types.RepoFilter, err error resCodes []types.Code ) - repos, total, err := c.PublicToUser(ctx, types.CodeRepo, filter.Username, filter, per, page) + repos, total, err := c.repoComponent.PublicToUser(ctx, types.CodeRepo, filter.Username, filter, per, page) if err != nil { newError := fmt.Errorf("failed to get public code repos,error:%w", err) return nil, 0, newError @@ -158,7 +174,7 @@ func (c *codeComponentImpl) Index(ctx context.Context, filter *types.RepoFilter, for _, repo := range repos { repoIDs = append(repoIDs, repo.ID) } - codes, err := c.cs.ByRepoIDs(ctx, repoIDs) + codes, err := c.codeStore.ByRepoIDs(ctx, repoIDs) if err != nil { newError := fmt.Errorf("failed to get codes by repo ids,error:%w", err) return nil, 0, newError @@ -210,18 +226,18 @@ func (c *codeComponentImpl) Index(ctx context.Context, filter *types.RepoFilter, func (c *codeComponentImpl) Update(ctx context.Context, req *types.UpdateCodeReq) (*types.Code, error) { req.RepoType = types.CodeRepo - dbRepo, err := c.UpdateRepo(ctx, req.UpdateRepoReq) + dbRepo, err := c.repoComponent.UpdateRepo(ctx, req.UpdateRepoReq) if err != nil { return nil, err } - code, err := c.cs.ByRepoID(ctx, dbRepo.ID) + code, err := c.codeStore.ByRepoID(ctx, dbRepo.ID) if err != nil { return nil, fmt.Errorf("failed to find code repo, error: %w", err) } //update times of code - err = c.cs.Update(ctx, *code) + err = c.codeStore.Update(ctx, *code) if err != nil { return nil, fmt.Errorf("failed to update database code repo, error: %w", err) } @@ -244,7 +260,7 @@ func (c *codeComponentImpl) Update(ctx context.Context, req *types.UpdateCodeReq } func (c *codeComponentImpl) Delete(ctx context.Context, namespace, name, currentUser string) error { - code, err := c.cs.FindByPath(ctx, namespace, name) + code, err := c.codeStore.FindByPath(ctx, namespace, name) if err != nil { return fmt.Errorf("failed to find code, error: %w", err) } @@ -255,12 +271,12 @@ func (c *codeComponentImpl) Delete(ctx context.Context, namespace, name, current Name: name, RepoType: types.CodeRepo, } - _, err = c.DeleteRepo(ctx, deleteDatabaseRepoReq) + _, err = c.repoComponent.DeleteRepo(ctx, deleteDatabaseRepoReq) if err != nil { return fmt.Errorf("failed to delete repo of code, error: %w", err) } - err = c.cs.Delete(ctx, *code) + err = c.codeStore.Delete(ctx, *code) if err != nil { return fmt.Errorf("failed to delete database code, error: %w", err) } @@ -269,12 +285,12 @@ func (c *codeComponentImpl) Delete(ctx context.Context, namespace, name, current func (c *codeComponentImpl) Show(ctx context.Context, namespace, name, currentUser string) (*types.Code, error) { var tags []types.RepoTag - code, err := c.cs.FindByPath(ctx, namespace, name) + code, err := c.codeStore.FindByPath(ctx, namespace, name) if err != nil { return nil, fmt.Errorf("failed to find code, error: %w", err) } - permission, err := c.GetUserRepoPermission(ctx, currentUser, code.Repository) + permission, err := c.repoComponent.GetUserRepoPermission(ctx, currentUser, code.Repository) if err != nil { return nil, fmt.Errorf("failed to get user repo permission, error: %w", err) } @@ -282,7 +298,7 @@ func (c *codeComponentImpl) Show(ctx context.Context, namespace, name, currentUs return nil, ErrUnauthorized } - ns, err := c.GetNameSpaceInfo(ctx, namespace) + ns, err := c.repoComponent.GetNameSpaceInfo(ctx, namespace) if err != nil { return nil, fmt.Errorf("failed to get namespace info for code, error: %w", err) } @@ -338,12 +354,12 @@ func (c *codeComponentImpl) Show(ctx context.Context, namespace, name, currentUs } func (c *codeComponentImpl) Relations(ctx context.Context, namespace, name, currentUser string) (*types.Relations, error) { - code, err := c.cs.FindByPath(ctx, namespace, name) + code, err := c.codeStore.FindByPath(ctx, namespace, name) if err != nil { return nil, fmt.Errorf("failed to find code repo, error: %w", err) } - allow, _ := c.AllowReadAccessRepo(ctx, code.Repository, currentUser) + allow, _ := c.repoComponent.AllowReadAccessRepo(ctx, code.Repository, currentUser) if !allow { return nil, ErrUnauthorized } @@ -352,7 +368,7 @@ func (c *codeComponentImpl) Relations(ctx context.Context, namespace, name, curr } func (c *codeComponentImpl) getRelations(ctx context.Context, repoID int64, currentUser string) (*types.Relations, error) { - res, err := c.RelatedRepos(ctx, repoID, currentUser) + res, err := c.repoComponent.RelatedRepos(ctx, repoID, currentUser) if err != nil { return nil, err } @@ -387,7 +403,7 @@ func (c *codeComponentImpl) OrgCodes(ctx context.Context, req *types.OrgCodesReq } } onlyPublic := !r.CanRead() - codes, total, err := c.cs.ByOrgPath(ctx, req.Namespace, req.PageSize, req.Page, onlyPublic) + codes, total, err := c.codeStore.ByOrgPath(ctx, req.Namespace, req.PageSize, req.Page, onlyPublic) if err != nil { newError := fmt.Errorf("failed to get org codes,error:%w", err) slog.Error(newError.Error()) diff --git a/component/code_test.go b/component/code_test.go new file mode 100644 index 00000000..4dff2019 --- /dev/null +++ b/component/code_test.go @@ -0,0 +1,225 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/git/membership" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestCodeComponent_Create(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCodeComponent(ctx, t) + + req := &types.CreateCodeReq{ + CreateRepoReq: types.CreateRepoReq{ + Username: "user", + Namespace: "ns", + Name: "n", + License: "l", + Readme: "r", + }, + } + dbrepo := &database.Repository{ + ID: 1, + User: database.User{Username: "user"}, + Tags: []database.Tag{{Name: "t1"}}, + } + crq := req.CreateRepoReq + crq.Nickname = "n" + crq.Readme = generateReadmeData(req.License) + crq.RepoType = types.CodeRepo + crq.DefaultBranch = "main" + cc.mocks.components.repo.EXPECT().CreateRepo(ctx, crq).Return( + nil, dbrepo, nil, + ) + cc.mocks.stores.CodeMock().EXPECT().Create(ctx, database.Code{ + Repository: dbrepo, + RepositoryID: 1, + }).Return(&database.Code{ + RepositoryID: 1, + Repository: dbrepo, + }, nil) + cc.mocks.gitServer.EXPECT().CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ + Username: "user", + Message: initCommitMessage, + Branch: "main", + Content: crq.Readme, + NewBranch: "main", + Namespace: "ns", + Name: "n", + FilePath: readmeFileName, + }, types.CodeRepo)).Return(nil) + cc.mocks.gitServer.EXPECT().CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ + Username: "user", + Message: initCommitMessage, + Branch: "main", + Content: codeGitattributesContent, + NewBranch: "main", + Namespace: "ns", + Name: "n", + FilePath: gitattributesFileName, + }, types.CodeRepo)).Return(nil) + + resp, err := cc.Create(ctx, req) + require.Nil(t, err) + require.Equal(t, &types.Code{ + RepositoryID: 1, + User: types.User{ + Username: "user", + }, + Repository: types.Repository{ + HTTPCloneURL: "/s/.git", + SSHCloneURL: ":s/.git", + }, + Tags: []types.RepoTag{{Name: "t1"}}, + }, resp) +} + +func TestCodeComponent_Index(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCodeComponent(ctx, t) + + filter := &types.RepoFilter{Username: "user"} + repos := []*database.Repository{ + {ID: 1, Name: "r1", Tags: []database.Tag{{Name: "t1"}}}, + {ID: 2, Name: "r2"}, + } + cc.mocks.components.repo.EXPECT().PublicToUser(ctx, types.CodeRepo, "user", filter, 10, 1).Return( + repos, 100, nil, + ) + cc.mocks.stores.CodeMock().EXPECT().ByRepoIDs(ctx, []int64{1, 2}).Return([]database.Code{ + {ID: 11, RepositoryID: 2}, + {ID: 12, RepositoryID: 1}, + }, nil) + + data, total, err := cc.Index(ctx, filter, 10, 1) + require.Nil(t, err) + require.Equal(t, 100, total) + require.Equal(t, []types.Code{ + {ID: 12, RepositoryID: 1, Name: "r1", Tags: []types.RepoTag{{Name: "t1"}}}, + {ID: 11, RepositoryID: 2, Name: "r2"}, + }, data) +} + +func TestCodeComponent_Update(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCodeComponent(ctx, t) + + req := &types.UpdateCodeReq{ + UpdateRepoReq: types.UpdateRepoReq{ + RepoType: types.CodeRepo, + }, + } + dbrepo := &database.Repository{Name: "name"} + cc.mocks.components.repo.EXPECT().UpdateRepo(ctx, req.UpdateRepoReq).Return(dbrepo, nil) + cc.mocks.stores.CodeMock().EXPECT().ByRepoID(ctx, dbrepo.ID).Return(&database.Code{ID: 1}, nil) + cc.mocks.stores.CodeMock().EXPECT().Update(ctx, database.Code{ + ID: 1, + }).Return(nil) + + data, err := cc.Update(ctx, req) + require.Nil(t, err) + require.Equal(t, &types.Code{ID: 1, Name: "name"}, data) + +} + +func TestCodeComponent_Delete(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCodeComponent(ctx, t) + + cc.mocks.stores.CodeMock().EXPECT().FindByPath(ctx, "ns", "n").Return(&database.Code{}, nil) + cc.mocks.components.repo.EXPECT().DeleteRepo(ctx, types.DeleteRepoReq{ + Username: "user", + Namespace: "ns", + Name: "n", + RepoType: types.CodeRepo, + }).Return(nil, nil) + cc.mocks.stores.CodeMock().EXPECT().Delete(ctx, database.Code{}).Return(nil) + + err := cc.Delete(ctx, "ns", "n", "user") + require.Nil(t, err) +} + +func TestCodeComponent_Show(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCodeComponent(ctx, t) + + code := &database.Code{ID: 1, Repository: &database.Repository{ + ID: 11, Name: "name", User: database.User{Username: "user"}, + }} + cc.mocks.stores.CodeMock().EXPECT().FindByPath(ctx, "ns", "n").Return(code, nil) + cc.mocks.components.repo.EXPECT().GetUserRepoPermission(ctx, "user", code.Repository).Return( + &types.UserRepoPermission{CanRead: true, CanAdmin: true}, nil, + ) + cc.mocks.stores.UserLikesMock().EXPECT().IsExist(ctx, "user", int64(11)).Return(true, nil) + cc.mocks.components.repo.EXPECT().GetNameSpaceInfo(ctx, "ns").Return(&types.Namespace{}, nil) + + data, err := cc.Show(ctx, "ns", "n", "user") + require.Nil(t, err) + require.Equal(t, &types.Code{ + ID: 1, + Repository: types.Repository{ + HTTPCloneURL: "/s/.git", + SSHCloneURL: ":s/.git", + }, + RepositoryID: 11, + Namespace: &types.Namespace{}, + Name: "name", + User: types.User{Username: "user"}, + CanManage: true, + UserLikes: true, + }, data) +} + +func TestCodeComponent_Relations(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCodeComponent(ctx, t) + + cc.mocks.stores.CodeMock().EXPECT().FindByPath(ctx, "ns", "n").Return(&database.Code{ + Repository: &database.Repository{}, + RepositoryID: 1, + }, nil) + cc.mocks.components.repo.EXPECT().AllowReadAccessRepo(ctx, &database.Repository{}, "user").Return(true, nil) + cc.mocks.components.repo.EXPECT().RelatedRepos(ctx, int64(1), "user").Return( + map[types.RepositoryType][]*database.Repository{ + types.ModelRepo: { + {Name: "r1"}, + }, + }, nil, + ) + + data, err := cc.Relations(ctx, "ns", "n", "user") + require.Nil(t, err) + require.Equal(t, &types.Relations{ + Models: []*types.Model{{Name: "r1"}}, + }, data) + +} + +func TestCodeComponent_OrgCodes(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCodeComponent(ctx, t) + + cc.mocks.userSvcClient.EXPECT().GetMemberRole(ctx, "ns", "user").Return(membership.RoleAdmin, nil) + cc.mocks.stores.CodeMock().EXPECT().ByOrgPath(ctx, "ns", 10, 1, false).Return( + []database.Code{{ + ID: 1, Repository: &database.Repository{Name: "repo"}, + RepositoryID: 11, + }}, 100, nil, + ) + + data, total, err := cc.OrgCodes(ctx, &types.OrgDatasetsReq{ + Namespace: "ns", CurrentUser: "user", + PageOpts: types.PageOpts{Page: 1, PageSize: 10}, + }) + require.Nil(t, err) + require.Equal(t, 100, total) + require.Equal(t, []types.Code{ + {ID: 1, Name: "repo", RepositoryID: 11}, + }, data) + +} diff --git a/component/collection.go b/component/collection.go index f48d48dc..9d3a8243 100644 --- a/component/collection.go +++ b/component/collection.go @@ -31,11 +31,11 @@ type CollectionComponent interface { func NewCollectionComponent(config *config.Config) (CollectionComponent, error) { cc := &collectionComponentImpl{} - cc.cs = database.NewCollectionStore() - cc.rs = database.NewRepoStore() - cc.us = database.NewUserStore() - cc.os = database.NewOrgStore() - cc.uls = database.NewUserLikesStore() + cc.collectionStore = database.NewCollectionStore() + cc.repoStore = database.NewRepoStore() + cc.userStore = database.NewUserStore() + cc.orgStore = database.NewOrgStore() + cc.userLikesStore = database.NewUserLikesStore() cc.userSvcClient = rpc.NewUserSvcHttpClient(fmt.Sprintf("%s:%d", config.User.Host, config.User.Port), rpc.AuthWithApiKey(config.APIToken)) spaceComponent, err := NewSpaceComponent(config) @@ -47,17 +47,17 @@ func NewCollectionComponent(config *config.Config) (CollectionComponent, error) } type collectionComponentImpl struct { - os database.OrgStore - cs database.CollectionStore - rs database.RepoStore - us database.UserStore - uls database.UserLikesStore - userSvcClient rpc.UserSvcClient - spaceComponent SpaceComponent + collectionStore database.CollectionStore + orgStore database.OrgStore + repoStore database.RepoStore + userStore database.UserStore + userLikesStore database.UserLikesStore + userSvcClient rpc.UserSvcClient + spaceComponent SpaceComponent } func (cc *collectionComponentImpl) GetCollections(ctx context.Context, filter *types.CollectionFilter, per, page int) ([]types.Collection, int, error) { - collections, total, err := cc.cs.GetCollections(ctx, filter, per, page, true) + collections, total, err := cc.collectionStore.GetCollections(ctx, filter, per, page, true) if err != nil { return nil, 0, err } @@ -73,7 +73,7 @@ func (cc *collectionComponentImpl) GetCollections(ctx context.Context, filter *t func (cc *collectionComponentImpl) CreateCollection(ctx context.Context, input types.CreateCollectionReq) (*database.Collection, error) { // find by user name - user, err := cc.us.FindByUsername(ctx, input.Username) + user, err := cc.userStore.FindByUsername(ctx, input.Username) if err != nil { return nil, fmt.Errorf("cannot find user for collection, %w", err) } @@ -92,24 +92,24 @@ func (cc *collectionComponentImpl) CreateCollection(ctx context.Context, input t collection.Username = "" } - return cc.cs.CreateCollection(ctx, collection) + return cc.collectionStore.CreateCollection(ctx, collection) } func (cc *collectionComponentImpl) GetCollection(ctx context.Context, currentUser string, id int64) (*types.Collection, error) { - collection, err := cc.cs.GetCollection(ctx, id) + collection, err := cc.collectionStore.GetCollection(ctx, id) if err != nil { return nil, err } // find by user name avatar := "" if collection.Username != "" { - user, err := cc.us.FindByUsername(ctx, collection.Username) + user, err := cc.userStore.FindByUsername(ctx, collection.Username) if err != nil { return nil, fmt.Errorf("cannot find user for collection, %w", err) } avatar = user.Avatar } else if collection.Namespace != "" { - org, err := cc.os.FindByPath(ctx, collection.Namespace) + org, err := cc.orgStore.FindByPath(ctx, collection.Namespace) if err != nil { return nil, fmt.Errorf("fail to get org info, path: %s, error: %w", collection.Namespace, err) } @@ -131,7 +131,7 @@ func (cc *collectionComponentImpl) GetCollection(ctx context.Context, currentUse if err != nil { return nil, err } - likeExists, err := cc.uls.IsExistCollection(ctx, currentUser, id) + likeExists, err := cc.userLikesStore.IsExistCollection(ctx, currentUser, id) if err != nil { newError := fmt.Errorf("failed to check for the presence of the user likes,error:%w", err) return nil, newError @@ -165,7 +165,7 @@ func (cc *collectionComponentImpl) GetPublicRepos(collection types.Collection) [ } func (cc *collectionComponentImpl) UpdateCollection(ctx context.Context, input types.CreateCollectionReq) (*database.Collection, error) { - collection, err := cc.cs.GetCollection(ctx, input.ID) + collection, err := cc.collectionStore.GetCollection(ctx, input.ID) if err != nil { return nil, fmt.Errorf("cannot find collection to update, %w", err) } @@ -175,25 +175,25 @@ func (cc *collectionComponentImpl) UpdateCollection(ctx context.Context, input t collection.Private = input.Private collection.Theme = input.Theme collection.UpdatedAt = time.Now() - return cc.cs.UpdateCollection(ctx, *collection) + return cc.collectionStore.UpdateCollection(ctx, *collection) } func (cc *collectionComponentImpl) DeleteCollection(ctx context.Context, id int64, userName string) error { // find by user name - user, err := cc.us.FindByUsername(ctx, userName) + user, err := cc.userStore.FindByUsername(ctx, userName) if err != nil { return fmt.Errorf("cannot find user for collection, %w", err) } - return cc.cs.DeleteCollection(ctx, id, user.ID) + return cc.collectionStore.DeleteCollection(ctx, id, user.ID) } func (cc *collectionComponentImpl) AddReposToCollection(ctx context.Context, req types.UpdateCollectionReposReq) error { // find by user name - user, err := cc.us.FindByUsername(ctx, req.Username) + user, err := cc.userStore.FindByUsername(ctx, req.Username) if err != nil { return fmt.Errorf("cannot find user for collection, %w", err) } - collection, err := cc.cs.GetCollection(ctx, req.ID) + collection, err := cc.collectionStore.GetCollection(ctx, req.ID) if err != nil { return err } @@ -207,16 +207,16 @@ func (cc *collectionComponentImpl) AddReposToCollection(ctx context.Context, req RepositoryID: id, }) } - return cc.cs.AddCollectionRepos(ctx, collectionRepos) + return cc.collectionStore.AddCollectionRepos(ctx, collectionRepos) } func (cc *collectionComponentImpl) RemoveReposFromCollection(ctx context.Context, req types.UpdateCollectionReposReq) error { // find by user name - user, err := cc.us.FindByUsername(ctx, req.Username) + user, err := cc.userStore.FindByUsername(ctx, req.Username) if err != nil { return fmt.Errorf("cannot find user for collection, %w", err) } - collection, err := cc.cs.GetCollection(ctx, req.ID) + collection, err := cc.collectionStore.GetCollection(ctx, req.ID) if err != nil { return err } @@ -230,7 +230,7 @@ func (cc *collectionComponentImpl) RemoveReposFromCollection(ctx context.Context RepositoryID: id, }) } - return cc.cs.RemoveCollectionRepos(ctx, collectionRepos) + return cc.collectionStore.RemoveCollectionRepos(ctx, collectionRepos) } func (cc *collectionComponentImpl) getUserCollectionPermission(ctx context.Context, userName string, collection *database.Collection) (*types.UserRepoPermission, error) { @@ -290,7 +290,7 @@ func (c *collectionComponentImpl) OrgCollections(ctx context.Context, req *types } } onlyPublic := !r.CanRead() - collections, total, err := c.cs.ByUserOrgs(ctx, req.Namespace, req.PageSize, req.Page, onlyPublic) + collections, total, err := c.collectionStore.ByUserOrgs(ctx, req.Namespace, req.PageSize, req.Page, onlyPublic) if err != nil { return nil, 0, err } diff --git a/component/collection_test.go b/component/collection_test.go new file mode 100644 index 00000000..a06dbf42 --- /dev/null +++ b/component/collection_test.go @@ -0,0 +1,220 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/git/membership" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestCollectionComponent_GetCollections(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCollectionComponent(ctx, t) + + filter := &types.CollectionFilter{Search: "foo"} + cc.mocks.stores.CollectionMock().EXPECT().GetCollections(ctx, filter, 10, 1, true).Return( + []database.Collection{{Name: "n"}}, 100, nil, + ) + data, total, err := cc.GetCollections(ctx, filter, 10, 1) + require.Nil(t, err) + require.Equal(t, 100, total) + require.Equal(t, []types.Collection{{Name: "n"}}, data) +} + +func TestCollectionComponent_CreateCollection(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCollectionComponent(ctx, t) + + cc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + Username: "user", + }, nil) + cc.mocks.stores.CollectionMock().EXPECT().CreateCollection(ctx, database.Collection{ + Username: "user", + Name: "n", + Nickname: "nn", + Description: "d", + }).Return(&database.Collection{}, nil) + + r, err := cc.CreateCollection(ctx, types.CreateCollectionReq{ + Name: "n", + Nickname: "nn", + Description: "d", + Username: "user", + }) + require.Nil(t, err) + require.Equal(t, &database.Collection{}, r) +} + +func TestCollectionComponent_GetCollection(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCollectionComponent(ctx, t) + + repos := []database.Repository{ + {RepositoryType: types.SpaceRepo, Path: "r1/foo"}, + } + cc.mocks.stores.CollectionMock().EXPECT().GetCollection(ctx, int64(1)).Return( + &database.Collection{Username: "user", Namespace: "user", Repositories: repos}, nil, + ) + cc.mocks.stores.CollectionMock().EXPECT().GetCollection(ctx, int64(2)).Return( + &database.Collection{Namespace: "ns", Repositories: repos}, nil, + ) + cc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + Username: "user", + Avatar: "aaa", + }, nil) + cc.mocks.stores.OrgMock().EXPECT().FindByPath(ctx, "ns").Return(database.Organization{ + Logo: "logo", + }, nil) + cc.mocks.userSvcClient.EXPECT().GetMemberRole(ctx, "ns", "user").Return(membership.RoleAdmin, nil) + cc.mocks.stores.UserLikesMock().EXPECT().IsExistCollection(ctx, "user", mock.Anything).Return(true, nil) + cc.mocks.components.space.EXPECT().Status(ctx, "r1", "foo").Return("", "go", nil) + + col, err := cc.GetCollection(ctx, "user", 1) + require.Nil(t, err) + require.Equal(t, &types.Collection{ + Username: "user", + Namespace: "user", + UserLikes: true, + CanWrite: true, + CanManage: true, + Avatar: "aaa", + Repositories: []types.CollectionRepository{ + {RepositoryType: types.SpaceRepo, Path: "r1/foo", Status: "go"}, + }, + }, col) + col, err = cc.GetCollection(ctx, "user", 2) + require.Nil(t, err) + require.Equal(t, &types.Collection{ + Namespace: "ns", + UserLikes: true, + CanWrite: true, + CanManage: true, + Avatar: "logo", + Repositories: []types.CollectionRepository{ + {RepositoryType: types.SpaceRepo, Path: "r1/foo", Status: "go"}, + }, + }, col) +} + +func TestCollectionComponent_GetPublicRepos(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCollectionComponent(ctx, t) + + repos := []types.CollectionRepository{ + {RepositoryType: types.SpaceRepo, Path: "r1/foo", Private: true}, + {RepositoryType: types.SpaceRepo, Path: "r1/foo", Private: false}, + } + r := cc.GetPublicRepos(types.Collection{Repositories: repos}) + require.Equal(t, []types.CollectionRepository{ + {RepositoryType: types.SpaceRepo, Path: "r1/foo", Private: false}, + }, r) +} + +func TestCollectionComponent_UpdateCollection(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCollectionComponent(ctx, t) + + cc.mocks.stores.CollectionMock().EXPECT().GetCollection(ctx, int64(1)).Return( + &database.Collection{}, nil, + ) + cc.mocks.stores.CollectionMock().EXPECT().UpdateCollection(ctx, mock.Anything).RunAndReturn(func(ctx context.Context, c database.Collection) (*database.Collection, error) { + require.Equal(t, c.Name, "n") + require.True(t, c.Private) + return &database.Collection{}, nil + }) + + r, err := cc.UpdateCollection(ctx, types.CreateCollectionReq{ + ID: 1, + Name: "n", + Private: true, + }) + require.Nil(t, err) + require.Equal(t, &database.Collection{}, r) +} + +func TestCollectionComponent_DeleteCollection(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCollectionComponent(ctx, t) + + cc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + ID: 2, + Username: "user", + }, nil) + cc.mocks.stores.CollectionMock().EXPECT().DeleteCollection(ctx, int64(1), int64(2)).Return(nil) + + err := cc.DeleteCollection(ctx, 1, "user") + require.Nil(t, err) + +} + +func TestCollectionComponent_AddReposToCollection(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCollectionComponent(ctx, t) + + cc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + ID: 2, + Username: "user", + }, nil) + cc.mocks.stores.CollectionMock().EXPECT().GetCollection(ctx, int64(1)).Return( + &database.Collection{UserID: 2}, nil, + ) + cc.mocks.stores.CollectionMock().EXPECT().AddCollectionRepos(ctx, []database.CollectionRepository{ + {CollectionID: 1, RepositoryID: 1}, + {CollectionID: 1, RepositoryID: 2}, + }).Return(nil) + + err := cc.AddReposToCollection(ctx, types.UpdateCollectionReposReq{ + RepoIDs: []int64{1, 2}, + Username: "user", + ID: 1, + }) + require.Nil(t, err) + +} + +func TestCollectionComponent_RemoveReposFromCollection(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCollectionComponent(ctx, t) + + cc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + ID: 2, + Username: "user", + }, nil) + cc.mocks.stores.CollectionMock().EXPECT().GetCollection(ctx, int64(1)).Return( + &database.Collection{UserID: 2}, nil, + ) + cc.mocks.stores.CollectionMock().EXPECT().RemoveCollectionRepos(ctx, []database.CollectionRepository{ + {CollectionID: 1, RepositoryID: 1}, + {CollectionID: 1, RepositoryID: 2}, + }).Return(nil) + + err := cc.RemoveReposFromCollection(ctx, types.UpdateCollectionReposReq{ + RepoIDs: []int64{1, 2}, + Username: "user", + ID: 1, + }) + require.Nil(t, err) + +} + +func TestCollectionComponent_OrgCollections(t *testing.T) { + ctx := context.TODO() + cc := initializeTestCollectionComponent(ctx, t) + + cc.mocks.userSvcClient.EXPECT().GetMemberRole(ctx, "ns", "user").Return(membership.RoleAdmin, nil) + cc.mocks.stores.CollectionMock().EXPECT().ByUserOrgs(ctx, "ns", 10, 1, false).Return([]database.Collection{ + {Name: "col"}, + }, 100, nil) + + cols, total, err := cc.OrgCollections(ctx, &types.OrgDatasetsReq{ + Namespace: "ns", CurrentUser: "user", + PageOpts: types.PageOpts{Page: 1, PageSize: 10}, + }) + require.Nil(t, err) + require.Equal(t, 100, total) + require.Equal(t, []types.Collection{{Name: "col"}}, cols) +} diff --git a/component/dataset.go b/component/dataset.go index 44b224ed..302fbf68 100644 --- a/component/dataset.go +++ b/component/dataset.go @@ -7,7 +7,10 @@ import ( "log/slog" "time" + "opencsg.com/csghub-server/builder/git" + "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/git/membership" + "opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/types" @@ -91,27 +94,44 @@ type DatasetComponent interface { func NewDatasetComponent(config *config.Config) (DatasetComponent, error) { c := &datasetComponentImpl{} - c.ts = database.NewTagStore() - c.ds = database.NewDatasetStore() - c.rs = database.NewRepoStore() + c.tagStore = database.NewTagStore() + c.datasetStore = database.NewDatasetStore() + c.repoStore = database.NewRepoStore() + c.namespaceStore = database.NewNamespaceStore() + c.userStore = database.NewUserStore() + c.userLikesStore = database.NewUserLikesStore() var err error - c.repoComponentImpl, err = NewRepoComponentImpl(config) + c.repoComponent, err = NewRepoComponentImpl(config) if err != nil { return nil, fmt.Errorf("failed to create repo component, error: %w", err) } - c.sc, err = NewSensitiveComponent(config) + c.sensitiveComponent, err = NewSensitiveComponent(config) if err != nil { return nil, fmt.Errorf("failed to create sensitive component, error: %w", err) } + gs, err := git.NewGitServer(config) + if err != nil { + return nil, fmt.Errorf("failed to create git server, error: %w", err) + } + c.userSvcClient = rpc.NewUserSvcHttpClient(fmt.Sprintf("%s:%d", config.User.Host, config.User.Port), + rpc.AuthWithApiKey(config.APIToken)) + c.gitServer = gs + c.config = config return c, nil } type datasetComponentImpl struct { - *repoComponentImpl - ts database.TagStore - ds database.DatasetStore - rs database.RepoStore - sc SensitiveComponent + config *config.Config + repoComponent RepoComponent + tagStore database.TagStore + datasetStore database.DatasetStore + repoStore database.RepoStore + namespaceStore database.NamespaceStore + userStore database.UserStore + sensitiveComponent SensitiveComponent + gitServer gitserver.GitServer + userLikesStore database.UserLikesStore + userSvcClient rpc.UserSvcClient } func (c *datasetComponentImpl) Create(ctx context.Context, req *types.CreateDatasetReq) (*types.Dataset, error) { @@ -131,7 +151,7 @@ func (c *datasetComponentImpl) Create(ctx context.Context, req *types.CreateData } if !user.CanAdmin() { if namespace.NamespaceType == database.OrgNamespace { - canWrite, err := c.CheckCurrentUserPermission(ctx, req.Username, req.Namespace, membership.RoleWrite) + canWrite, err := c.repoComponent.CheckCurrentUserPermission(ctx, req.Username, req.Namespace, membership.RoleWrite) if err != nil { return nil, err } @@ -158,7 +178,7 @@ func (c *datasetComponentImpl) Create(ctx context.Context, req *types.CreateData req.RepoType = types.DatasetRepo req.Readme = generateReadmeData(req.License) req.Nickname = nickname - _, dbRepo, err := c.CreateRepo(ctx, req.CreateRepoReq) + _, dbRepo, err := c.repoComponent.CreateRepo(ctx, req.CreateRepoReq) if err != nil { return nil, err } @@ -168,13 +188,13 @@ func (c *datasetComponentImpl) Create(ctx context.Context, req *types.CreateData RepositoryID: dbRepo.ID, } - dataset, err := c.ds.Create(ctx, dbDataset) + dataset, err := c.datasetStore.Create(ctx, dbDataset) if err != nil { return nil, fmt.Errorf("failed to create database dataset, cause: %w", err) } // Create README.md file - err = c.git.CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ + err = c.gitServer.CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ Username: user.Username, Email: user.Email, Message: initCommitMessage, @@ -190,7 +210,7 @@ func (c *datasetComponentImpl) Create(ctx context.Context, req *types.CreateData } // Create .gitattributes file - err = c.git.CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ + err = c.gitServer.CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ Username: user.Username, Email: user.Email, Message: initCommitMessage, @@ -258,7 +278,7 @@ func (c *datasetComponentImpl) commonIndex(ctx context.Context, filter *types.Re err error resDatasets []types.Dataset ) - repos, total, err := c.PublicToUser(ctx, types.DatasetRepo, filter.Username, filter, per, page) + repos, total, err := c.repoComponent.PublicToUser(ctx, types.DatasetRepo, filter.Username, filter, per, page) if err != nil { newError := fmt.Errorf("failed to get public dataset repos,error:%w", err) return nil, 0, newError @@ -267,7 +287,7 @@ func (c *datasetComponentImpl) commonIndex(ctx context.Context, filter *types.Re for _, repo := range repos { repoIDs = append(repoIDs, repo.ID) } - datasets, err := c.ds.ByRepoIDs(ctx, repoIDs) + datasets, err := c.datasetStore.ByRepoIDs(ctx, repoIDs) if err != nil { newError := fmt.Errorf("failed to get datasets by repo ids,error:%w", err) return nil, 0, newError @@ -328,18 +348,18 @@ func (c *datasetComponentImpl) commonIndex(ctx context.Context, filter *types.Re func (c *datasetComponentImpl) Update(ctx context.Context, req *types.UpdateDatasetReq) (*types.Dataset, error) { req.RepoType = types.DatasetRepo - dbRepo, err := c.UpdateRepo(ctx, req.UpdateRepoReq) + dbRepo, err := c.repoComponent.UpdateRepo(ctx, req.UpdateRepoReq) if err != nil { return nil, err } - dataset, err := c.ds.ByRepoID(ctx, dbRepo.ID) + dataset, err := c.datasetStore.ByRepoID(ctx, dbRepo.ID) if err != nil { return nil, fmt.Errorf("failed to find dataset, error: %w", err) } // update times of dateset - err = c.ds.Update(ctx, *dataset) + err = c.datasetStore.Update(ctx, *dataset) if err != nil { return nil, fmt.Errorf("failed to update database dataset, error: %w", err) } @@ -362,7 +382,7 @@ func (c *datasetComponentImpl) Update(ctx context.Context, req *types.UpdateData } func (c *datasetComponentImpl) Delete(ctx context.Context, namespace, name, currentUser string) error { - dataset, err := c.ds.FindByPath(ctx, namespace, name) + dataset, err := c.datasetStore.FindByPath(ctx, namespace, name) if err != nil { return fmt.Errorf("failed to find dataset, error: %w", err) } @@ -373,12 +393,12 @@ func (c *datasetComponentImpl) Delete(ctx context.Context, namespace, name, curr Name: name, RepoType: types.DatasetRepo, } - _, err = c.DeleteRepo(ctx, deleteDatabaseRepoReq) + _, err = c.repoComponent.DeleteRepo(ctx, deleteDatabaseRepoReq) if err != nil { return fmt.Errorf("failed to delete repo of dataset, error: %w", err) } - err = c.ds.Delete(ctx, *dataset) + err = c.datasetStore.Delete(ctx, *dataset) if err != nil { return fmt.Errorf("failed to delete database dataset, error: %w", err) } @@ -387,12 +407,12 @@ func (c *datasetComponentImpl) Delete(ctx context.Context, namespace, name, curr func (c *datasetComponentImpl) Show(ctx context.Context, namespace, name, currentUser string) (*types.Dataset, error) { var tags []types.RepoTag - dataset, err := c.ds.FindByPath(ctx, namespace, name) + dataset, err := c.datasetStore.FindByPath(ctx, namespace, name) if err != nil { return nil, fmt.Errorf("failed to find dataset, error: %w", err) } - permission, err := c.GetUserRepoPermission(ctx, currentUser, dataset.Repository) + permission, err := c.repoComponent.GetUserRepoPermission(ctx, currentUser, dataset.Repository) if err != nil { return nil, fmt.Errorf("failed to get user repo permission, error: %w", err) } @@ -400,7 +420,7 @@ func (c *datasetComponentImpl) Show(ctx context.Context, namespace, name, curren return nil, ErrUnauthorized } - ns, err := c.GetNameSpaceInfo(ctx, namespace) + ns, err := c.repoComponent.GetNameSpaceInfo(ctx, namespace) if err != nil { return nil, fmt.Errorf("failed to get namespace info for dataset, error: %w", err) } @@ -458,12 +478,12 @@ func (c *datasetComponentImpl) Show(ctx context.Context, namespace, name, curren } func (c *datasetComponentImpl) Relations(ctx context.Context, namespace, name, currentUser string) (*types.Relations, error) { - dataset, err := c.ds.FindByPath(ctx, namespace, name) + dataset, err := c.datasetStore.FindByPath(ctx, namespace, name) if err != nil { return nil, fmt.Errorf("failed to find dataset repo, error: %w", err) } - allow, _ := c.AllowReadAccessRepo(ctx, dataset.Repository, currentUser) + allow, _ := c.repoComponent.AllowReadAccessRepo(ctx, dataset.Repository, currentUser) if !allow { return nil, ErrUnauthorized } @@ -472,7 +492,7 @@ func (c *datasetComponentImpl) Relations(ctx context.Context, namespace, name, c } func (c *datasetComponentImpl) getRelations(ctx context.Context, repoID int64, currentUser string) (*types.Relations, error) { - res, err := c.RelatedRepos(ctx, repoID, currentUser) + res, err := c.repoComponent.RelatedRepos(ctx, repoID, currentUser) if err != nil { return nil, err } @@ -507,7 +527,7 @@ func (c *datasetComponentImpl) OrgDatasets(ctx context.Context, req *types.OrgDa } } onlyPublic := !r.CanRead() - datasets, total, err := c.ds.ByOrgPath(ctx, req.Namespace, req.PageSize, req.Page, onlyPublic) + datasets, total, err := c.datasetStore.ByOrgPath(ctx, req.Namespace, req.PageSize, req.Page, onlyPublic) if err != nil { newError := fmt.Errorf("failed to get user datasets,error:%w", err) slog.Error(newError.Error()) diff --git a/component/dataset_test.go b/component/dataset_test.go new file mode 100644 index 00000000..69f1744d --- /dev/null +++ b/component/dataset_test.go @@ -0,0 +1,251 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/git/membership" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestDatasetCompnent_Create(t *testing.T) { + ctx := context.TODO() + dc := initializeTestDatasetComponent(ctx, t) + + req := &types.CreateDatasetReq{ + CreateRepoReq: types.CreateRepoReq{ + Username: "user", + Namespace: "ns", + Name: "n", + }, + } + dc.mocks.stores.NamespaceMock().EXPECT().FindByPath(ctx, "ns").Return( + database.Namespace{}, nil, + ) + dc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + RoleMask: "admin", + Username: "user", + }, nil) + rq := req.CreateRepoReq + rq.RepoType = types.DatasetRepo + rq.Readme = "\n---\nlicense: \n---\n\t" + rq.DefaultBranch = "main" + rq.Nickname = "n" + dc.mocks.components.repo.EXPECT().CreateRepo(ctx, rq).Return(&gitserver.CreateRepoResp{}, &database.Repository{}, nil) + dc.mocks.stores.DatasetMock().EXPECT().Create(ctx, database.Dataset{ + Repository: &database.Repository{}, + }).Return(&database.Dataset{ + Repository: &database.Repository{ + Tags: []database.Tag{{Name: "t1"}}, + }, + }, nil) + dc.mocks.gitServer.EXPECT().CreateRepoFile(buildCreateFileReq( + &types.CreateFileParams{ + Username: "user", + Message: initCommitMessage, + Branch: "main", + Content: "\n---\nlicense: \n---\n\t", + Namespace: "ns", + Name: "n", + FilePath: readmeFileName, + }, types.DatasetRepo), + ).Return(nil) + dc.mocks.gitServer.EXPECT().CreateRepoFile(buildCreateFileReq( + &types.CreateFileParams{ + Username: "user", + Message: initCommitMessage, + Branch: "main", + Content: datasetGitattributesContent, + Namespace: "ns", + Name: "n", + FilePath: gitattributesFileName, + }, types.DatasetRepo), + ).Return(nil) + + resp, err := dc.Create(ctx, req) + require.Nil(t, err) + require.Equal(t, &types.Dataset{ + User: types.User{Username: "user"}, + Repository: types.Repository{ + HTTPCloneURL: "/s/.git", + SSHCloneURL: ":s/.git", + }, + Tags: []types.RepoTag{{Name: "t1"}}, + }, resp) + +} + +func TestDatasetCompnent_Index(t *testing.T) { + ctx := context.TODO() + dc := initializeTestDatasetComponent(ctx, t) + + filter := &types.RepoFilter{Username: "user"} + dc.mocks.components.repo.EXPECT().PublicToUser(ctx, types.DatasetRepo, "user", filter, 10, 1).Return( + []*database.Repository{ + {ID: 1, Tags: []database.Tag{{Name: "t1"}}}, + {ID: 2}, + }, 100, nil, + ) + dc.mocks.stores.DatasetMock().EXPECT().ByRepoIDs(ctx, []int64{1, 2}).Return([]database.Dataset{ + { + ID: 11, RepositoryID: 2, Repository: &database.Repository{ + User: database.User{Username: "user2"}, + }, + }, + { + ID: 12, RepositoryID: 1, Repository: &database.Repository{ + User: database.User{Username: "user1"}, + }, + }, + }, nil) + + data, total, err := dc.Index(ctx, filter, 10, 1) + require.Nil(t, err) + require.Equal(t, 100, total) + require.Equal(t, []types.Dataset{ + {ID: 12, RepositoryID: 1, Repository: types.Repository{ + HTTPCloneURL: "/s/.git", + SSHCloneURL: ":s/.git", + }, User: types.User{Username: "user1"}, + Tags: []types.RepoTag{{Name: "t1"}}, + }, + {ID: 11, RepositoryID: 2, Repository: types.Repository{ + HTTPCloneURL: "/s/.git", + SSHCloneURL: ":s/.git", + }, User: types.User{Username: "user2"}}, + }, data) + +} + +func TestDatasetCompnent_Update(t *testing.T) { + ctx := context.TODO() + dc := initializeTestDatasetComponent(ctx, t) + + req := &types.UpdateDatasetReq{UpdateRepoReq: types.UpdateRepoReq{ + RepoType: types.DatasetRepo, + }} + dc.mocks.components.repo.EXPECT().UpdateRepo(ctx, req.UpdateRepoReq).Return( + &database.Repository{ID: 1, Name: "repo"}, nil, + ) + dc.mocks.stores.DatasetMock().EXPECT().ByRepoID(ctx, int64(1)).Return( + &database.Dataset{ID: 2}, nil, + ) + dc.mocks.stores.DatasetMock().EXPECT().Update(ctx, database.Dataset{ID: 2}).Return(nil) + + d, err := dc.Update(ctx, req) + require.Nil(t, err) + require.Equal(t, &types.Dataset{ + ID: 2, + RepositoryID: 1, + Name: "repo", + }, d) +} + +func TestDatasetCompnent_Delete(t *testing.T) { + ctx := context.TODO() + dc := initializeTestDatasetComponent(ctx, t) + + dc.mocks.stores.DatasetMock().EXPECT().FindByPath(ctx, "ns", "n").Return(&database.Dataset{}, nil) + dc.mocks.components.repo.EXPECT().DeleteRepo(ctx, types.DeleteRepoReq{ + Username: "user", + Namespace: "ns", + Name: "n", + RepoType: types.DatasetRepo, + }).Return(&database.Repository{}, nil) + dc.mocks.stores.DatasetMock().EXPECT().Delete(ctx, database.Dataset{}).Return(nil) + + err := dc.Delete(ctx, "ns", "n", "user") + require.Nil(t, err) + +} + +func TestDatasetCompnent_Show(t *testing.T) { + ctx := context.TODO() + dc := initializeTestDatasetComponent(ctx, t) + + dataset := &database.Dataset{ + ID: 1, + Repository: &database.Repository{ + ID: 2, + Name: "n", + Tags: []database.Tag{{Name: "t1"}}, + User: database.User{ + Username: "user", + }, + }, + } + dc.mocks.stores.DatasetMock().EXPECT().FindByPath(ctx, "ns", "n").Return(dataset, nil) + dc.mocks.components.repo.EXPECT().GetUserRepoPermission(ctx, "user", dataset.Repository).Return(&types.UserRepoPermission{CanRead: true}, nil) + dc.mocks.components.repo.EXPECT().GetNameSpaceInfo(ctx, "ns").Return(&types.Namespace{}, nil) + dc.mocks.stores.UserLikesMock().EXPECT().IsExist(ctx, "user", int64(2)).Return(true, nil) + + d, err := dc.Show(ctx, "ns", "n", "user") + require.Nil(t, err) + require.Equal(t, &types.Dataset{ + ID: 1, + Name: "n", + RepositoryID: 2, + Tags: []types.RepoTag{{Name: "t1"}}, + Repository: types.Repository{ + HTTPCloneURL: "/s/.git", + SSHCloneURL: ":s/.git", + }, + User: types.User{Username: "user"}, + UserLikes: true, + Namespace: &types.Namespace{}, + }, d) + +} + +func TestDatasetCompnent_Relations(t *testing.T) { + ctx := context.TODO() + dc := initializeTestDatasetComponent(ctx, t) + + dataset := &database.Dataset{ + Repository: &database.Repository{}, + RepositoryID: 1, + } + dc.mocks.stores.DatasetMock().EXPECT().FindByPath(ctx, "ns", "n").Return(dataset, nil) + dc.mocks.components.repo.EXPECT().RelatedRepos(ctx, int64(1), "user").Return( + map[types.RepositoryType][]*database.Repository{ + types.ModelRepo: { + {Name: "n"}, + }, + }, nil, + ) + dc.mocks.components.repo.EXPECT().AllowReadAccessRepo(ctx, dataset.Repository, "user").Return(true, nil) + + rs, err := dc.Relations(ctx, "ns", "n", "user") + require.Nil(t, err) + require.Equal(t, &types.Relations{ + Models: []*types.Model{{Name: "n"}}, + }, rs) + +} + +func TestDatasetCompnent_OrgDatasets(t *testing.T) { + ctx := context.TODO() + dc := initializeTestDatasetComponent(ctx, t) + + dc.mocks.userSvcClient.EXPECT().GetMemberRole(ctx, "ns", "user").Return(membership.RoleAdmin, nil) + dc.mocks.stores.DatasetMock().EXPECT().ByOrgPath(ctx, "ns", 10, 1, false).Return( + []database.Dataset{ + {ID: 1, Repository: &database.Repository{Name: "repo"}}, + }, 100, nil, + ) + + data, total, err := dc.OrgDatasets(ctx, &types.OrgDatasetsReq{ + Namespace: "ns", + CurrentUser: "user", + PageOpts: types.PageOpts{Page: 1, PageSize: 10}, + }) + require.Nil(t, err) + require.Equal(t, 100, total) + require.Equal(t, []types.Dataset{ + {ID: 1, Name: "repo"}, + }, data) + +} diff --git a/component/discussion.go b/component/discussion.go index 1b2caf3c..06f30e9d 100644 --- a/component/discussion.go +++ b/component/discussion.go @@ -12,9 +12,9 @@ import ( ) type discussionComponentImpl struct { - ds database.DiscussionStore - rs database.RepoStore - us database.UserStore + discussionStore database.DiscussionStore + repoStore database.RepoStore + userStore database.UserStore } type DiscussionComponent interface { @@ -33,22 +33,22 @@ func NewDiscussionComponent() DiscussionComponent { ds := database.NewDiscussionStore() rs := database.NewRepoStore() us := database.NewUserStore() - return &discussionComponentImpl{ds: ds, rs: rs, us: us} + return &discussionComponentImpl{discussionStore: ds, repoStore: rs, userStore: us} } func (c *discussionComponentImpl) CreateRepoDiscussion(ctx context.Context, req CreateRepoDiscussionRequest) (*CreateDiscussionResponse, error) { //TODO:check if the user can access the repo //get repo by namespace and name - repo, err := c.rs.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo by path '%s/%s/%s': %w", req.RepoType, req.Namespace, req.Name, err) } - user, err := c.us.FindByUsername(ctx, req.CurrentUser) + user, err := c.userStore.FindByUsername(ctx, req.CurrentUser) if err != nil { return nil, fmt.Errorf("failed to find user by username '%s': %w", req.CurrentUser, err) } - discussion, err := c.ds.Create(ctx, database.Discussion{ + discussion, err := c.discussionStore.Create(ctx, database.Discussion{ Title: req.Title, DiscussionableID: repo.ID, DiscussionableType: database.DiscussionableTypeRepo, @@ -72,11 +72,11 @@ func (c *discussionComponentImpl) CreateRepoDiscussion(ctx context.Context, req } func (c *discussionComponentImpl) GetDiscussion(ctx context.Context, id int64) (*ShowDiscussionResponse, error) { - discussion, err := c.ds.FindByID(ctx, id) + discussion, err := c.discussionStore.FindByID(ctx, id) if err != nil { return nil, fmt.Errorf("failed to find discussion by id '%d': %w", id, err) } - comments, err := c.ds.FindDiscussionComments(ctx, discussion.ID) + comments, err := c.discussionStore.FindDiscussionComments(ctx, discussion.ID) if err != nil && err != sql.ErrNoRows { return nil, fmt.Errorf("failed to find discussion comments by discussion id '%d': %w", discussion.ID, err) } @@ -105,18 +105,18 @@ func (c *discussionComponentImpl) GetDiscussion(ctx context.Context, id int64) ( func (c *discussionComponentImpl) UpdateDiscussion(ctx context.Context, req UpdateDiscussionRequest) error { //check if the user is the owner of the discussion - user, err := c.us.FindByUsername(ctx, req.CurrentUser) + user, err := c.userStore.FindByUsername(ctx, req.CurrentUser) if err != nil { return fmt.Errorf("failed to find user by username '%s': %w", req.CurrentUser, err) } - discussion, err := c.ds.FindByID(ctx, req.ID) + discussion, err := c.discussionStore.FindByID(ctx, req.ID) if err != nil { return fmt.Errorf("failed to find discussion by id '%d': %w", req.ID, err) } if discussion.UserID != user.ID { return fmt.Errorf("user '%s' is not the owner of the discussion '%d'", req.CurrentUser, req.ID) } - err = c.ds.UpdateByID(ctx, req.ID, req.Title) + err = c.discussionStore.UpdateByID(ctx, req.ID, req.Title) if err != nil { return fmt.Errorf("failed to update discussion by id '%d': %w", req.ID, err) } @@ -124,14 +124,14 @@ func (c *discussionComponentImpl) UpdateDiscussion(ctx context.Context, req Upda } func (c *discussionComponentImpl) DeleteDiscussion(ctx context.Context, currentUser string, id int64) error { - discussion, err := c.ds.FindByID(ctx, id) + discussion, err := c.discussionStore.FindByID(ctx, id) if err != nil { return fmt.Errorf("failed to find discussion by id '%d': %w", id, err) } if discussion.User.Username != currentUser { return fmt.Errorf("user '%s' is not the owner of the discussion '%d'", currentUser, id) } - err = c.ds.DeleteByID(ctx, id) + err = c.discussionStore.DeleteByID(ctx, id) if err != nil { return fmt.Errorf("failed to delete discussion by id '%d': %w", id, err) } @@ -140,11 +140,11 @@ func (c *discussionComponentImpl) DeleteDiscussion(ctx context.Context, currentU func (c *discussionComponentImpl) ListRepoDiscussions(ctx context.Context, req ListRepoDiscussionRequest) (*ListRepoDiscussionResponse, error) { //TODO:check if the user can access the repo - repo, err := c.rs.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo by path '%s/%s/%s': %w", req.RepoType, req.Namespace, req.Name, err) } - discussions, err := c.ds.FindByDiscussionableID(ctx, database.DiscussionableTypeRepo, repo.ID) + discussions, err := c.discussionStore.FindByDiscussionableID(ctx, database.DiscussionableTypeRepo, repo.ID) if err != nil { return nil, fmt.Errorf("failed to list repo discussions by repo type '%s', namespace '%s', name '%s': %w", req.RepoType, req.Namespace, req.Name, err) } @@ -168,18 +168,18 @@ func (c *discussionComponentImpl) ListRepoDiscussions(ctx context.Context, req L func (c *discussionComponentImpl) CreateDiscussionComment(ctx context.Context, req CreateCommentRequest) (*CreateCommentResponse, error) { req.CommentableType = database.CommentableTypeDiscussion // get discussion by id - _, err := c.ds.FindByID(ctx, req.CommentableID) + _, err := c.discussionStore.FindByID(ctx, req.CommentableID) if err != nil { return nil, fmt.Errorf("failed to find discussion by id '%d': %w", req.CommentableID, err) } //get user by username - user, err := c.us.FindByUsername(ctx, req.CurrentUser) + user, err := c.userStore.FindByUsername(ctx, req.CurrentUser) if err != nil { return nil, fmt.Errorf("failed to find user by username '%s': %w", req.CurrentUser, err) } // create comment - comment, err := c.ds.CreateComment(ctx, database.Comment{ + comment, err := c.discussionStore.CreateComment(ctx, database.Comment{ Content: req.Content, CommentableID: req.CommentableID, CommentableType: req.CommentableType, @@ -202,12 +202,12 @@ func (c *discussionComponentImpl) CreateDiscussionComment(ctx context.Context, r } func (c *discussionComponentImpl) UpdateComment(ctx context.Context, currentUser string, id int64, content string) error { - user, err := c.us.FindByUsername(ctx, currentUser) + user, err := c.userStore.FindByUsername(ctx, currentUser) if err != nil { return fmt.Errorf("failed to find user by username '%s': %w", currentUser, err) } //get comment by id - comment, err := c.ds.FindCommentByID(ctx, id) + comment, err := c.discussionStore.FindCommentByID(ctx, id) if err != nil { return fmt.Errorf("failed to find comment by id '%d': %w", id, err) } @@ -215,7 +215,7 @@ func (c *discussionComponentImpl) UpdateComment(ctx context.Context, currentUser if comment.UserID != user.ID { return fmt.Errorf("user '%s' is not the owner of the comment '%d'", currentUser, id) } - err = c.ds.UpdateComment(ctx, id, content) + err = c.discussionStore.UpdateComment(ctx, id, content) if err != nil { return fmt.Errorf("failed to update comment by id '%d': %w", id, err) } @@ -223,12 +223,12 @@ func (c *discussionComponentImpl) UpdateComment(ctx context.Context, currentUser } func (c *discussionComponentImpl) DeleteComment(ctx context.Context, currentUser string, id int64) error { - user, err := c.us.FindByUsername(ctx, currentUser) + user, err := c.userStore.FindByUsername(ctx, currentUser) if err != nil { return fmt.Errorf("failed to find user by username '%s': %w", currentUser, err) } //get comment by id - comment, err := c.ds.FindCommentByID(ctx, id) + comment, err := c.discussionStore.FindCommentByID(ctx, id) if err != nil { return fmt.Errorf("failed to find comment by id '%d': %w", id, err) } @@ -236,7 +236,7 @@ func (c *discussionComponentImpl) DeleteComment(ctx context.Context, currentUser if comment.UserID != user.ID { return fmt.Errorf("user '%s' is not the owner of the comment '%d'", currentUser, id) } - err = c.ds.DeleteComment(ctx, id) + err = c.discussionStore.DeleteComment(ctx, id) if err != nil { return fmt.Errorf("failed to delete comment by id '%d': %w", id, err) } @@ -244,7 +244,7 @@ func (c *discussionComponentImpl) DeleteComment(ctx context.Context, currentUser } func (c *discussionComponentImpl) ListDiscussionComments(ctx context.Context, discussionID int64) ([]*DiscussionResponse_Comment, error) { - comments, err := c.ds.FindDiscussionComments(ctx, discussionID) + comments, err := c.discussionStore.FindDiscussionComments(ctx, discussionID) if err != nil { return nil, fmt.Errorf("failed to find discussion comments by discussion id '%d': %w", discussionID, err) } diff --git a/component/discussion_test.go b/component/discussion_test.go index 1813f69a..ddceff7a 100644 --- a/component/discussion_test.go +++ b/component/discussion_test.go @@ -18,9 +18,9 @@ func TestDiscussionComponent_CreateDisucssion(t *testing.T) { mockDiscussionStore := mockdb.NewMockDiscussionStore(t) // new discussionComponentImpl from mock db store comp := &discussionComponentImpl{ - rs: mockRepoStore, - us: mockUserStore, - ds: mockDiscussionStore, + repoStore: mockRepoStore, + userStore: mockUserStore, + discussionStore: mockDiscussionStore, } repo := &database.Repository{ @@ -77,9 +77,9 @@ func TestDiscussionComponent_GetDisussion(t *testing.T) { mockDiscussionStore := mockdb.NewMockDiscussionStore(t) // new discussionComponentImpl from mock db store comp := &discussionComponentImpl{ - rs: mockRepoStore, - us: mockUserStore, - ds: mockDiscussionStore, + repoStore: mockRepoStore, + userStore: mockUserStore, + discussionStore: mockDiscussionStore, } disc := database.Discussion{ @@ -126,9 +126,9 @@ func TestDiscussionComponent_UpdateDisussion(t *testing.T) { mockDiscussionStore := mockdb.NewMockDiscussionStore(t) // new discussionComponentImpl from mock db store comp := &discussionComponentImpl{ - rs: mockRepoStore, - us: mockUserStore, - ds: mockDiscussionStore, + repoStore: mockRepoStore, + userStore: mockUserStore, + discussionStore: mockDiscussionStore, } req := UpdateDiscussionRequest{ @@ -164,9 +164,9 @@ func TestDiscussionComponent_DeleteDisussion(t *testing.T) { mockDiscussionStore := mockdb.NewMockDiscussionStore(t) // new discussionComponentImpl from mock db store comp := &discussionComponentImpl{ - rs: mockRepoStore, - us: mockUserStore, - ds: mockDiscussionStore, + repoStore: mockRepoStore, + userStore: mockUserStore, + discussionStore: mockDiscussionStore, } currentUser := "user" @@ -197,9 +197,9 @@ func TestDiscussionComponent_ListRepoDiscussions(t *testing.T) { mockDiscussionStore := mockdb.NewMockDiscussionStore(t) // new discussionComponentImpl from mock db store comp := &discussionComponentImpl{ - rs: mockRepoStore, - us: mockUserStore, - ds: mockDiscussionStore, + repoStore: mockRepoStore, + userStore: mockUserStore, + discussionStore: mockDiscussionStore, } repo := &database.Repository{ @@ -239,9 +239,9 @@ func TestDiscussionComponent_CreateDisussionComment(t *testing.T) { mockDiscussionStore := mockdb.NewMockDiscussionStore(t) // new discussionComponentImpl from mock db store comp := &discussionComponentImpl{ - rs: mockRepoStore, - us: mockUserStore, - ds: mockDiscussionStore, + repoStore: mockRepoStore, + userStore: mockUserStore, + discussionStore: mockDiscussionStore, } req := CreateCommentRequest{ @@ -293,9 +293,9 @@ func TestDiscussionComponent_UpdateComment(t *testing.T) { mockDiscussionStore := mockdb.NewMockDiscussionStore(t) // new discussionComponentImpl from mock db store comp := &discussionComponentImpl{ - rs: mockRepoStore, - us: mockUserStore, - ds: mockDiscussionStore, + repoStore: mockRepoStore, + userStore: mockUserStore, + discussionStore: mockDiscussionStore, } req := CreateCommentRequest{ @@ -332,9 +332,9 @@ func TestDiscussionComponent_DeleteComment(t *testing.T) { mockDiscussionStore := mockdb.NewMockDiscussionStore(t) // new discussionComponentImpl from mock db store comp := &discussionComponentImpl{ - rs: mockRepoStore, - us: mockUserStore, - ds: mockDiscussionStore, + repoStore: mockRepoStore, + userStore: mockUserStore, + discussionStore: mockDiscussionStore, } req := CreateCommentRequest{ @@ -371,9 +371,9 @@ func TestDiscussionComponent_ListDiscussionComments(t *testing.T) { mockDiscussionStore := mockdb.NewMockDiscussionStore(t) // new discussionComponentImpl from mock db store comp := &discussionComponentImpl{ - rs: mockRepoStore, - us: mockUserStore, - ds: mockDiscussionStore, + repoStore: mockRepoStore, + userStore: mockUserStore, + discussionStore: mockDiscussionStore, } discussionID := int64(1) diff --git a/component/git_http.go b/component/git_http.go index d1eb5305..4decc37a 100644 --- a/component/git_http.go +++ b/component/git_http.go @@ -26,13 +26,14 @@ import ( ) type gitHTTPComponentImpl struct { - git gitserver.GitServer + gitServer gitserver.GitServer config *config.Config s3Client s3.Client lfsMetaObjectStore database.LfsMetaObjectStore lfsLockStore database.LfsLockStore - repo database.RepoStore - *repoComponentImpl + repoStore database.RepoStore + userStore database.UserStore + repoComponent RepoComponent } type GitHTTPComponent interface { @@ -53,7 +54,7 @@ func NewGitHTTPComponent(config *config.Config) (GitHTTPComponent, error) { c := &gitHTTPComponentImpl{} c.config = config var err error - c.git, err = git.NewGitServer(config) + c.gitServer, err = git.NewGitServer(config) if err != nil { newError := fmt.Errorf("fail to create git server,error:%w", err) slog.Error(newError.Error()) @@ -66,9 +67,10 @@ func NewGitHTTPComponent(config *config.Config) (GitHTTPComponent, error) { return nil, newError } c.lfsMetaObjectStore = database.NewLfsMetaObjectStore() - c.repo = database.NewRepoStore() + c.repoStore = database.NewRepoStore() c.lfsLockStore = database.NewLfsLockStore() - c.repoComponentImpl, err = NewRepoComponentImpl(config) + c.userStore = database.NewUserStore() + c.repoComponent, err = NewRepoComponentImpl(config) if err != nil { return nil, err } @@ -76,13 +78,13 @@ func NewGitHTTPComponent(config *config.Config) (GitHTTPComponent, error) { } func (c *gitHTTPComponentImpl) InfoRefs(ctx context.Context, req types.InfoRefsReq) (io.Reader, error) { - repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) } if req.Rpc == "git-receive-pack" { - allowed, err := c.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { return nil, ErrUnauthorized } @@ -91,7 +93,7 @@ func (c *gitHTTPComponentImpl) InfoRefs(ctx context.Context, req types.InfoRefsR } } else { if repo.Private { - allowed, err := c.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { return nil, ErrUnauthorized } @@ -101,7 +103,7 @@ func (c *gitHTTPComponentImpl) InfoRefs(ctx context.Context, req types.InfoRefsR } } - reader, err := c.git.InfoRefsResponse(ctx, gitserver.InfoRefsReq{ + reader, err := c.gitServer.InfoRefsResponse(ctx, gitserver.InfoRefsReq{ Namespace: req.Namespace, Name: req.Name, Rpc: req.Rpc, @@ -113,13 +115,13 @@ func (c *gitHTTPComponentImpl) InfoRefs(ctx context.Context, req types.InfoRefsR } func (c *gitHTTPComponentImpl) GitUploadPack(ctx context.Context, req types.GitUploadPackReq) error { - repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return fmt.Errorf("failed to find repo, error: %w", err) } if repo.Private { - allowed, err := c.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { return ErrUnauthorized } @@ -127,7 +129,7 @@ func (c *gitHTTPComponentImpl) GitUploadPack(ctx context.Context, req types.GitU return ErrForbidden } } - err = c.git.UploadPack(ctx, gitserver.UploadPackReq{ + err = c.gitServer.UploadPack(ctx, gitserver.UploadPackReq{ Namespace: req.Namespace, Name: req.Name, Request: req.Request, @@ -140,7 +142,7 @@ func (c *gitHTTPComponentImpl) GitUploadPack(ctx context.Context, req types.GitU } func (c *gitHTTPComponentImpl) GitReceivePack(ctx context.Context, req types.GitReceivePackReq) error { - _, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + _, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return fmt.Errorf("failed to find repo, error: %w", err) } @@ -150,14 +152,14 @@ func (c *gitHTTPComponentImpl) GitReceivePack(ctx context.Context, req types.Git return ErrUnauthorized } - allowed, err := c.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { return ErrUnauthorized } if !allowed { return ErrForbidden } - err = c.git.ReceivePack(ctx, gitserver.ReceivePackReq{ + err = c.gitServer.ReceivePack(ctx, gitserver.ReceivePackReq{ Namespace: req.Namespace, Name: req.Name, Request: req.Request, @@ -176,7 +178,7 @@ func (c *gitHTTPComponentImpl) BuildObjectResponse(ctx context.Context, req type respObjects []*types.ObjectResponse exists bool ) - repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) } @@ -223,7 +225,7 @@ func (c *gitHTTPComponentImpl) BuildObjectResponse(ctx context.Context, req type // } if exists && lfsMetaObject == nil { - allowed, err := c.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { slog.Error("unable to check if user can wirte this repo", slog.String("lfs oid", obj.Oid), slog.Any("error", err)) return nil, ErrUnauthorized @@ -307,7 +309,7 @@ func (c *gitHTTPComponentImpl) buildObjectResponse(ctx context.Context, req type func (c *gitHTTPComponentImpl) LfsUpload(ctx context.Context, body io.ReadCloser, req types.UploadRequest) error { var exists bool - repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return fmt.Errorf("failed to find repo, error: %w", err) } @@ -332,7 +334,7 @@ func (c *gitHTTPComponentImpl) LfsUpload(ctx context.Context, body io.ReadCloser } uploadOrVerify := func() error { if exists { - allowed, err := c.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { slog.Error("Unable to check if LFS MetaObject [%s] is allowed. Error: %v", pointer.Oid, err) return err @@ -414,7 +416,7 @@ func (c *gitHTTPComponentImpl) LfsUpload(ctx context.Context, body io.ReadCloser } func (c *gitHTTPComponentImpl) LfsVerify(ctx context.Context, req types.VerifyRequest, p types.Pointer) error { - repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return fmt.Errorf("failed to find repo, error: %w", err) } @@ -451,7 +453,7 @@ func (c *gitHTTPComponentImpl) CreateLock(ctx context.Context, req types.LfsLock var ( lock *database.LfsLock ) - repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) } @@ -461,7 +463,7 @@ func (c *gitHTTPComponentImpl) CreateLock(ctx context.Context, req types.LfsLock return nil, ErrUnauthorized } - allowed, err := c.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { slog.Error("Unable to check user write access:", slog.Any("error", err)) return nil, err @@ -492,7 +494,7 @@ func (c *gitHTTPComponentImpl) CreateLock(ctx context.Context, req types.LfsLock } func (c *gitHTTPComponentImpl) ListLocks(ctx context.Context, req types.ListLFSLockReq) (*types.LFSLockList, error) { - repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) } @@ -502,7 +504,7 @@ func (c *gitHTTPComponentImpl) ListLocks(ctx context.Context, req types.ListLFSL return nil, ErrUnauthorized } - allowed, err := c.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { slog.Error("Unable to check user write access:", slog.Any("error", err)) return nil, err @@ -557,7 +559,7 @@ func (c *gitHTTPComponentImpl) UnLock(ctx context.Context, req types.UnlockLFSRe lock *database.LfsLock err error ) - _, err = c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + _, err = c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) } @@ -567,7 +569,7 @@ func (c *gitHTTPComponentImpl) UnLock(ctx context.Context, req types.UnlockLFSRe return nil, ErrUnauthorized } - allowed, err := c.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { slog.Error("Unable to check user write access:", slog.Any("error", err)) return nil, err @@ -602,7 +604,7 @@ func (c *gitHTTPComponentImpl) VerifyLock(ctx context.Context, req types.VerifyL theirLocks []*types.LFSLock res types.LFSLockListVerify ) - repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) } @@ -612,7 +614,7 @@ func (c *gitHTTPComponentImpl) VerifyLock(ctx context.Context, req types.VerifyL return nil, ErrUnauthorized } - allowed, err := c.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { slog.Error("Unable to check user write access:", slog.Any("error", err)) return nil, err @@ -656,11 +658,11 @@ func (c *gitHTTPComponentImpl) VerifyLock(ctx context.Context, req types.VerifyL func (c *gitHTTPComponentImpl) LfsDownload(ctx context.Context, req types.DownloadRequest) (*url.URL, error) { pointer := types.Pointer{Oid: req.Oid} - repo, err := c.repo.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) + repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find repo, error: %w", err) } - allowed, err := c.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) + allowed, err := c.repoComponent.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, req.CurrentUser) if err != nil { return nil, fmt.Errorf("failed to check allowed, error: %w", err) } diff --git a/component/git_http_test.go b/component/git_http_test.go new file mode 100644 index 00000000..88d24547 --- /dev/null +++ b/component/git_http_test.go @@ -0,0 +1,492 @@ +package component + +import ( + "context" + "database/sql" + "errors" + "fmt" + "io" + "net/url" + "testing" + + "github.com/minio/minio-go/v7" + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestGitHTTPComponent_InfoRefs(t *testing.T) { + + cases := []struct { + rpc string + private bool + }{ + {"foo", true}, + {"git-receive-pack", false}, + {"foo", false}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("%+v", c), func(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + Private: c.private, + }, nil) + if c.rpc == "git-receive-pack" { + gc.mocks.components.repo.EXPECT().AllowWriteAccess(ctx, types.ModelRepo, "ns", "n", "user").Return(true, nil) + } + if c.private { + gc.mocks.components.repo.EXPECT().AllowReadAccess(ctx, types.ModelRepo, "ns", "n", "user").Return(true, nil) + } + + gc.mocks.gitServer.EXPECT().InfoRefsResponse(ctx, gitserver.InfoRefsReq{ + Namespace: "ns", + Name: "n", + Rpc: c.rpc, + RepoType: types.ModelRepo, + GitProtocol: "", + }).Return(nil, nil) + + r, err := gc.InfoRefs(ctx, types.InfoRefsReq{ + Namespace: "ns", + Name: "n", + Rpc: c.rpc, + RepoType: types.ModelRepo, + CurrentUser: "user", + }) + require.Nil(t, err) + require.Equal(t, nil, r) + + }) + } + +} + +func TestGitHTTPComponent_GitUploadPack(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + Private: true, + }, nil) + gc.mocks.components.repo.EXPECT().AllowReadAccess(ctx, types.ModelRepo, "ns", "n", "user").Return(true, nil) + gc.mocks.gitServer.EXPECT().UploadPack(ctx, gitserver.UploadPackReq{ + Namespace: "ns", + Name: "n", + Request: nil, + RepoType: types.ModelRepo, + }).Return(nil) + err := gc.GitUploadPack(ctx, types.GitUploadPackReq{ + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + CurrentUser: "user", + }) + require.Nil(t, err) + +} + +func TestGitHTTPComponent_GitReceivePack(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + Private: true, + }, nil) + gc.mocks.components.repo.EXPECT().AllowWriteAccess(ctx, types.ModelRepo, "ns", "n", "user").Return(true, nil) + gc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{}, nil) + gc.mocks.gitServer.EXPECT().ReceivePack(ctx, gitserver.UploadPackReq{ + Namespace: "ns", + Name: "n", + Request: nil, + RepoType: types.ModelRepo, + }).Return(nil) + err := gc.GitReceivePack(ctx, types.GitUploadPackReq{ + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + CurrentUser: "user", + }) + require.Nil(t, err) + +} + +func TestGitHTTPComponent_BuildObjectResponse(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + oid1 := "a3f8e1b4f77bb24e508906c6972f81928f0d926e6daef1b29d12e348b8a3547e" + oid2 := "c39e7f5f1d61fa58ec6dbcd3b60a50870c577f0988d7c080fc88d1b460e7f5f1" + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + ID: 123, + Private: true, + }, nil) + gc.mocks.s3Client.EXPECT().StatObject( + ctx, "", + "lfs/a3/f8/e1b4f77bb24e508906c6972f81928f0d926e6daef1b29d12e348b8a3547e", + minio.StatObjectOptions{}, + ).Return( + minio.ObjectInfo{Size: 100}, nil, + ) + gc.mocks.s3Client.EXPECT().StatObject( + ctx, "", + "lfs/c3/9e/7f5f1d61fa58ec6dbcd3b60a50870c577f0988d7c080fc88d1b460e7f5f1", + minio.StatObjectOptions{}, + ).Return( + minio.ObjectInfo{Size: 100}, nil, + ) + gc.mocks.stores.LfsMetaObjectMock().EXPECT().FindByOID(ctx, int64(123), oid1).Return( + &database.LfsMetaObject{}, nil, + ) + gc.mocks.stores.LfsMetaObjectMock().EXPECT().FindByOID(ctx, int64(123), oid2).Return( + nil, nil, + ) + gc.mocks.components.repo.EXPECT().AllowWriteAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + gc.mocks.stores.LfsMetaObjectMock().EXPECT().Create(ctx, database.LfsMetaObject{ + Oid: oid2, + Size: 100, + RepositoryID: 123, + Existing: true, + }).Return(nil, nil) + + resp, err := gc.BuildObjectResponse(ctx, types.BatchRequest{ + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + CurrentUser: "user", + Objects: []types.Pointer{ + { + Oid: oid1, + Size: 5, + }, + { + Oid: oid2, + Size: 100, + }, + }, + }, true) + require.Nil(t, err) + require.Equal(t, &types.BatchResponse{ + Objects: []*types.ObjectResponse{ + { + Pointer: types.Pointer{Oid: oid1, Size: 5}, + Error: &types.ObjectError{ + Code: 422, + Message: "Object a3f8e1b4f77bb24e508906c6972f81928f0d926e6daef1b29d12e348b8a3547e is not 5 bytes", + }, + Actions: nil, + }, + { + Pointer: types.Pointer{Oid: oid2, Size: 100}, + Actions: map[string]*types.Link{}, + }, + }, + }, resp) + +} + +func TestGitHTTPComponent_LfsUpload(t *testing.T) { + + for _, exist := range []bool{false, true} { + t.Run(fmt.Sprintf("exist %v", exist), func(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + rc := io.NopCloser(&io.LimitedReader{}) + oid := "a3f8e1b4f77bb24e508906c6972f81928f0d926e6daef1b29d12e348b8a3547e" + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + ID: 123, + Private: true, + }, nil) + if exist { + gc.mocks.s3Client.EXPECT().StatObject( + ctx, "", + "lfs/a3/f8/e1b4f77bb24e508906c6972f81928f0d926e6daef1b29d12e348b8a3547e", + minio.StatObjectOptions{}, + ).Return( + minio.ObjectInfo{Size: 100}, nil, + ) + } else { + gc.mocks.s3Client.EXPECT().StatObject( + ctx, "", + "lfs/a3/f8/e1b4f77bb24e508906c6972f81928f0d926e6daef1b29d12e348b8a3547e", + minio.StatObjectOptions{}, + ).Return( + minio.ObjectInfo{Size: 100}, errors.New("zzzz"), + ) + } + + if exist { + gc.mocks.components.repo.EXPECT().AllowWriteAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + } else { + gc.mocks.s3Client.EXPECT().PutObject( + ctx, "", + "lfs/a3/f8/e1b4f77bb24e508906c6972f81928f0d926e6daef1b29d12e348b8a3547e", + rc, int64(100), minio.PutObjectOptions{ + ContentType: "application/octet-stream", + SendContentMd5: true, + ConcurrentStreamParts: true, + NumThreads: 5, + }).Return(minio.UploadInfo{Size: 100}, nil) + } + gc.mocks.stores.LfsMetaObjectMock().EXPECT().Create(ctx, database.LfsMetaObject{ + Oid: oid, + Size: 100, + RepositoryID: 123, + Existing: true, + }).Return(nil, nil) + + err := gc.LfsUpload(ctx, rc, types.UploadRequest{ + Oid: oid, + Size: 100, + CurrentUser: "user", + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + }) + require.Nil(t, err) + }) + } + +} + +func TestGitHTTPComponent_LfsVerify(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + ID: 123, + Private: true, + }, nil) + + gc.mocks.s3Client.EXPECT().StatObject(ctx, "", "lfs/oid", minio.StatObjectOptions{}).Return( + minio.ObjectInfo{Size: 100}, nil, + ) + gc.mocks.stores.LfsMetaObjectMock().EXPECT().Create(ctx, database.LfsMetaObject{ + Oid: "oid", + Size: 100, + RepositoryID: 123, + Existing: true, + }).Return(nil, nil) + + err := gc.LfsVerify(ctx, types.VerifyRequest{ + CurrentUser: "user", + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + }, types.Pointer{Oid: "oid", Size: 100}) + require.Nil(t, err) + +} + +func TestGitHTTPComponent_CreateLock(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + ID: 123, + Private: true, + }, nil) + + gc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{}, nil) + gc.mocks.components.repo.EXPECT().AllowWriteAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + lfslock := &database.LfsLock{Path: "path", RepositoryID: 123} + gc.mocks.stores.LfsLockMock().EXPECT().FindByPath(ctx, int64(123), "path").Return( + lfslock, sql.ErrNoRows, + ) + gc.mocks.stores.LfsLockMock().EXPECT().Create(ctx, *lfslock).Return(lfslock, nil) + + l, err := gc.CreateLock(ctx, types.LfsLockReq{ + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + CurrentUser: "user", + Path: "path", + }) + require.Nil(t, err) + require.Equal(t, lfslock, l) + +} + +func TestGitHTTPComponent_ListLocks(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + ID: 123, + Private: true, + }, nil) + + gc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{}, nil) + gc.mocks.components.repo.EXPECT().AllowReadAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + + gc.mocks.stores.LfsLockMock().EXPECT().FindByID(ctx, int64(123)).Return( + &database.LfsLock{ID: 11, RepositoryID: 123}, nil, + ) + gc.mocks.stores.LfsLockMock().EXPECT().FindByPath(ctx, int64(123), "foo/bar").Return( + &database.LfsLock{ID: 12, RepositoryID: 123}, nil, + ) + gc.mocks.stores.LfsLockMock().EXPECT().FindByRepoID(ctx, int64(123), 1, 10).Return( + []database.LfsLock{{ID: 13, RepositoryID: 123}}, nil, + ) + + req := types.ListLFSLockReq{ + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + CurrentUser: "user", + Cursor: 1, + Limit: 10, + } + req1 := req + req1.ID = 123 + ll, err := gc.ListLocks(ctx, req1) + require.Nil(t, err) + require.Equal(t, &types.LFSLockList{ + Locks: []*types.LFSLock{{ID: "11", Owner: &types.LFSLockOwner{}}}, + }, ll) + req2 := req + req2.Path = "foo/bar" + ll, err = gc.ListLocks(ctx, req2) + require.Nil(t, err) + require.Equal(t, &types.LFSLockList{ + Locks: []*types.LFSLock{{ID: "12", Owner: &types.LFSLockOwner{}}}, + }, ll) + ll, err = gc.ListLocks(ctx, req) + require.Nil(t, err) + require.Equal(t, &types.LFSLockList{ + Locks: []*types.LFSLock{{ID: "13", Owner: &types.LFSLockOwner{}}}, + }, ll) +} + +func TestGitHTTPComponent_UnLock(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + ID: 123, + Private: true, + }, nil) + + gc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{}, nil) + gc.mocks.components.repo.EXPECT().AllowWriteAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + + gc.mocks.stores.LfsLockMock().EXPECT().FindByID(ctx, int64(123)).Return( + &database.LfsLock{ID: 11, RepositoryID: 123}, nil, + ) + gc.mocks.stores.LfsLockMock().EXPECT().RemoveByID(ctx, int64(123)).Return(nil) + + lk, err := gc.UnLock(ctx, types.UnlockLFSReq{ + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + CurrentUser: "user", + ID: 123, + }) + require.Nil(t, err) + require.Equal(t, &database.LfsLock{ + ID: 11, + RepositoryID: 123, + }, lk) + +} + +func TestGitHTTPComponent_VerifyLock(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + ID: 123, + Private: true, + }, nil) + + gc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{}, nil) + gc.mocks.components.repo.EXPECT().AllowReadAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + + gc.mocks.stores.LfsLockMock().EXPECT().FindByRepoID(ctx, int64(123), 10, 1).Return( + []database.LfsLock{{ID: 11, RepositoryID: 123, User: database.User{Username: "zzz"}}}, nil, + ) + + lk, err := gc.VerifyLock(ctx, types.VerifyLFSLockReq{ + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + CurrentUser: "user", + Cursor: 10, + Limit: 1, + }) + require.Nil(t, err) + require.Equal(t, &types.LFSLockListVerify{ + Ours: []*types.LFSLock{{ID: "11", Owner: &types.LFSLockOwner{Name: "zzz"}}}, + Next: "11", + }, lk) + +} + +func TestGitHTTPComponent_LfsDownload(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitHTTPComponent(ctx, t) + + gc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Repository{ + ID: 123, + Private: true, + }, nil) + + gc.mocks.components.repo.EXPECT().AllowReadAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + gc.mocks.stores.LfsMetaObjectMock().EXPECT().FindByOID(ctx, int64(123), "oid").Return(nil, nil) + reqParams := make(url.Values) + reqParams.Set("response-content-disposition", fmt.Sprintf("attachment;filename=%s", "sa")) + url := &url.URL{Scheme: "http"} + gc.mocks.s3Client.EXPECT().PresignedGetObject(ctx, "", "lfs/oid", ossFileExpireSeconds, reqParams).Return(url, nil) + + u, err := gc.LfsDownload(ctx, types.DownloadRequest{ + Oid: "oid", + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + SaveAs: "sa", + CurrentUser: "user", + }) + require.Nil(t, err) + require.Equal(t, url, u) + +} diff --git a/component/internal.go b/component/internal.go index d23f28ec..07af29ca 100644 --- a/component/internal.go +++ b/component/internal.go @@ -8,6 +8,7 @@ import ( "strconv" pb "gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb" + "opencsg.com/csghub-server/builder/git" "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/git/gitserver/gitaly" "opencsg.com/csghub-server/builder/store/database" @@ -17,10 +18,13 @@ import ( ) type internalComponentImpl struct { - config *config.Config - sshKeyStore database.SSHKeyStore - repoStore database.RepoStore - *repoComponentImpl + config *config.Config + sshKeyStore database.SSHKeyStore + repoStore database.RepoStore + tokenStore database.AccessTokenStore + namespaceStore database.NamespaceStore + repoComponent RepoComponent + gitServer gitserver.GitServer } type InternalComponent interface { @@ -37,11 +41,17 @@ func NewInternalComponent(config *config.Config) (InternalComponent, error) { c.config = config c.sshKeyStore = database.NewSSHKeyStore() c.repoStore = database.NewRepoStore() - c.repoComponentImpl, err = NewRepoComponentImpl(config) + c.repoComponent, err = NewRepoComponentImpl(config) c.tokenStore = database.NewAccessTokenStore() + c.namespaceStore = database.NewNamespaceStore() if err != nil { return nil, err } + git, err := git.NewGitServer(config) + if err != nil { + return nil, fmt.Errorf("failed to create git server: %w", err) + } + c.gitServer = git return c, nil } @@ -70,7 +80,7 @@ func (c *internalComponentImpl) SSHAllowed(ctx context.Context, req types.SSHAll return nil, fmt.Errorf("failed to find ssh key by id, err: %v", err) } if req.Action == "git-receive-pack" { - allowed, err := c.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, sshKey.User.Username) + allowed, err := c.repoComponent.AllowWriteAccess(ctx, req.RepoType, req.Namespace, req.Name, sshKey.User.Username) if err != nil { return nil, ErrUnauthorized } @@ -79,7 +89,7 @@ func (c *internalComponentImpl) SSHAllowed(ctx context.Context, req types.SSHAll } } else if req.Action == "git-upload-pack" { if repo.Private { - allowed, err := c.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, sshKey.User.Username) + allowed, err := c.repoComponent.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, sshKey.User.Username) if err != nil { return nil, ErrUnauthorized } @@ -133,7 +143,7 @@ func (c *internalComponentImpl) GetCommitDiff(ctx context.Context, req types.Get if repo == nil { return nil, errors.New("repo not found") } - diffs, err := c.git.GetDiffBetweenTwoCommits(ctx, gitserver.GetDiffBetweenTwoCommitsReq{ + diffs, err := c.gitServer.GetDiffBetweenTwoCommits(ctx, gitserver.GetDiffBetweenTwoCommitsReq{ Namespace: req.Namespace, Name: req.Name, RepoType: req.RepoType, @@ -165,7 +175,7 @@ func (c *internalComponentImpl) LfsAuthenticate(ctx context.Context, req types.L return nil, fmt.Errorf("failed to find ssh key by id, err: %v", err) } if repo.Private { - allowed, err := c.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, sshKey.User.Username) + allowed, err := c.repoComponent.AllowReadAccess(ctx, req.RepoType, req.Namespace, req.Name, sshKey.User.Username) if err != nil { return nil, ErrUnauthorized } diff --git a/component/internal_test.go b/component/internal_test.go new file mode 100644 index 00000000..4935fac9 --- /dev/null +++ b/component/internal_test.go @@ -0,0 +1,153 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + pb "gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestInternalComponent_Allowed(t *testing.T) { + ctx := context.TODO() + ic := initializeTestInternalComponent(ctx, t) + + allowed, err := ic.Allowed(ctx) + require.Nil(t, err) + require.True(t, allowed) +} + +func TestInternalComponent_SSHAllowed(t *testing.T) { + ctx := context.TODO() + ic := initializeTestInternalComponent(ctx, t) + + ic.mocks.stores.NamespaceMock().EXPECT().FindByPath(ctx, "ns").Return(database.Namespace{ + ID: 321, + }, nil) + ic.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.ModelRepo, "ns", "n").Return( + &database.Repository{ID: 123, Private: true}, nil, + ) + ic.mocks.stores.SSHMock().EXPECT().FindByID(ctx, int64(1)).Return(&database.SSHKey{ + ID: 111, + User: &database.User{ID: 11, Username: "user"}, + }, nil) + ic.mocks.components.repo.EXPECT().AllowWriteAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + ic.mocks.components.repo.EXPECT().AllowReadAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + + req := types.SSHAllowedReq{ + RepoType: types.ModelRepo, + Namespace: "ns", + Name: "n", + KeyID: "1", + Action: "git-receive-pack", + } + resp, err := ic.SSHAllowed(ctx, req) + require.Nil(t, err) + expected := &types.SSHAllowedResp{ + Success: true, + Message: "allowed", + Repo: req.Repo, + UserID: "11", + KeyType: "ssh", + KeyID: 111, + ProjectID: 123, + RootNamespaceID: 321, + GitConfigOptions: []string{"uploadpack.allowFilter=true", "uploadpack.allowAnySHA1InWant=true"}, + Gitaly: types.Gitaly{ + Repo: pb.Repository{ + RelativePath: "models_ns/n.git", + GlRepository: "models/ns/n", + }, + }, + StatusCode: 200, + } + + require.Equal(t, expected, resp) + + req.Action = "git-upload-pack" + resp, err = ic.SSHAllowed(ctx, req) + require.Nil(t, err) + require.Equal(t, expected, resp) + +} + +func TestInternalComponent_GetAuthorizedKeys(t *testing.T) { + ctx := context.TODO() + ic := initializeTestInternalComponent(ctx, t) + + ic.mocks.stores.SSHMock().EXPECT().FindByFingerpringSHA256( + ctx, "dUQ5GwtKsCPC8Scv1OLnOEvIW0QWULVSWyj5bZwQHwM", + ).Return(&database.SSHKey{}, nil) + key, err := ic.GetAuthorizedKeys(ctx, "foobar") + require.Nil(t, err) + require.Equal(t, &database.SSHKey{}, key) +} + +func TestInternalComponent_GetCommitDiff(t *testing.T) { + ctx := context.TODO() + ic := initializeTestInternalComponent(ctx, t) + + req := types.GetDiffBetweenTwoCommitsReq{ + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + Ref: "main", + LeftCommitId: "l", + RightCommitId: "r", + } + ic.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.ModelRepo, "ns", "n").Return( + &database.Repository{}, nil, + ) + ic.mocks.gitServer.EXPECT().GetDiffBetweenTwoCommits(ctx, gitserver.GetDiffBetweenTwoCommitsReq{ + Namespace: req.Namespace, + Name: req.Name, + RepoType: req.RepoType, + Ref: req.Ref, + LeftCommitId: req.LeftCommitId, + RightCommitId: req.RightCommitId, + }).Return(&types.GiteaCallbackPushReq{Ref: "main"}, nil) + + resp, err := ic.GetCommitDiff(ctx, req) + require.Nil(t, err) + require.Equal(t, &types.GiteaCallbackPushReq{Ref: "main"}, resp) +} + +func TestInternalComponent_LfsAuthenticate(t *testing.T) { + ctx := context.TODO() + ic := initializeTestInternalComponent(ctx, t) + + ic.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.ModelRepo, "ns", "n").Return( + &database.Repository{Private: true}, nil, + ) + ic.mocks.stores.SSHMock().EXPECT().FindByID(ctx, int64(1)).Return(&database.SSHKey{ + ID: 111, + User: &database.User{ID: 11, Username: "user"}, + }, nil) + ic.mocks.components.repo.EXPECT().AllowReadAccess( + ctx, types.ModelRepo, "ns", "n", "user", + ).Return(true, nil) + ic.mocks.stores.AccessTokenMock().EXPECT().GetUserGitToken(ctx, "user").Return( + &database.AccessToken{Token: "token"}, nil, + ) + + resp, err := ic.LfsAuthenticate(ctx, types.LfsAuthenticateReq{ + Namespace: "ns", + Name: "n", + RepoType: types.ModelRepo, + KeyID: "1", + }) + require.Nil(t, err) + require.Equal(t, &types.LfsAuthenticateResp{ + Username: "user", + LfsToken: "token", + RepoPath: "/models/ns/n.git", + }, resp) + +} diff --git a/component/mirror_source.go b/component/mirror_source.go index 136c16a0..b136e29d 100644 --- a/component/mirror_source.go +++ b/component/mirror_source.go @@ -11,8 +11,8 @@ import ( ) type mirrorSourceComponentImpl struct { - msStore database.MirrorSourceStore - userStore database.UserStore + mirrorSourceStore database.MirrorSourceStore + userStore database.UserStore } type MirrorSourceComponent interface { @@ -25,8 +25,8 @@ type MirrorSourceComponent interface { func NewMirrorSourceComponent(config *config.Config) (MirrorSourceComponent, error) { return &mirrorSourceComponentImpl{ - msStore: database.NewMirrorSourceStore(), - userStore: database.NewUserStore(), + mirrorSourceStore: database.NewMirrorSourceStore(), + userStore: database.NewUserStore(), }, nil } @@ -41,7 +41,7 @@ func (c *mirrorSourceComponentImpl) Create(ctx context.Context, req types.Create } ms.SourceName = req.SourceName ms.InfoAPIUrl = req.InfoAPiUrl - res, err := c.msStore.Create(ctx, &ms) + res, err := c.mirrorSourceStore.Create(ctx, &ms) if err != nil { return nil, fmt.Errorf("failed to create mirror source, error: %w", err) } @@ -56,7 +56,7 @@ func (c *mirrorSourceComponentImpl) Get(ctx context.Context, id int64, currentUs if !user.CanAdmin() { return nil, errors.New("user does not have admin permission") } - ms, err := c.msStore.Get(ctx, id) + ms, err := c.mirrorSourceStore.Get(ctx, id) if err != nil { return nil, fmt.Errorf("failed to get mirror source, error: %w", err) } @@ -71,7 +71,7 @@ func (c *mirrorSourceComponentImpl) Index(ctx context.Context, currentUser strin if !user.CanAdmin() { return nil, errors.New("user does not have admin permission") } - ms, err := c.msStore.Index(ctx) + ms, err := c.mirrorSourceStore.Index(ctx) if err != nil { return nil, fmt.Errorf("failed to get mirror source, error: %w", err) } @@ -89,7 +89,7 @@ func (c *mirrorSourceComponentImpl) Update(ctx context.Context, req types.Update ms.ID = req.ID ms.SourceName = req.SourceName ms.InfoAPIUrl = req.InfoAPiUrl - err = c.msStore.Update(ctx, &ms) + err = c.mirrorSourceStore.Update(ctx, &ms) if err != nil { return nil, fmt.Errorf("failed to update mirror source, error: %w", err) } @@ -104,11 +104,11 @@ func (c *mirrorSourceComponentImpl) Delete(ctx context.Context, id int64, curren if !user.CanAdmin() { return errors.New("user does not have admin permission") } - ms, err := c.msStore.Get(ctx, id) + ms, err := c.mirrorSourceStore.Get(ctx, id) if err != nil { return fmt.Errorf("failed to find mirror source, error: %w", err) } - err = c.msStore.Delete(ctx, ms) + err = c.mirrorSourceStore.Delete(ctx, ms) if err != nil { return fmt.Errorf("failed to delete mirror source, error: %w", err) } diff --git a/component/mirror_source_test.go b/component/mirror_source_test.go new file mode 100644 index 00000000..ccc3ec02 --- /dev/null +++ b/component/mirror_source_test.go @@ -0,0 +1,104 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestMirrorSourceComponent_Create(t *testing.T) { + ctx := context.TODO() + mc := initializeTestMirrorSourceComponent(ctx, t) + + mc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + RoleMask: "admin", + }, nil) + mc.mocks.stores.MirrorSourceMock().EXPECT().Create(ctx, &database.MirrorSource{ + SourceName: "sn", + InfoAPIUrl: "url", + }).Return(&database.MirrorSource{ID: 1}, nil) + + data, err := mc.Create(ctx, types.CreateMirrorSourceReq{ + SourceName: "sn", + InfoAPiUrl: "url", + CurrentUser: "user", + }) + require.Nil(t, err) + require.Equal(t, &database.MirrorSource{ID: 1}, data) +} + +func TestMirrorSourceComponent_Get(t *testing.T) { + ctx := context.TODO() + mc := initializeTestMirrorSourceComponent(ctx, t) + + mc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + RoleMask: "admin", + }, nil) + mc.mocks.stores.MirrorSourceMock().EXPECT().Get(ctx, int64(1)).Return(&database.MirrorSource{ID: 1}, nil) + + data, err := mc.Get(ctx, 1, "user") + require.Nil(t, err) + require.Equal(t, &database.MirrorSource{ID: 1}, data) +} + +func TestMirrorSourceComponent_Index(t *testing.T) { + ctx := context.TODO() + mc := initializeTestMirrorSourceComponent(ctx, t) + + mc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + RoleMask: "admin", + }, nil) + mc.mocks.stores.MirrorSourceMock().EXPECT().Index(ctx).Return([]database.MirrorSource{ + {ID: 1}, + }, nil) + + data, err := mc.Index(ctx, "user") + require.Nil(t, err) + require.Equal(t, []database.MirrorSource{ + {ID: 1}, + }, data) +} + +func TestMirrorSourceComponent_Update(t *testing.T) { + ctx := context.TODO() + mc := initializeTestMirrorSourceComponent(ctx, t) + + mc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + RoleMask: "admin", + }, nil) + mc.mocks.stores.MirrorSourceMock().EXPECT().Update(ctx, &database.MirrorSource{ + ID: 1, + SourceName: "sn", + InfoAPIUrl: "url", + }).Return(nil) + + data, err := mc.Update(ctx, types.UpdateMirrorSourceReq{ + ID: 1, + SourceName: "sn", + InfoAPiUrl: "url", + CurrentUser: "user", + }) + require.Nil(t, err) + require.Equal(t, &database.MirrorSource{ + ID: 1, + SourceName: "sn", + InfoAPIUrl: "url", + }, data) +} + +func TestMirrorSourceComponent_Delete(t *testing.T) { + ctx := context.TODO() + mc := initializeTestMirrorSourceComponent(ctx, t) + + mc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + RoleMask: "admin", + }, nil) + mc.mocks.stores.MirrorSourceMock().EXPECT().Get(ctx, int64(1)).Return(&database.MirrorSource{ID: 1}, nil) + mc.mocks.stores.MirrorSourceMock().EXPECT().Delete(ctx, &database.MirrorSource{ID: 1}).Return(nil) + + err := mc.Delete(ctx, 1, "user") + require.Nil(t, err) +} diff --git a/component/model_test.go b/component/model_test.go index ea557131..3f49f915 100644 --- a/component/model_test.go +++ b/component/model_test.go @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/require" "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/git/membership" - "opencsg.com/csghub-server/builder/inference" "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/types" ) @@ -458,31 +457,22 @@ func TestModelComponent_DeleteRelationDataset(t *testing.T) { require.Nil(t, err) } -func TestModelComponent_Predict(t *testing.T) { - ctx := context.TODO() - mc := initializeTestModelComponent(ctx, t) - - mc.mocks.inferenceClient.EXPECT().Predict(inference.ModelID{ - Owner: "ns", - Name: "n", - }, &inference.PredictRequest{ - Prompt: "foo", - }).Return(&inference.PredictResponse{ - GeneratedText: "abcd", - }, nil) +// func TestModelComponent_Predict(t *testing.T) { +// ctx := context.TODO() +// mc := initializeTestModelComponent(ctx, t) - resp, err := mc.Predict(ctx, &types.ModelPredictReq{ - Namespace: "ns", - Name: "n", - Input: "foo", - CurrentUser: "user", - }) - require.Nil(t, err) - require.Equal(t, &types.ModelPredictResp{ - Content: "abcd", - }, resp) +// resp, err := mc.Predict(ctx, &types.ModelPredictReq{ +// Namespace: "ns", +// Name: "n", +// Input: "foo", +// CurrentUser: "user", +// }) +// require.Nil(t, err) +// require.Equal(t, &types.ModelPredictResp{ +// Content: "abcd", +// }, resp) -} +// } // func TestModelComponent_Deploy(t *testing.T) { // ctx := context.TODO() diff --git a/component/multi_sync.go b/component/multi_sync.go index edb44597..ddd08285 100644 --- a/component/multi_sync.go +++ b/component/multi_sync.go @@ -20,16 +20,16 @@ import ( ) type multiSyncComponentImpl struct { - s database.MultiSyncStore - repo database.RepoStore - model database.ModelStore - dataset database.DatasetStore - namespace database.NamespaceStore - user database.UserStore - versionStore database.SyncVersionStore - tag database.TagStore - file database.FileStore - git gitserver.GitServer + multiSyncStore database.MultiSyncStore + repoStore database.RepoStore + modelStore database.ModelStore + datasetStore database.DatasetStore + namespaceStore database.NamespaceStore + userStore database.UserStore + syncVersionStore database.SyncVersionStore + tagStore database.TagStore + fileStore database.FileStore + gitServer gitserver.GitServer } type MultiSyncComponent interface { @@ -43,21 +43,21 @@ func NewMultiSyncComponent(config *config.Config) (MultiSyncComponent, error) { return nil, fmt.Errorf("failed to create git server: %w", err) } return &multiSyncComponentImpl{ - s: database.NewMultiSyncStore(), - repo: database.NewRepoStore(), - model: database.NewModelStore(), - dataset: database.NewDatasetStore(), - namespace: database.NewNamespaceStore(), - user: database.NewUserStore(), - versionStore: database.NewSyncVersionStore(), - tag: database.NewTagStore(), - file: database.NewFileStore(), - git: git, + multiSyncStore: database.NewMultiSyncStore(), + repoStore: database.NewRepoStore(), + modelStore: database.NewModelStore(), + datasetStore: database.NewDatasetStore(), + namespaceStore: database.NewNamespaceStore(), + userStore: database.NewUserStore(), + syncVersionStore: database.NewSyncVersionStore(), + tagStore: database.NewTagStore(), + fileStore: database.NewFileStore(), + gitServer: git, }, nil } func (c *multiSyncComponentImpl) More(ctx context.Context, cur int64, limit int64) ([]types.SyncVersion, error) { - dbVersions, err := c.s.GetAfter(ctx, cur, limit) + dbVersions, err := c.multiSyncStore.GetAfter(ctx, cur, limit) if err != nil { return nil, fmt.Errorf("failed to get sync versions after %d from db: %w", cur, err) } @@ -77,7 +77,7 @@ func (c *multiSyncComponentImpl) More(ctx context.Context, cur int64, limit int6 func (c *multiSyncComponentImpl) SyncAsClient(ctx context.Context, sc multisync.Client) error { var currentVersion int64 - v, err := c.s.GetLatest(ctx) + v, err := c.multiSyncStore.GetLatest(ctx) if err != nil { if err != sql.ErrNoRows { return fmt.Errorf("failed to get latest sync version from db: %w", err) @@ -108,7 +108,7 @@ func (c *multiSyncComponentImpl) SyncAsClient(ctx context.Context, sc multisync. } } - syncVersions, err := c.s.GetAfterDistinct(ctx, v.Version) + syncVersions, err := c.multiSyncStore.GetAfterDistinct(ctx, v.Version) if err != nil { slog.Error("failed to find distinct sync versions", slog.Any("error", err)) return err @@ -214,7 +214,7 @@ func (c *multiSyncComponentImpl) createLocalDataset(ctx context.Context, m *type // HTTPCloneURL: gitRepo.HttpCloneURL, // SSHCloneURL: gitRepo.SshCloneURL, } - newDBRepo, err := c.repo.UpdateOrCreateRepo(ctx, dbRepo) + newDBRepo, err := c.repoStore.UpdateOrCreateRepo(ctx, dbRepo) if err != nil { return fmt.Errorf("fail to create database repo, error: %w", err) } @@ -230,7 +230,7 @@ func (c *multiSyncComponentImpl) createLocalDataset(ctx context.Context, m *type ShowName: tag.ShowName, Scope: database.DatasetTagScope, } - t, err := c.tag.FindOrCreate(ctx, dbTag) + t, err := c.tagStore.FindOrCreate(ctx, dbTag) if err != nil { slog.Error("failed to create or find database tag", slog.Any("tag", dbTag)) continue @@ -241,18 +241,18 @@ func (c *multiSyncComponentImpl) createLocalDataset(ctx context.Context, m *type }) } - err = c.repo.DeleteAllTags(ctx, newDBRepo.ID) + err = c.repoStore.DeleteAllTags(ctx, newDBRepo.ID) if err != nil && err != sql.ErrNoRows { slog.Error("failed to delete database tag", slog.Any("error", err)) } - err = c.repo.BatchCreateRepoTags(ctx, repoTags) + err = c.repoStore.BatchCreateRepoTags(ctx, repoTags) if err != nil { slog.Error("failed to create database tag", slog.Any("error", err)) } } - err = c.repo.DeleteAllFiles(ctx, newDBRepo.ID) + err = c.repoStore.DeleteAllFiles(ctx, newDBRepo.ID) if err != nil && err != sql.ErrNoRows { slog.Error("failed to delete database files", slog.Any("error", err)) } @@ -277,7 +277,7 @@ func (c *multiSyncComponentImpl) createLocalDataset(ctx context.Context, m *type }) } - err = c.file.BatchCreate(ctx, dbFiles) + err = c.fileStore.BatchCreate(ctx, dbFiles) if err != nil { slog.Error("failed to create all files of repo", slog.Any("sync_version", s)) } @@ -288,7 +288,7 @@ func (c *multiSyncComponentImpl) createLocalDataset(ctx context.Context, m *type Repository: newDBRepo, RepositoryID: newDBRepo.ID, } - _, err = c.dataset.CreateIfNotExist(ctx, dbDataset) + _, err = c.datasetStore.CreateIfNotExist(ctx, dbDataset) if err != nil { return fmt.Errorf("failed to create dataset in db, cause: %w", err) } @@ -340,7 +340,7 @@ func (c *multiSyncComponentImpl) createLocalModel(ctx context.Context, m *types. // HTTPCloneURL: gitRepo.HttpCloneURL, // SSHCloneURL: gitRepo.SshCloneURL, } - newDBRepo, err := c.repo.UpdateOrCreateRepo(ctx, dbRepo) + newDBRepo, err := c.repoStore.UpdateOrCreateRepo(ctx, dbRepo) if err != nil { return fmt.Errorf("fail to create database repo, error: %w", err) } @@ -356,7 +356,7 @@ func (c *multiSyncComponentImpl) createLocalModel(ctx context.Context, m *types. ShowName: tag.ShowName, Scope: database.ModelTagScope, } - t, err := c.tag.FindOrCreate(ctx, dbTag) + t, err := c.tagStore.FindOrCreate(ctx, dbTag) if err != nil { slog.Error("failed to create or find database tag", slog.Any("tag", dbTag)) continue @@ -366,17 +366,17 @@ func (c *multiSyncComponentImpl) createLocalModel(ctx context.Context, m *types. TagID: t.ID, }) } - err = c.repo.DeleteAllTags(ctx, newDBRepo.ID) + err = c.repoStore.DeleteAllTags(ctx, newDBRepo.ID) if err != nil && err != sql.ErrNoRows { slog.Error("failed to delete database tag", slog.Any("error", err)) } - err = c.repo.BatchCreateRepoTags(ctx, repoTags) + err = c.repoStore.BatchCreateRepoTags(ctx, repoTags) if err != nil { slog.Error("failed to batch create database tag", slog.Any("error", err)) } } - err = c.repo.DeleteAllFiles(ctx, newDBRepo.ID) + err = c.repoStore.DeleteAllFiles(ctx, newDBRepo.ID) if err != nil && err != sql.ErrNoRows { slog.Error("failed to delete all files for repo", slog.Any("error", err)) } @@ -401,7 +401,7 @@ func (c *multiSyncComponentImpl) createLocalModel(ctx context.Context, m *types. }) } - err = c.file.BatchCreate(ctx, dbFiles) + err = c.fileStore.BatchCreate(ctx, dbFiles) if err != nil { slog.Error("failed to create all files of repo", slog.Any("sync_version", s)) } @@ -413,7 +413,7 @@ func (c *multiSyncComponentImpl) createLocalModel(ctx context.Context, m *types. RepositoryID: newDBRepo.ID, BaseModel: m.BaseModel, } - _, err = c.model.CreateIfNotExist(ctx, dbModel) + _, err = c.modelStore.CreateIfNotExist(ctx, dbModel) if err != nil { return fmt.Errorf("failed to create database model, cause: %w", err) } @@ -426,7 +426,7 @@ func (c *multiSyncComponentImpl) createUser(ctx context.Context, req types.Creat Username: req.Username, Email: req.Email, } - gsUserResp, err := c.git.CreateUser(gsUserReq) + gsUserResp, err := c.gitServer.CreateUser(gsUserReq) if err != nil { newError := fmt.Errorf("failed to create gitserver user,error:%w", err) return database.User{}, newError @@ -443,7 +443,7 @@ func (c *multiSyncComponentImpl) createUser(ctx context.Context, req types.Creat GitID: gsUserResp.GitID, Password: gsUserResp.Password, } - err = c.user.Create(ctx, user, namespace) + err = c.userStore.Create(ctx, user, namespace) if err != nil { newError := fmt.Errorf("failed to create user,error:%w", err) return database.User{}, newError @@ -453,7 +453,7 @@ func (c *multiSyncComponentImpl) createUser(ctx context.Context, req types.Creat } func (c *multiSyncComponentImpl) getUser(ctx context.Context, userName string) (database.User, error) { - return c.user.FindByUsername(ctx, userName) + return c.userStore.FindByUsername(ctx, userName) } func (c *multiSyncComponentImpl) createLocalSyncVersion(ctx context.Context, v types.SyncVersion) error { @@ -465,7 +465,7 @@ func (c *multiSyncComponentImpl) createLocalSyncVersion(ctx context.Context, v t LastModifiedAt: v.LastModifyTime, ChangeLog: v.ChangeLog, } - err := c.versionStore.Create(ctx, &syncVersion) + err := c.syncVersionStore.Create(ctx, &syncVersion) if err != nil { return err } diff --git a/component/multi_sync_test.go b/component/multi_sync_test.go new file mode 100644 index 00000000..d0296aea --- /dev/null +++ b/component/multi_sync_test.go @@ -0,0 +1,174 @@ +package component + +import ( + "context" + "database/sql" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + multisync_mock "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/multisync" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestMultiSyncComponent_More(t *testing.T) { + ctx := context.TODO() + mc := initializeTestMultiSyncComponent(ctx, t) + + mc.mocks.stores.MultiSyncMock().EXPECT().GetAfter(ctx, int64(1), int64(10)).Return( + []database.SyncVersion{{Version: 2}}, nil, + ) + + data, err := mc.More(ctx, 1, 10) + require.Nil(t, err) + require.Equal(t, []types.SyncVersion{ + {Version: 2}, + }, data) +} + +func TestMultiSyncComponent_SyncAsClient(t *testing.T) { + ctx := mock.Anything + mc := initializeTestMultiSyncComponent(context.TODO(), t) + + mc.mocks.stores.MultiSyncMock().EXPECT().GetLatest(ctx).Return(database.SyncVersion{ + Version: 1, + }, nil) + mockedClient := multisync_mock.NewMockClient(t) + mockedClient.EXPECT().Latest(ctx, int64(1)).Return(types.SyncVersionResponse{ + Data: struct { + Versions []types.SyncVersion "json:\"versions\"" + HasMore bool "json:\"has_more\"" + }{ + Versions: []types.SyncVersion{ + {Version: 2}, + }, + HasMore: true, + }, + }, nil) + mockedClient.EXPECT().Latest(ctx, int64(2)).Return(types.SyncVersionResponse{ + Data: struct { + Versions []types.SyncVersion "json:\"versions\"" + HasMore bool "json:\"has_more\"" + }{ + Versions: []types.SyncVersion{ + {Version: 3}, + }, + HasMore: false, + }, + }, nil) + mc.mocks.stores.SyncVersionMock().EXPECT().Create(ctx, &database.SyncVersion{ + Version: 2, + }).Return(nil) + mc.mocks.stores.SyncVersionMock().EXPECT().Create(ctx, &database.SyncVersion{ + Version: 3, + }).Return(nil) + dsvs := []database.SyncVersion{ + {RepoType: types.ModelRepo}, + {RepoType: types.DatasetRepo}, + } + mc.mocks.stores.MultiSyncMock().EXPECT().GetAfterDistinct(ctx, int64(1)).Return( + dsvs, nil, + ) + svs := []types.SyncVersion{ + {RepoType: types.ModelRepo}, + {RepoType: types.DatasetRepo}, + } + // new model mock + mockedClient.EXPECT().ModelInfo(ctx, svs[0]).Return(&types.Model{ + User: &types.User{Nickname: "nn"}, + Path: "ns/user", + Tags: []types.RepoTag{{Name: "t1"}}, + }, nil) + mockedClient.EXPECT().ReadMeData(ctx, svs[0]).Return("readme", nil) + mc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "CSG_ns").Return(database.User{}, sql.ErrNoRows) + mc.mocks.gitServer.EXPECT().CreateUser(gitserver.CreateUserRequest{ + Nickname: "nn", + Username: "CSG_ns", + Email: "CSG_", + }).Return(&gitserver.CreateUserResponse{GitID: 123}, nil) + mc.mocks.stores.UserMock().EXPECT().Create(ctx, &database.User{ + NickName: "nn", + Username: "CSG_ns", + Email: "CSG_", + GitID: 123, + }, &database.Namespace{ + Path: "CSG_ns", + Mirrored: true, + }).Return(nil) + dbrepo := &database.Repository{ + Path: "CSG_ns/user", + GitPath: "models_CSG_ns/user", + Name: "user", + Readme: "readme", + Source: types.OpenCSGSource, + SyncStatus: types.SyncStatusPending, + RepositoryType: types.ModelRepo, + } + mc.mocks.stores.RepoMock().EXPECT().UpdateOrCreateRepo(ctx, *dbrepo).Return(dbrepo, nil) + dbrepo.ID = 1 + mc.mocks.stores.TagMock().EXPECT().FindOrCreate(ctx, database.Tag{ + Name: "t1", Scope: database.ModelTagScope, + }).Return( + &database.Tag{Name: "t1", ID: 11}, nil, + ) + mc.mocks.stores.RepoMock().EXPECT().DeleteAllTags(ctx, int64(1)).Return(nil) + mc.mocks.stores.RepoMock().EXPECT().BatchCreateRepoTags(ctx, []database.RepositoryTag{ + {RepositoryID: 1, TagID: 11}, + }).Return(nil) + mc.mocks.stores.RepoMock().EXPECT().DeleteAllFiles(ctx, int64(1)).Return(nil) + mockedClient.EXPECT().FileList(ctx, svs[0]).Return([]types.File{ + {Name: "foo.go"}, + }, nil) + mc.mocks.stores.FileMock().EXPECT().BatchCreate(ctx, []database.File{ + {Name: "foo.go", ParentPath: "/", RepositoryID: 1}, + }).Return(nil) + mc.mocks.stores.ModelMock().EXPECT().CreateIfNotExist(ctx, database.Model{ + RepositoryID: 1, + Repository: dbrepo, + }).Return(nil, nil) + + // new dataset mock + dbrepo = &database.Repository{ + Path: "CSG_ns/user", + GitPath: "datasets_CSG_ns/user", + Name: "user", + Readme: "readme", + Source: types.OpenCSGSource, + SyncStatus: types.SyncStatusPending, + RepositoryType: types.DatasetRepo, + } + mockedClient.EXPECT().DatasetInfo(ctx, svs[1]).Return(&types.Dataset{ + User: types.User{Nickname: "nn"}, + Path: "ns/user", + Tags: []types.RepoTag{{Name: "t2"}}, + }, nil) + mockedClient.EXPECT().ReadMeData(ctx, svs[1]).Return("readme", nil) + mc.mocks.stores.RepoMock().EXPECT().UpdateOrCreateRepo(ctx, *dbrepo).Return(dbrepo, nil) + dbrepo.ID = 2 + mc.mocks.stores.TagMock().EXPECT().FindOrCreate(ctx, database.Tag{ + Name: "t2", Scope: database.DatasetTagScope, + }).Return( + &database.Tag{Name: "t2", ID: 12}, nil, + ) + mc.mocks.stores.RepoMock().EXPECT().DeleteAllTags(ctx, int64(2)).Return(nil) + mc.mocks.stores.RepoMock().EXPECT().BatchCreateRepoTags(ctx, []database.RepositoryTag{ + {RepositoryID: 2, TagID: 12}, + }).Return(nil) + mc.mocks.stores.RepoMock().EXPECT().DeleteAllFiles(ctx, int64(2)).Return(nil) + mockedClient.EXPECT().FileList(ctx, svs[1]).Return([]types.File{ + {Name: "foo.go"}, + }, nil) + mc.mocks.stores.FileMock().EXPECT().BatchCreate(ctx, []database.File{ + {Name: "foo.go", ParentPath: "/", RepositoryID: 2}, + }).Return(nil) + mc.mocks.stores.DatasetMock().EXPECT().CreateIfNotExist(ctx, database.Dataset{ + RepositoryID: 2, + Repository: dbrepo, + }).Return(nil, nil) + + err := mc.SyncAsClient(context.TODO(), mockedClient) + require.Nil(t, err) + +} diff --git a/component/recom.go b/component/recom.go index c3661891..0e3183d8 100644 --- a/component/recom.go +++ b/component/recom.go @@ -14,9 +14,9 @@ import ( ) type recomComponentImpl struct { - rs database.RecomStore - repos database.RepoStore - gs gitserver.GitServer + recomStore database.RecomStore + repoStore database.RepoStore + gitServer gitserver.GitServer } type RecomComponent interface { @@ -33,18 +33,18 @@ func NewRecomComponent(cfg *config.Config) (RecomComponent, error) { } return &recomComponentImpl{ - rs: database.NewRecomStore(), - repos: database.NewRepoStore(), - gs: gs, + recomStore: database.NewRecomStore(), + repoStore: database.NewRepoStore(), + gitServer: gs, }, nil } func (rc *recomComponentImpl) SetOpWeight(ctx context.Context, repoID, weight int64) error { - _, err := rc.repos.FindById(ctx, repoID) + _, err := rc.repoStore.FindById(ctx, repoID) if err != nil { return fmt.Errorf("failed to find repository with id %d, err:%w", repoID, err) } - return rc.rs.UpsetOpWeights(ctx, repoID, weight) + return rc.recomStore.UpsetOpWeights(ctx, repoID, weight) } // loop through repositories and calculate the recom score of the repository @@ -54,7 +54,7 @@ func (rc *recomComponentImpl) CalculateRecomScore(ctx context.Context) { slog.Error("Error loading weights", "error", err) return } - repos, err := rc.repos.All(ctx) + repos, err := rc.repoStore.All(ctx) if err != nil { slog.Error("Error fetching repositories", "error", err) return @@ -62,7 +62,7 @@ func (rc *recomComponentImpl) CalculateRecomScore(ctx context.Context) { for _, repo := range repos { repoID := repo.ID score := rc.CalcTotalScore(ctx, repo, weights) - err := rc.rs.UpsertScore(ctx, repoID, score) + err := rc.recomStore.UpsertScore(ctx, repoID, score) if err != nil { slog.Error("Error updating recom score", slog.Int64("repo_id", repoID), slog.Float64("score", score), slog.String("error", err.Error())) @@ -132,7 +132,7 @@ func (rc *recomComponentImpl) calcQualityScore(ctx context.Context, repo *databa score := 0.0 // get file counts from git server namespace, name := repo.NamespaceAndName() - files, err := getFilePaths(namespace, name, "", repo.RepositoryType, "", rc.gs.GetRepoFileTree) + files, err := getFilePaths(namespace, name, "", repo.RepositoryType, "", rc.gitServer.GetRepoFileTree) if err != nil { return 0, fmt.Errorf("failed to get repo file tree,%w", err) } @@ -157,7 +157,7 @@ func (rc *recomComponentImpl) calcQualityScore(ctx context.Context, repo *databa func (rc *recomComponentImpl) loadWeights() (map[string]string, error) { ctx := context.Background() - items, err := rc.rs.LoadWeights(ctx) + items, err := rc.recomStore.LoadWeights(ctx) if err != nil { return nil, err } diff --git a/component/recom_test.go b/component/recom_test.go index 139e9eac..bba569e3 100644 --- a/component/recom_test.go +++ b/component/recom_test.go @@ -6,25 +6,48 @@ import ( "time" "github.com/stretchr/testify/mock" - gsmock "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/store/database" - "opencsg.com/csghub-server/common/tests" ) -func NewTestRecomComponent(stores *tests.MockStores, gitServer gitserver.GitServer) *recomComponentImpl { - return &recomComponentImpl{ - repos: stores.Repo, - gs: gitServer, - } +// func TestRecomComponent_SetOpWeight(t *testing.T) { +// ctx := context.TODO() +// rc := initializeTestRecomComponent(ctx, t) + +// rc.mocks.stores.RepoMock().EXPECT().FindById(ctx, int64(1)).Return(&database.Repository{}, nil) +// rc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ +// RoleMask: "admin", +// }, nil) +// rc.mocks.stores.RecomMock().EXPECT().UpsetOpWeights(ctx, int64(1), int64(100)).Return(nil) + +// err := rc.SetOpWeight(ctx, 1, 100) +// require.Nil(t, err) +// } + +func TestRecomComponent_CalculateRecomScore(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRecomComponent(ctx, t) + + rc.mocks.stores.RecomMock().EXPECT().LoadWeights(mock.Anything).Return( + []*database.RecomWeight{{Name: "freshness", WeightExp: "score = 12.34"}}, nil, + ) + rc.mocks.stores.RepoMock().EXPECT().All(ctx).Return([]*database.Repository{ + {ID: 1, Path: "foo/bar"}, + }, nil) + rc.mocks.stores.RecomMock().EXPECT().UpsertScore(ctx, int64(1), 12.34).Return(nil) + rc.mocks.gitServer.EXPECT().GetRepoFileTree(mock.Anything, gitserver.GetRepoInfoByPathReq{ + Namespace: "foo", + Name: "bar", + }).Return(nil, nil) + + rc.CalculateRecomScore(ctx) } func TestRecomComponent_CalculateTotalScore(t *testing.T) { - gitServer := gsmock.NewMockGitServer(t) - rc := &recomComponentImpl{gs: gitServer} ctx := context.TODO() + rc := initializeTestRecomComponent(ctx, t) - gitServer.EXPECT().GetRepoFileTree(mock.Anything, gitserver.GetRepoInfoByPathReq{ + rc.mocks.gitServer.EXPECT().GetRepoFileTree(mock.Anything, gitserver.GetRepoInfoByPathReq{ Namespace: "foo", Name: "bar", }).Return(nil, nil) diff --git a/component/runtime_architecture.go b/component/runtime_architecture.go index cb79f139..b6f2f872 100644 --- a/component/runtime_architecture.go +++ b/component/runtime_architecture.go @@ -8,6 +8,7 @@ import ( "strings" "sync" + "opencsg.com/csghub-server/builder/git" "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/config" @@ -22,11 +23,14 @@ var ( ) type runtimeArchitectureComponentImpl struct { - r *repoComponentImpl - ras database.RuntimeArchitecturesStore - rfs database.RuntimeFrameworksStore - ts database.TagStore - rms database.ResourceModelStore + repoComponent RepoComponent + repoStore database.RepoStore + repoRuntimeFrameworkStore database.RepositoriesRuntimeFrameworkStore + runtimeArchStore database.RuntimeArchitecturesStore + runtimeFrameworksStore database.RuntimeFrameworksStore + tagStore database.TagStore + resouceModelStore database.ResourceModelStore + gitServer gitserver.GitServer } type RuntimeArchitectureComponent interface { @@ -47,20 +51,28 @@ type RuntimeArchitectureComponent interface { func NewRuntimeArchitectureComponent(config *config.Config) (RuntimeArchitectureComponent, error) { c := &runtimeArchitectureComponentImpl{} - c.rfs = database.NewRuntimeFrameworksStore() - c.ras = database.NewRuntimeArchitecturesStore() - c.ts = database.NewTagStore() - c.rms = database.NewResourceModelStore() + c.runtimeFrameworksStore = database.NewRuntimeFrameworksStore() + c.runtimeArchStore = database.NewRuntimeArchitecturesStore() + c.tagStore = database.NewTagStore() + c.resouceModelStore = database.NewResourceModelStore() repo, err := NewRepoComponentImpl(config) if err != nil { return nil, fmt.Errorf("fail to create repo component, %w", err) } - c.r = repo + c.repoComponent = repo + c.repoStore = database.NewRepoStore() + c.repoRuntimeFrameworkStore = database.NewRepositoriesRuntimeFramework() + c.gitServer, err = git.NewGitServer(config) + if err != nil { + newError := fmt.Errorf("fail to create git server,error:%w", err) + slog.Error(newError.Error()) + return nil, newError + } return c, nil } func (c *runtimeArchitectureComponentImpl) ListByRuntimeFrameworkID(ctx context.Context, id int64) ([]database.RuntimeArchitecture, error) { - archs, err := c.ras.ListByRuntimeFrameworkID(ctx, id) + archs, err := c.runtimeArchStore.ListByRuntimeFrameworkID(ctx, id) if err != nil { return nil, fmt.Errorf("list runtime arch failed, %w", err) } @@ -68,7 +80,7 @@ func (c *runtimeArchitectureComponentImpl) ListByRuntimeFrameworkID(ctx context. } func (c *runtimeArchitectureComponentImpl) SetArchitectures(ctx context.Context, id int64, architectures []string) ([]string, error) { - _, err := c.r.runtimeFrameworksStore.FindByID(ctx, id) + _, err := c.runtimeFrameworksStore.FindByID(ctx, id) if err != nil { return nil, fmt.Errorf("invalid runtime framework id, %w", err) } @@ -77,7 +89,7 @@ func (c *runtimeArchitectureComponentImpl) SetArchitectures(ctx context.Context, if len(strings.Trim(arch, " ")) < 1 { continue } - err := c.ras.Add(ctx, database.RuntimeArchitecture{ + err := c.runtimeArchStore.Add(ctx, database.RuntimeArchitecture{ RuntimeFrameworkID: id, ArchitectureName: strings.Trim(arch, " "), }) @@ -89,7 +101,7 @@ func (c *runtimeArchitectureComponentImpl) SetArchitectures(ctx context.Context, } func (c *runtimeArchitectureComponentImpl) DeleteArchitectures(ctx context.Context, id int64, architectures []string) ([]string, error) { - _, err := c.r.runtimeFrameworksStore.FindByID(ctx, id) + _, err := c.runtimeFrameworksStore.FindByID(ctx, id) if err != nil { return nil, fmt.Errorf("invalid runtime framework id, %w", err) } @@ -98,7 +110,7 @@ func (c *runtimeArchitectureComponentImpl) DeleteArchitectures(ctx context.Conte if len(strings.Trim(arch, " ")) < 1 { continue } - err := c.ras.DeleteByRuntimeIDAndArchName(ctx, id, strings.Trim(arch, " ")) + err := c.runtimeArchStore.DeleteByRuntimeIDAndArchName(ctx, id, strings.Trim(arch, " ")) if err != nil { failedDeletes = append(failedDeletes, arch) } @@ -107,11 +119,11 @@ func (c *runtimeArchitectureComponentImpl) DeleteArchitectures(ctx context.Conte } func (c *runtimeArchitectureComponentImpl) ScanArchitecture(ctx context.Context, id int64, scanType int, models []string) error { - frame, err := c.r.runtimeFrameworksStore.FindByID(ctx, id) + frame, err := c.runtimeFrameworksStore.FindByID(ctx, id) if err != nil { return fmt.Errorf("invalid runtime framework id, %w", err) } - archs, err := c.ras.ListByRuntimeFrameworkID(ctx, id) + archs, err := c.runtimeArchStore.ListByRuntimeFrameworkID(ctx, id) if err != nil { return fmt.Errorf("list runtime arch failed, %w", err) } @@ -156,18 +168,18 @@ func (c *runtimeArchitectureComponentImpl) ScanArchitecture(ctx context.Context, } func (c *runtimeArchitectureComponentImpl) scanNewModels(ctx context.Context, req types.ScanReq) error { - repos, err := c.r.repoStore.GetRepoWithoutRuntimeByID(ctx, req.FrameID, req.Models) + repos, err := c.repoStore.GetRepoWithoutRuntimeByID(ctx, req.FrameID, req.Models) if err != nil { return fmt.Errorf("failed to get repos without runtime by ID, %w", err) } if repos == nil { return nil } - runtime_framework, err := c.rfs.FindByID(ctx, req.FrameID) + runtime_framework, err := c.runtimeFrameworksStore.FindByID(ctx, req.FrameID) if err != nil { return fmt.Errorf("failed to get runtime framework by ID, %w", err) } - runtime_framework_tags, _ := c.ts.GetTagsByScopeAndCategories(ctx, "model", []string{"runtime_framework", "resource"}) + runtime_framework_tags, _ := c.tagStore.GetTagsByScopeAndCategories(ctx, "model", []string{"runtime_framework", "resource"}) for _, repo := range repos { namespace, name := repo.NamespaceAndName() arch, err := c.GetArchitectureFromConfig(ctx, namespace, name) @@ -187,7 +199,7 @@ func (c *runtimeArchitectureComponentImpl) scanNewModels(ctx context.Context, re if !exist && !isSupportedRM { continue } - err = c.r.repoRuntimeFrameworkStore.Add(ctx, req.FrameID, repo.ID, req.FrameType) + err = c.repoRuntimeFrameworkStore.Add(ctx, req.FrameID, repo.ID, req.FrameType) if err != nil { slog.Warn("fail to create relation", slog.Any("repo", repo.Path), slog.Any("frameid", req.FrameID), slog.Any("error", err)) } @@ -207,7 +219,7 @@ func (c *runtimeArchitectureComponentImpl) scanNewModels(ctx context.Context, re // check if it's supported model resource by name func (c *runtimeArchitectureComponentImpl) IsSupportedModelResource(ctx context.Context, modelName string, rf *database.RuntimeFramework, id int64) (bool, error) { trimModel := strings.Replace(strings.ToLower(modelName), "meta-", "", 1) - rm, err := c.rms.CheckModelNameNotInRFRepo(ctx, trimModel, id) + rm, err := c.resouceModelStore.CheckModelNameNotInRFRepo(ctx, trimModel, id) if err != nil || rm == nil { return false, err } @@ -230,7 +242,7 @@ func (c *runtimeArchitectureComponentImpl) IsSupportedModelResource(ctx context. } func (c *runtimeArchitectureComponentImpl) scanExistModels(ctx context.Context, req types.ScanReq) error { - repos, err := c.r.repoStore.GetRepoWithRuntimeByID(ctx, req.FrameID, req.Models) + repos, err := c.repoStore.GetRepoWithRuntimeByID(ctx, req.FrameID, req.Models) if err != nil { return fmt.Errorf("fail to get repos with runtime by ID, %w", err) } @@ -251,7 +263,7 @@ func (c *runtimeArchitectureComponentImpl) scanExistModels(ctx context.Context, if exist { continue } - err = c.r.repoRuntimeFrameworkStore.Delete(ctx, req.FrameID, repo.ID, req.FrameType) + err = c.repoRuntimeFrameworkStore.Delete(ctx, req.FrameID, repo.ID, req.FrameType) if err != nil { slog.Warn("fail to remove relation", slog.Any("repo", repo.Path), slog.Any("frameid", req.FrameID), slog.Any("error", err)) } @@ -282,7 +294,7 @@ func (c *runtimeArchitectureComponentImpl) GetArchitectureFromConfig(ctx context } func (c *runtimeArchitectureComponentImpl) getConfigContent(ctx context.Context, namespace, name string) (string, error) { - content, err := c.r.git.GetRepoFileRaw(ctx, gitserver.GetRepoInfoByPathReq{ + content, err := c.gitServer.GetRepoFileRaw(ctx, gitserver.GetRepoInfoByPathReq{ Namespace: namespace, Name: name, Ref: MainBranch, @@ -297,10 +309,10 @@ func (c *runtimeArchitectureComponentImpl) getConfigContent(ctx context.Context, // remove runtime_framework tag from model func (c *runtimeArchitectureComponentImpl) RemoveRuntimeFrameworkTag(ctx context.Context, rftags []*database.Tag, repoId, rfId int64) { - rfw, _ := c.rfs.FindByID(ctx, rfId) + rfw, _ := c.runtimeFrameworksStore.FindByID(ctx, rfId) for _, tag := range rftags { if strings.Contains(rfw.FrameImage, tag.Name) { - err := c.ts.RemoveRepoTags(ctx, repoId, []int64{tag.ID}) + err := c.tagStore.RemoveRepoTags(ctx, repoId, []int64{tag.ID}) if err != nil { slog.Warn("fail to remove runtime_framework tag from model repo", slog.Any("repoId", repoId), slog.Any("runtime_framework_id", rfId), slog.Any("error", err)) } @@ -310,13 +322,13 @@ func (c *runtimeArchitectureComponentImpl) RemoveRuntimeFrameworkTag(ctx context // add runtime_framework tag to model func (c *runtimeArchitectureComponentImpl) AddRuntimeFrameworkTag(ctx context.Context, rftags []*database.Tag, repoId, rfId int64) error { - rfw, err := c.rfs.FindByID(ctx, rfId) + rfw, err := c.runtimeFrameworksStore.FindByID(ctx, rfId) if err != nil { return err } for _, tag := range rftags { if strings.Contains(rfw.FrameImage, tag.Name) { - err := c.ts.UpsertRepoTags(ctx, repoId, []int64{}, []int64{tag.ID}) + err := c.tagStore.UpsertRepoTags(ctx, repoId, []int64{}, []int64{tag.ID}) if err != nil { slog.Warn("fail to add runtime_framework tag to model repo", slog.Any("repoId", repoId), slog.Any("runtime_framework_id", rfId), slog.Any("error", err)) } @@ -327,14 +339,14 @@ func (c *runtimeArchitectureComponentImpl) AddRuntimeFrameworkTag(ctx context.Co // add resource tag to model func (c *runtimeArchitectureComponentImpl) AddResourceTag(ctx context.Context, rstags []*database.Tag, modelname string, repoId int64) error { - rms, err := c.rms.FindByModelName(ctx, modelname) + rms, err := c.resouceModelStore.FindByModelName(ctx, modelname) if err != nil { return err } for _, rm := range rms { for _, tag := range rstags { if strings.Contains(rm.ResourceName, tag.Name) { - err := c.ts.UpsertRepoTags(ctx, repoId, []int64{}, []int64{tag.ID}) + err := c.tagStore.UpsertRepoTags(ctx, repoId, []int64{}, []int64{tag.ID}) if err != nil { slog.Warn("fail to add resource tag to model repo", slog.Any("repoId", repoId), slog.Any("error", err)) } diff --git a/component/runtime_architecture_test.go b/component/runtime_architecture_test.go new file mode 100644 index 00000000..5db7dfdf --- /dev/null +++ b/component/runtime_architecture_test.go @@ -0,0 +1,209 @@ +package component + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestRuntimeArchComponent_ListByRuntimeFrameworkID(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRuntimeArchComponent(ctx, t) + + data := []database.RuntimeArchitecture{ + {ID: 123, ArchitectureName: "arch"}, + } + rc.mocks.stores.RuntimeArchMock().EXPECT().ListByRuntimeFrameworkID(ctx, int64(1)).Return( + data, nil, + ) + resp, err := rc.ListByRuntimeFrameworkID(ctx, 1) + require.Nil(t, err) + require.Equal(t, data, resp) + +} + +func TestRuntimeArchComponent_SetArchitectures(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRuntimeArchComponent(ctx, t) + + rc.mocks.stores.RuntimeFrameworkMock().EXPECT().FindByID(ctx, int64(1)).Return(nil, nil) + rc.mocks.stores.RuntimeArchMock().EXPECT().Add(ctx, database.RuntimeArchitecture{ + RuntimeFrameworkID: 1, + ArchitectureName: "foo", + }).Return(nil) + rc.mocks.stores.RuntimeArchMock().EXPECT().Add(ctx, database.RuntimeArchitecture{ + RuntimeFrameworkID: 1, + ArchitectureName: "bar", + }).Return(errors.New("")) + + failed, err := rc.SetArchitectures(ctx, int64(1), []string{"foo", "bar"}) + require.Nil(t, err) + require.Equal(t, []string{"bar"}, failed) + +} + +func TestRuntimeArchComponent_DeleteArchitectures(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRuntimeArchComponent(ctx, t) + + rc.mocks.stores.RuntimeFrameworkMock().EXPECT().FindByID(ctx, int64(1)).Return(nil, nil) + rc.mocks.stores.RuntimeArchMock().EXPECT().DeleteByRuntimeIDAndArchName(ctx, int64(1), "foo").Return(nil) + rc.mocks.stores.RuntimeArchMock().EXPECT().DeleteByRuntimeIDAndArchName(ctx, int64(1), "bar").Return(errors.New("")) + + failed, err := rc.DeleteArchitectures(ctx, int64(1), []string{"foo", "bar"}) + require.Nil(t, err) + require.Equal(t, []string{"bar"}, failed) + +} + +func TestRuntimeArchComponent_ScanArchitectures(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRuntimeArchComponent(ctx, t) + + rc.mocks.stores.RuntimeFrameworkMock().EXPECT().FindByID(ctx, int64(1)).Return( + &database.RuntimeFramework{ + Type: 11, + }, nil, + ) + data := []database.RuntimeArchitecture{ + {ID: 123, ArchitectureName: "arch"}, + {ID: 124, ArchitectureName: "foo"}, + } + rc.mocks.stores.RuntimeArchMock().EXPECT().ListByRuntimeFrameworkID(ctx, int64(1)).Return( + data, nil, + ) + + // scan exists mocks + rc.mocks.stores.RepoMock().EXPECT().GetRepoWithRuntimeByID(ctx, int64(1), []string{"foo"}).Return([]database.Repository{ + {Path: "foo/bar"}, + }, nil) + rc.mocks.gitServer.EXPECT().GetRepoFileRaw(ctx, gitserver.GetRepoInfoByPathReq{ + Namespace: "foo", + Name: "bar", + Ref: "main", + Path: ConfigFileName, + RepoType: types.ModelRepo, + }).Return(`{"architectures": ["foo","bar"]}`, nil) + + // scan new mocks + rc.mocks.stores.RepoMock().EXPECT().GetRepoWithoutRuntimeByID(ctx, int64(1), []string{"foo"}).Return([]database.Repository{ + {Path: "foo/bar"}, + }, nil) + rc.mocks.stores.TagMock().EXPECT().GetTagsByScopeAndCategories(ctx, database.ModelTagScope, []string{ + "runtime_framework", "resource", + }).Return([]*database.Tag{}, nil) + rc.mocks.stores.RepoRuntimeFrameworkMock().EXPECT().Add(ctx, int64(1), int64(0), 11).Return(nil) + rc.mocks.stores.ResourceModelMock().EXPECT().CheckModelNameNotInRFRepo(ctx, "bar", int64(0)).Return( + &database.ResourceModel{}, nil, + ) + rc.mocks.stores.ResourceModelMock().EXPECT().FindByModelName(ctx, "bar").Return( + []*database.ResourceModel{ + {ResourceName: "r1"}, + {ResourceName: "r2"}, + }, nil, + ) + + err := rc.ScanArchitecture(ctx, 1, 0, []string{"foo"}) + require.Nil(t, err) + // wait async code finish + ScanLock.Lock() + _ = 1 + ScanLock.Unlock() + +} + +func TestRuntimeArchComponent_IsSupportedModelResource(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRuntimeArchComponent(ctx, t) + + rc.mocks.stores.ResourceModelMock().EXPECT().CheckModelNameNotInRFRepo(ctx, "model", int64(1)).Return( + &database.ResourceModel{EngineName: "a"}, nil, + ) + + r, err := rc.IsSupportedModelResource(ctx, "meta-model", &database.RuntimeFramework{ + FrameImage: "a/b", + }, 1) + require.Nil(t, err, nil) + require.False(t, r) +} + +func TestRuntimeArchComponent_GetArchitectureFromConfig(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRuntimeArchComponent(ctx, t) + + rc.mocks.gitServer.EXPECT().GetRepoFileRaw(ctx, gitserver.GetRepoInfoByPathReq{ + Namespace: "foo", + Name: "bar", + Ref: "main", + Path: ConfigFileName, + RepoType: types.ModelRepo, + }).Return(`{"architectures": ["foo","bar"]}`, nil) + + arch, err := rc.GetArchitectureFromConfig(ctx, "foo", "bar") + require.Nil(t, err) + require.Equal(t, "foo", arch) + +} + +// func TestRuntimeArchComponent_RemoveRuntimeFrameworkTag(t *testing.T) { +// ctx := context.TODO() +// rc := initializeTestRuntimeArchComponent(ctx, t) + +// rc.mocks.stores.RuntimeFrameworkMock().EXPECT().FindByID(ctx, int64(2)).Return( +// &database.RuntimeFramework{ +// FrameImage: "img", +// FrameNpuImage: "npu", +// }, nil, +// ) +// rc.mocks.stores.TagMock().EXPECT().RemoveRepoTags(ctx, int64(1), []int64{1}).Return(nil) +// rc.mocks.stores.TagMock().EXPECT().RemoveRepoTags(ctx, int64(1), []int64{2}).Return(nil) + +// rc.RemoveRuntimeFrameworkTag(ctx, []*database.Tag{ +// {Name: "img", ID: 1}, +// {Name: "npu", ID: 2}, +// }, int64(1), int64(2)) +// } + +// func TestRuntimeArchComponent_AddRuntimeFrameworkTag(t *testing.T) { +// ctx := context.TODO() +// rc := initializeTestRuntimeArchComponent(ctx, t) + +// rc.mocks.stores.RuntimeFrameworkMock().EXPECT().FindByID(ctx, int64(2)).Return( +// &database.RuntimeFramework{ +// FrameImage: "img", +// FrameNpuImage: "npu", +// }, nil, +// ) +// rc.mocks.stores.TagMock().EXPECT().UpsertRepoTags(ctx, int64(1), []int64{}, []int64{1}).Return(nil) +// rc.mocks.stores.TagMock().EXPECT().UpsertRepoTags(ctx, int64(1), []int64{}, []int64{2}).Return(nil) + +// err := rc.AddRuntimeFrameworkTag(ctx, []*database.Tag{ +// {Name: "img", ID: 1}, +// {Name: "npu", ID: 2}, +// }, int64(1), int64(2)) +// require.Nil(t, err) +// } + +// func TestRuntimeArchComponent_AddResourceTag(t *testing.T) { +// ctx := context.TODO() +// rc := initializeTestRuntimeArchComponent(ctx, t) + +// rc.mocks.stores.ResourceModelMock().EXPECT().FindByModelName(ctx, "model").Return( +// []*database.ResourceModel{ +// {ResourceName: "r1"}, +// {ResourceName: "r2"}, +// }, nil, +// ) +// rc.mocks.stores.TagMock().EXPECT().UpsertRepoTags(ctx, int64(1), []int64{}, []int64{1}).Return(nil) +// rc.mocks.stores.TagMock().EXPECT().UpsertRepoTags(ctx, int64(1), []int64{}, []int64{2}).Return(nil) + +// err := rc.AddResourceTag(ctx, []*database.Tag{ +// {Name: "r1", ID: 1}, +// }, "model", int64(1)) +// require.Nil(t, err) +// } diff --git a/component/space_resource_test.go b/component/space_resource_test.go new file mode 100644 index 00000000..024cc1e9 --- /dev/null +++ b/component/space_resource_test.go @@ -0,0 +1,94 @@ +package component + +// func TestSpaceResourceComponent_Index(t *testing.T) { +// ctx := context.TODO() +// sc := initializeTestSpaceResourceComponent(ctx, t) + +// sc.mocks.deployer.EXPECT().ListCluster(ctx).Return([]types.ClusterRes{ +// {ClusterID: "c1"}, +// }, nil) +// sc.mocks.stores.SpaceResourceMock().EXPECT().Index(ctx, "c1").Return( +// []database.SpaceResource{ +// {ID: 1, Name: "sr", Resources: `{"memory": "1000"}`}, +// }, nil, +// ) +// sc.mocks.deployer.EXPECT().GetClusterById(ctx, "c1").Return(&types.ClusterRes{}, nil) +// sc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ +// UUID: "uid", +// }, nil) + +// data, err := sc.Index(ctx, "", 1) +// require.Nil(t, err) +// require.Equal(t, []types.SpaceResource{ +// { +// ID: 1, Name: "sr", Resources: "{\"memory\": \"1000\"}", +// IsAvailable: false, Type: "cpu", +// }, +// { +// ID: 0, Name: "", Resources: "{\"memory\": \"2000\"}", IsAvailable: true, +// Type: "cpu", +// }, +// }, data) + +// } + +// func TestSpaceResourceComponent_Update(t *testing.T) { +// ctx := context.TODO() +// sc := initializeTestSpaceResourceComponent(ctx, t) + +// sc.mocks.stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(1)).Return( +// &database.SpaceResource{}, nil, +// ) +// sc.mocks.stores.SpaceResourceMock().EXPECT().Update(ctx, database.SpaceResource{ +// Name: "n", +// Resources: "r", +// }).Return(&database.SpaceResource{ID: 1, Name: "n", Resources: "r"}, nil) + +// data, err := sc.Update(ctx, &types.UpdateSpaceResourceReq{ +// ID: 1, +// Name: "n", +// Resources: "r", +// }) +// require.Nil(t, err) +// require.Equal(t, &types.SpaceResource{ +// ID: 1, +// Name: "n", +// Resources: "r", +// }, data) +// } + +// func TestSpaceResourceComponent_Create(t *testing.T) { +// ctx := context.TODO() +// sc := initializeTestSpaceResourceComponent(ctx, t) + +// sc.mocks.stores.SpaceResourceMock().EXPECT().Create(ctx, database.SpaceResource{ +// Name: "n", +// Resources: "r", +// ClusterID: "c", +// }).Return(&database.SpaceResource{ID: 1, Name: "n", Resources: "r"}, nil) + +// data, err := sc.Create(ctx, &types.CreateSpaceResourceReq{ +// Name: "n", +// Resources: "r", +// ClusterID: "c", +// }) +// require.Nil(t, err) +// require.Equal(t, &types.SpaceResource{ +// ID: 1, +// Name: "n", +// Resources: "r", +// }, data) +// } + +// func TestSpaceResourceComponent_Delete(t *testing.T) { +// ctx := context.TODO() +// sc := initializeTestSpaceResourceComponent(ctx, t) + +// sc.mocks.stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(1)).Return( +// &database.SpaceResource{}, nil, +// ) +// sc.mocks.stores.SpaceResourceMock().EXPECT().Delete(ctx, database.SpaceResource{}).Return(nil) + +// err := sc.Delete(ctx, 1) +// require.Nil(t, err) +// } diff --git a/component/space_sdk.go b/component/space_sdk.go index 1c0eca17..497dfede 100644 --- a/component/space_sdk.go +++ b/component/space_sdk.go @@ -18,18 +18,18 @@ type SpaceSdkComponent interface { func NewSpaceSdkComponent(config *config.Config) (SpaceSdkComponent, error) { c := &spaceSdkComponentImpl{} - c.sss = database.NewSpaceSdkStore() + c.spaceSdkStore = database.NewSpaceSdkStore() return c, nil } type spaceSdkComponentImpl struct { - sss database.SpaceSdkStore + spaceSdkStore database.SpaceSdkStore } func (c *spaceSdkComponentImpl) Index(ctx context.Context) ([]types.SpaceSdk, error) { var result []types.SpaceSdk - databaseSpaceSdks, err := c.sss.Index(ctx) + databaseSpaceSdks, err := c.spaceSdkStore.Index(ctx) if err != nil { return nil, err } @@ -45,7 +45,7 @@ func (c *spaceSdkComponentImpl) Index(ctx context.Context) ([]types.SpaceSdk, er } func (c *spaceSdkComponentImpl) Update(ctx context.Context, req *types.UpdateSpaceSdkReq) (*types.SpaceSdk, error) { - ss, err := c.sss.FindByID(ctx, req.ID) + ss, err := c.spaceSdkStore.FindByID(ctx, req.ID) if err != nil { slog.Error("error getting space sdk", slog.Any("error", err)) return nil, err @@ -53,7 +53,7 @@ func (c *spaceSdkComponentImpl) Update(ctx context.Context, req *types.UpdateSpa ss.Name = req.Name ss.Version = req.Version - ss, err = c.sss.Update(ctx, *ss) + ss, err = c.spaceSdkStore.Update(ctx, *ss) if err != nil { slog.Error("error getting space sdk", slog.Any("error", err)) return nil, err @@ -73,7 +73,7 @@ func (c *spaceSdkComponentImpl) Create(ctx context.Context, req *types.CreateSpa Name: req.Name, Version: req.Version, } - res, err := c.sss.Create(ctx, ss) + res, err := c.spaceSdkStore.Create(ctx, ss) if err != nil { slog.Error("error creating space sdk", slog.Any("error", err)) return nil, err @@ -89,13 +89,13 @@ func (c *spaceSdkComponentImpl) Create(ctx context.Context, req *types.CreateSpa } func (c *spaceSdkComponentImpl) Delete(ctx context.Context, id int64) error { - ss, err := c.sss.FindByID(ctx, id) + ss, err := c.spaceSdkStore.FindByID(ctx, id) if err != nil { slog.Error("error finding space sdk", slog.Any("error", err)) return err } - err = c.sss.Delete(ctx, *ss) + err = c.spaceSdkStore.Delete(ctx, *ss) if err != nil { slog.Error("error deleting space sdk", slog.Any("error", err)) return err diff --git a/component/space_sdk_test.go b/component/space_sdk_test.go new file mode 100644 index 00000000..641f00d4 --- /dev/null +++ b/component/space_sdk_test.go @@ -0,0 +1,64 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestSpaceSdkComponent_Index(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSpaceSdkComponent(ctx, t) + + sc.mocks.stores.SpaceSdkMock().EXPECT().Index(ctx).Return([]database.SpaceSdk{ + {ID: 1, Name: "s", Version: "1"}, + }, nil) + + data, err := sc.Index(ctx) + require.Nil(t, err) + require.Equal(t, []types.SpaceSdk{{ID: 1, Name: "s", Version: "1"}}, data) +} + +func TestSpaceSdkComponent_Update(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSpaceSdkComponent(ctx, t) + + s := &database.SpaceSdk{ID: 1} + sc.mocks.stores.SpaceSdkMock().EXPECT().FindByID(ctx, int64(1)).Return(s, nil) + s2 := *s + s2.Name = "n" + s2.Version = "v1" + sc.mocks.stores.SpaceSdkMock().EXPECT().Update(ctx, s2).Return(s, nil) + + data, err := sc.Update(ctx, &types.UpdateSpaceSdkReq{ID: 1, Name: "n", Version: "v1"}) + require.Nil(t, err) + require.Equal(t, &types.SpaceSdk{ID: 1, Name: "n", Version: "v1"}, data) +} + +func TestSpaceSdkComponent_Create(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSpaceSdkComponent(ctx, t) + + s := database.SpaceSdk{Name: "n", Version: "v1"} + sc.mocks.stores.SpaceSdkMock().EXPECT().Create(ctx, s).Return(&s, nil) + s.ID = 1 + + data, err := sc.Create(ctx, &types.CreateSpaceSdkReq{Name: "n", Version: "v1"}) + require.Nil(t, err) + require.Equal(t, &types.SpaceSdk{ID: 1, Name: "n", Version: "v1"}, data) +} + +func TestSpaceSdkComponent_Delete(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSpaceSdkComponent(ctx, t) + + s := &database.SpaceSdk{} + sc.mocks.stores.SpaceSdkMock().EXPECT().FindByID(ctx, int64(1)).Return(s, nil) + sc.mocks.stores.SpaceSdkMock().EXPECT().Delete(ctx, *s).Return(nil) + + err := sc.Delete(ctx, int64(1)) + require.Nil(t, err) +} diff --git a/component/tag.go b/component/tag.go index d428b895..2c95e4f3 100644 --- a/component/tag.go +++ b/component/tag.go @@ -16,7 +16,6 @@ import ( type TagComponent interface { AllTagsByScopeAndCategory(ctx context.Context, scope string, category string) ([]*database.Tag, error) - AllTags(ctx context.Context) ([]database.Tag, error) ClearMetaTags(ctx context.Context, repoType types.RepositoryType, namespace, name string) error UpdateMetaTags(ctx context.Context, tagScope database.TagScope, namespace, name, content string) ([]*database.RepositoryTag, error) UpdateLibraryTags(ctx context.Context, tagScope database.TagScope, namespace, name, oldFilePath, newFilePath string) error @@ -25,8 +24,8 @@ type TagComponent interface { func NewTagComponent(config *config.Config) (TagComponent, error) { tc := &tagComponentImpl{} - tc.ts = database.NewTagStore() - tc.rs = database.NewRepoStore() + tc.tagStore = database.NewTagStore() + tc.repoStore = database.NewRepoStore() if config.SensitiveCheck.Enable { tc.sensitiveChecker = rpc.NewModerationSvcHttpClient(fmt.Sprintf("%s:%d", config.Moderation.Host, config.Moderation.Port)) } @@ -34,21 +33,18 @@ func NewTagComponent(config *config.Config) (TagComponent, error) { } type tagComponentImpl struct { - ts database.TagStore - rs database.RepoStore + tagStore database.TagStore + repoStore database.RepoStore sensitiveChecker rpc.ModerationSvcClient } -func (c *tagComponentImpl) AllTags(ctx context.Context) ([]database.Tag, error) { - // TODO: query cache for tags at first - return c.ts.AllTags(ctx) -} -func (c *tagComponentImpl) AllTagsByScopeAndCategory(ctx context.Context, scope string, category string) ([]*database.Tag, error) { - return c.ts.AllTagsByScopeAndCategory(ctx, database.TagScope(scope), category) +func (tc *tagComponentImpl) AllTagsByScopeAndCategory(ctx context.Context, scope string, category string) ([]*database.Tag, error) { + return tc.tagStore.AllTagsByScopeAndCategory(ctx, database.TagScope(scope), category) } func (c *tagComponentImpl) ClearMetaTags(ctx context.Context, repoType types.RepositoryType, namespace, name string) error { - _, err := c.ts.SetMetaTags(ctx, repoType, namespace, name, nil) + + _, err := c.tagStore.SetMetaTags(ctx, repoType, namespace, name, nil) return err } @@ -60,13 +56,13 @@ func (c *tagComponentImpl) UpdateMetaTags(ctx context.Context, tagScope database // TODO:load from cache if tagScope == database.DatasetTagScope { - tp = tagparser.NewDatasetTagProcessor(c.ts) + tp = tagparser.NewDatasetTagProcessor(c.tagStore) repoType = types.DatasetRepo } else if tagScope == database.ModelTagScope { - tp = tagparser.NewModelTagProcessor(c.ts) + tp = tagparser.NewModelTagProcessor(c.tagStore) repoType = types.ModelRepo } else if tagScope == database.PromptTagScope { - tp = tagparser.NewPromptTagProcessor(c.ts) + tp = tagparser.NewPromptTagProcessor(c.tagStore) repoType = types.PromptRepo } else { // skip tag process for code and space now @@ -91,13 +87,13 @@ func (c *tagComponentImpl) UpdateMetaTags(ctx context.Context, tagScope database }) } - err = c.ts.SaveTags(ctx, tagToCreate) + err = c.tagStore.SaveTags(ctx, tagToCreate) if err != nil { slog.Error("Failed to save tags", slog.Any("error", err)) return nil, fmt.Errorf("failed to save tags, cause: %w", err) } - repo, err := c.rs.FindByPath(ctx, repoType, namespace, name) + repo, err := c.repoStore.FindByPath(ctx, repoType, namespace, name) if err != nil { slog.Error("failed to find repo", slog.Any("error", err)) return nil, fmt.Errorf("failed to find repo, cause: %w", err) @@ -105,14 +101,14 @@ func (c *tagComponentImpl) UpdateMetaTags(ctx context.Context, tagScope database metaTags := append(tagsMatched, tagToCreate...) var repoTags []*database.RepositoryTag - repoTags, err = c.ts.SetMetaTags(ctx, repoType, namespace, name, metaTags) + repoTags, err = c.tagStore.SetMetaTags(ctx, repoType, namespace, name, metaTags) if err != nil { slog.Error("failed to set dataset's tags", slog.String("namespace", namespace), slog.String("name", name), slog.Any("error", err)) return nil, fmt.Errorf("failed to set dataset's tags, cause: %w", err) } - err = c.rs.UpdateLicenseByTag(ctx, repo.ID) + err = c.repoStore.UpdateLicenseByTag(ctx, repo.ID) if err != nil { slog.Error("failed to update repo license tags", slog.Any("error", err)) } @@ -130,13 +126,13 @@ func (c *tagComponentImpl) UpdateLibraryTags(ctx context.Context, tagScope datab repoType types.RepositoryType ) if tagScope == database.DatasetTagScope { - allTags, err = c.ts.AllDatasetTags(ctx) + allTags, err = c.tagStore.AllDatasetTags(ctx) repoType = types.DatasetRepo } else if tagScope == database.ModelTagScope { - allTags, err = c.ts.AllModelTags(ctx) + allTags, err = c.tagStore.AllModelTags(ctx) repoType = types.ModelRepo } else if tagScope == database.PromptTagScope { - allTags, err = c.ts.AllPromptTags(ctx) + allTags, err = c.tagStore.AllPromptTags(ctx) repoType = types.PromptRepo } else { return nil @@ -156,7 +152,7 @@ func (c *tagComponentImpl) UpdateLibraryTags(ctx context.Context, tagScope datab oldLibTag = t } } - err = c.ts.SetLibraryTag(ctx, repoType, namespace, name, newLibTag, oldLibTag) + err = c.tagStore.SetLibraryTag(ctx, repoType, namespace, name, newLibTag, oldLibTag) if err != nil { slog.Error("failed to set %s's tags", string(repoType), slog.String("namespace", namespace), slog.String("name", name), slog.Any("error", err)) @@ -166,7 +162,7 @@ func (c *tagComponentImpl) UpdateLibraryTags(ctx context.Context, tagScope datab } func (c *tagComponentImpl) UpdateRepoTagsByCategory(ctx context.Context, tagScope database.TagScope, repoID int64, category string, tagNames []string) error { - allTags, err := c.ts.AllTagsByScopeAndCategory(ctx, tagScope, category) + allTags, err := c.tagStore.AllTagsByScopeAndCategory(ctx, tagScope, category) if err != nil { return fmt.Errorf("failed to get all tags of scope `%s`, error: %w", tagScope, err) } @@ -185,9 +181,9 @@ func (c *tagComponentImpl) UpdateRepoTagsByCategory(ctx context.Context, tagScop } var oldTagIDs []int64 - oldTagIDs, err = c.rs.TagIDs(ctx, repoID, category) + oldTagIDs, err = c.repoStore.TagIDs(ctx, repoID, category) if err != nil { return fmt.Errorf("failed to get old tag ids, error: %w", err) } - return c.ts.UpsertRepoTags(ctx, repoID, oldTagIDs, tagIDs) + return c.tagStore.UpsertRepoTags(ctx, repoID, oldTagIDs, tagIDs) } diff --git a/component/tag_test.go b/component/tag_test.go new file mode 100644 index 00000000..2f53aa30 --- /dev/null +++ b/component/tag_test.go @@ -0,0 +1,90 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestTagComponent_AllTagsByScopeAndCategory(t *testing.T) { + ctx := context.TODO() + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.TagMock().EXPECT().AllTagsByScopeAndCategory(ctx, database.CodeTagScope, "cat").Return( + []*database.Tag{{Name: "t"}}, nil, + ) + + data, err := tc.AllTagsByScopeAndCategory(ctx, "code", "cat") + require.Nil(t, err) + require.Equal(t, []*database.Tag{{Name: "t"}}, data) +} + +func TestTagComponent_ClearMetaTags(t *testing.T) { + ctx := context.TODO() + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.TagMock().EXPECT().SetMetaTags( + ctx, types.ModelRepo, "ns", "n", []*database.Tag(nil), + ).Return(nil, nil) + + err := tc.ClearMetaTags(ctx, types.ModelRepo, "ns", "n") + require.Nil(t, err) +} + +func TestTagComponent_UpdateMetaTags(t *testing.T) { + ctx := context.TODO() + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.TagMock().EXPECT().AllDatasetTags(ctx).Return([]*database.Tag{}, nil) + tc.mocks.stores.TagMock().EXPECT().SaveTags(ctx, []*database.Tag(nil)).Return(nil) + tc.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.DatasetRepo, "ns", "n").Return( + &database.Repository{ID: 1}, nil, + ) + tc.mocks.stores.TagMock().EXPECT().SetMetaTags( + ctx, types.DatasetRepo, "ns", "n", []*database.Tag(nil), + ).Return(nil, nil) + tc.mocks.stores.RepoMock().EXPECT().UpdateLicenseByTag(ctx, int64(1)).Return(nil) + + data, err := tc.UpdateMetaTags(ctx, database.DatasetTagScope, "ns", "n", "") + require.Nil(t, err) + require.Equal(t, []*database.RepositoryTag(nil), data) +} + +func TestTagComponent_UpdateLibraryTags(t *testing.T) { + ctx := context.TODO() + tc := initializeTestTagComponent(ctx, t) + + tags := []*database.Tag{ + {Category: "framework", Name: "pytorch", ID: 1}, + {Category: "framework", Name: "tensorflow", ID: 2}, + } + tc.mocks.stores.TagMock().EXPECT().AllDatasetTags(ctx).Return(tags, nil) + tc.mocks.stores.TagMock().EXPECT().SetLibraryTag( + ctx, types.DatasetRepo, "ns", "n", tags[1], tags[0], + ).Return(nil) + + err := tc.UpdateLibraryTags( + ctx, database.DatasetTagScope, "ns", "n", "pytorch_model_old.bin", "tf_model_new.h5", + ) + require.Nil(t, err) + +} + +func TestTagComponent_UpdateRepoTagsByCategory(t *testing.T) { + ctx := context.TODO() + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.TagMock().EXPECT().AllTagsByScopeAndCategory(ctx, database.DatasetTagScope, "c").Return( + []*database.Tag{ + {Name: "t1", ID: 2}, + }, nil, + ) + tc.mocks.stores.RepoMock().EXPECT().TagIDs(ctx, int64(1), "c").Return([]int64{1}, nil) + tc.mocks.stores.TagMock().EXPECT().UpsertRepoTags(ctx, int64(1), []int64{1}, []int64{2}).Return(nil) + + err := tc.UpdateRepoTagsByCategory(ctx, database.DatasetTagScope, 1, "c", []string{"t1"}) + require.Nil(t, err) +} diff --git a/component/wire.go b/component/wire.go index 5567b131..89b3ab39 100644 --- a/component/wire.go +++ b/component/wire.go @@ -105,3 +105,243 @@ func initializeTestAccountingComponent(ctx context.Context, t interface { ) return &testAccountingWithMocks{} } + +type testDatasetViewerWithMocks struct { + *datasetViewerComponentImpl + mocks *Mocks +} + +func initializeTestDatasetViewerComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testDatasetViewerWithMocks { + wire.Build( + MockSuperSet, DatasetViewerComponentSet, + wire.Struct(new(testDatasetViewerWithMocks), "*"), + ) + return &testDatasetViewerWithMocks{} +} + +type testGitHTTPWithMocks struct { + *gitHTTPComponentImpl + mocks *Mocks +} + +func initializeTestGitHTTPComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testGitHTTPWithMocks { + wire.Build( + MockSuperSet, GitHTTPComponentSet, + wire.Struct(new(testGitHTTPWithMocks), "*"), + ) + return &testGitHTTPWithMocks{} +} + +type testDiscussionWithMocks struct { + *discussionComponentImpl + mocks *Mocks +} + +func initializeTestDiscussionComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testDiscussionWithMocks { + wire.Build( + MockSuperSet, DiscussionComponentSet, + wire.Struct(new(testDiscussionWithMocks), "*"), + ) + return &testDiscussionWithMocks{} +} + +type testRuntimeArchWithMocks struct { + *runtimeArchitectureComponentImpl + mocks *Mocks +} + +func initializeTestRuntimeArchComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testRuntimeArchWithMocks { + wire.Build( + MockSuperSet, RuntimeArchComponentSet, + wire.Struct(new(testRuntimeArchWithMocks), "*"), + ) + return &testRuntimeArchWithMocks{} +} + +type testMirrorWithMocks struct { + *mirrorComponentImpl + mocks *Mocks +} + +func initializeTestMirrorComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testMirrorWithMocks { + wire.Build( + MockSuperSet, MirrorComponentSet, + wire.Struct(new(testMirrorWithMocks), "*"), + ) + return &testMirrorWithMocks{} +} + +type testCollectionWithMocks struct { + *collectionComponentImpl + mocks *Mocks +} + +func initializeTestCollectionComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testCollectionWithMocks { + wire.Build( + MockSuperSet, CollectionComponentSet, + wire.Struct(new(testCollectionWithMocks), "*"), + ) + return &testCollectionWithMocks{} +} + +type testDatasetWithMocks struct { + *datasetComponentImpl + mocks *Mocks +} + +func initializeTestDatasetComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testDatasetWithMocks { + wire.Build( + MockSuperSet, DatasetComponentSet, + wire.Struct(new(testDatasetWithMocks), "*"), + ) + return &testDatasetWithMocks{} +} + +type testCodeWithMocks struct { + *codeComponentImpl + mocks *Mocks +} + +func initializeTestCodeComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testCodeWithMocks { + wire.Build( + MockSuperSet, CodeComponentSet, + wire.Struct(new(testCodeWithMocks), "*"), + ) + return &testCodeWithMocks{} +} + +type testMultiSyncWithMocks struct { + *multiSyncComponentImpl + mocks *Mocks +} + +func initializeTestMultiSyncComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testMultiSyncWithMocks { + wire.Build( + MockSuperSet, MultiSyncComponentSet, + wire.Struct(new(testMultiSyncWithMocks), "*"), + ) + return &testMultiSyncWithMocks{} +} + +type testInternalWithMocks struct { + *internalComponentImpl + mocks *Mocks +} + +func initializeTestInternalComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testInternalWithMocks { + wire.Build( + MockSuperSet, InternalComponentSet, + wire.Struct(new(testInternalWithMocks), "*"), + ) + return &testInternalWithMocks{} +} + +type testMirrorSourceWithMocks struct { + *mirrorSourceComponentImpl + mocks *Mocks +} + +func initializeTestMirrorSourceComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testMirrorSourceWithMocks { + wire.Build( + MockSuperSet, MirrorSourceComponentSet, + wire.Struct(new(testMirrorSourceWithMocks), "*"), + ) + return &testMirrorSourceWithMocks{} +} + +type testSpaceResourceWithMocks struct { + *spaceResourceComponentImpl + mocks *Mocks +} + +func initializeTestSpaceResourceComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSpaceResourceWithMocks { + wire.Build( + MockSuperSet, SpaceResourceComponentSet, + wire.Struct(new(testSpaceResourceWithMocks), "*"), + ) + return &testSpaceResourceWithMocks{} +} + +type testTagWithMocks struct { + *tagComponentImpl + mocks *Mocks +} + +func initializeTestTagComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testTagWithMocks { + wire.Build( + MockSuperSet, TagComponentSet, + wire.Struct(new(testTagWithMocks), "*"), + ) + return &testTagWithMocks{} +} + +type testRecomWithMocks struct { + *recomComponentImpl + mocks *Mocks +} + +func initializeTestRecomComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testRecomWithMocks { + wire.Build( + MockSuperSet, RecomComponentSet, + wire.Struct(new(testRecomWithMocks), "*"), + ) + return &testRecomWithMocks{} +} + +type testSpaceSdkWithMocks struct { + *spaceSdkComponentImpl + mocks *Mocks +} + +func initializeTestSpaceSdkComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSpaceSdkWithMocks { + wire.Build( + MockSuperSet, SpaceSdkComponentSet, + wire.Struct(new(testSpaceSdkWithMocks), "*"), + ) + return &testSpaceSdkWithMocks{} +} diff --git a/component/wire_gen_test.go b/component/wire_gen_test.go index 8290271a..2fac5ee9 100644 --- a/component/wire_gen_test.go +++ b/component/wire_gen_test.go @@ -13,7 +13,7 @@ import ( "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/deploy" "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/git/mirrorserver" - "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/inference" + "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/parquet" "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/store/s3" "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" @@ -41,15 +41,18 @@ func initializeTestRepoComponent(ctx context.Context, t interface { mockRepoComponent := component.NewMockRepoComponent(t) mockSpaceComponent := component.NewMockSpaceComponent(t) mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) componentMockedComponents := &mockedComponents{ accounting: mockAccountingComponent, repo: mockRepoComponent, tag: mockTagComponent, space: mockSpaceComponent, runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, } - inferenceMockClient := inference.NewMockClient(t) mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) mocks := &Mocks{ stores: mockStores, components: componentMockedComponents, @@ -59,8 +62,9 @@ func initializeTestRepoComponent(ctx context.Context, t interface { mirrorServer: mockMirrorServer, mirrorQueue: mockPriorityQueue, deployer: mockDeployer, - inferenceClient: inferenceMockClient, accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, } componentTestRepoWithMocks := &testRepoWithMocks{ repoComponentImpl: componentRepoComponentImpl, @@ -83,19 +87,22 @@ func initializeTestPromptComponent(ctx context.Context, t interface { mockTagComponent := component.NewMockTagComponent(t) mockSpaceComponent := component.NewMockSpaceComponent(t) mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) componentMockedComponents := &mockedComponents{ accounting: mockAccountingComponent, repo: mockRepoComponent, tag: mockTagComponent, space: mockSpaceComponent, runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, } mockClient := s3.NewMockClient(t) mockMirrorServer := mirrorserver.NewMockMirrorServer(t) mockPriorityQueue := queue.NewMockPriorityQueue(t) mockDeployer := deploy.NewMockDeployer(t) - inferenceMockClient := inference.NewMockClient(t) mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) mocks := &Mocks{ stores: mockStores, components: componentMockedComponents, @@ -105,8 +112,9 @@ func initializeTestPromptComponent(ctx context.Context, t interface { mirrorServer: mockMirrorServer, mirrorQueue: mockPriorityQueue, deployer: mockDeployer, - inferenceClient: inferenceMockClient, accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, } componentTestPromptWithMocks := &testPromptWithMocks{ promptComponentImpl: componentPromptComponentImpl, @@ -128,19 +136,22 @@ func initializeTestUserComponent(ctx context.Context, t interface { componentUserComponentImpl := NewTestUserComponent(mockStores, mockGitServer, mockSpaceComponent, mockRepoComponent, mockDeployer, mockAccountingComponent) mockTagComponent := component.NewMockTagComponent(t) mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) componentMockedComponents := &mockedComponents{ accounting: mockAccountingComponent, repo: mockRepoComponent, tag: mockTagComponent, space: mockSpaceComponent, runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, } mockUserSvcClient := rpc.NewMockUserSvcClient(t) mockClient := s3.NewMockClient(t) mockMirrorServer := mirrorserver.NewMockMirrorServer(t) mockPriorityQueue := queue.NewMockPriorityQueue(t) - inferenceMockClient := inference.NewMockClient(t) mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) mocks := &Mocks{ stores: mockStores, components: componentMockedComponents, @@ -150,8 +161,9 @@ func initializeTestUserComponent(ctx context.Context, t interface { mirrorServer: mockMirrorServer, mirrorQueue: mockPriorityQueue, deployer: mockDeployer, - inferenceClient: inferenceMockClient, accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, } componentTestUserWithMocks := &testUserWithMocks{ userComponentImpl: componentUserComponentImpl, @@ -175,18 +187,21 @@ func initializeTestSpaceComponent(ctx context.Context, t interface { mockTagComponent := component.NewMockTagComponent(t) mockSpaceComponent := component.NewMockSpaceComponent(t) mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) componentMockedComponents := &mockedComponents{ accounting: mockAccountingComponent, repo: mockRepoComponent, tag: mockTagComponent, space: mockSpaceComponent, runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, } mockClient := s3.NewMockClient(t) mockMirrorServer := mirrorserver.NewMockMirrorServer(t) mockPriorityQueue := queue.NewMockPriorityQueue(t) - inferenceMockClient := inference.NewMockClient(t) mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) mocks := &Mocks{ stores: mockStores, components: componentMockedComponents, @@ -196,8 +211,9 @@ func initializeTestSpaceComponent(ctx context.Context, t interface { mirrorServer: mockMirrorServer, mirrorQueue: mockPriorityQueue, deployer: mockDeployer, - inferenceClient: inferenceMockClient, accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, } componentTestSpaceWithMocks := &testSpaceWithMocks{ spaceComponentImpl: componentSpaceComponentImpl, @@ -214,62 +230,614 @@ func initializeTestModelComponent(ctx context.Context, t interface { mockStores := tests.NewMockStores(t) mockRepoComponent := component.NewMockRepoComponent(t) mockSpaceComponent := component.NewMockSpaceComponent(t) - mockClient := inference.NewMockClient(t) mockDeployer := deploy.NewMockDeployer(t) mockAccountingComponent := component.NewMockAccountingComponent(t) mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) mockGitServer := gitserver.NewMockGitServer(t) mockUserSvcClient := rpc.NewMockUserSvcClient(t) - componentModelComponentImpl := NewTestModelComponent(config, mockStores, mockRepoComponent, mockSpaceComponent, mockClient, mockDeployer, mockAccountingComponent, mockRuntimeArchitectureComponent, mockGitServer, mockUserSvcClient) + componentModelComponentImpl := NewTestModelComponent(config, mockStores, mockRepoComponent, mockSpaceComponent, mockDeployer, mockAccountingComponent, mockRuntimeArchitectureComponent, mockGitServer, mockUserSvcClient) mockTagComponent := component.NewMockTagComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) componentMockedComponents := &mockedComponents{ accounting: mockAccountingComponent, repo: mockRepoComponent, tag: mockTagComponent, space: mockSpaceComponent, runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, } - s3MockClient := s3.NewMockClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestModelWithMocks := &testModelWithMocks{ + modelComponentImpl: componentModelComponentImpl, + mocks: mocks, + } + return componentTestModelWithMocks +} + +func initializeTestAccountingComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testAccountingWithMocks { + mockStores := tests.NewMockStores(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + componentAccountingComponentImpl := NewTestAccountingComponent(mockStores, mockAccountingClient) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestAccountingWithMocks := &testAccountingWithMocks{ + accountingComponentImpl: componentAccountingComponentImpl, + mocks: mocks, + } + return componentTestAccountingWithMocks +} + +func initializeTestDatasetViewerComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testDatasetViewerWithMocks { + mockStores := tests.NewMockStores(t) + config := ProvideTestConfig() + mockRepoComponent := component.NewMockRepoComponent(t) + mockGitServer := gitserver.NewMockGitServer(t) + mockReader := parquet.NewMockReader(t) + componentDatasetViewerComponentImpl := NewTestDatasetViewerComponent(mockStores, config, mockRepoComponent, mockGitServer, mockReader) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestDatasetViewerWithMocks := &testDatasetViewerWithMocks{ + datasetViewerComponentImpl: componentDatasetViewerComponentImpl, + mocks: mocks, + } + return componentTestDatasetViewerWithMocks +} + +func initializeTestGitHTTPComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testGitHTTPWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockGitServer := gitserver.NewMockGitServer(t) + mockClient := s3.NewMockClient(t) + componentGitHTTPComponentImpl := NewTestGitHTTPComponent(config, mockStores, mockRepoComponent, mockGitServer, mockClient) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestGitHTTPWithMocks := &testGitHTTPWithMocks{ + gitHTTPComponentImpl: componentGitHTTPComponentImpl, + mocks: mocks, + } + return componentTestGitHTTPWithMocks +} + +func initializeTestDiscussionComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testDiscussionWithMocks { + mockStores := tests.NewMockStores(t) + componentDiscussionComponentImpl := NewTestDiscussionComponent(mockStores) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestDiscussionWithMocks := &testDiscussionWithMocks{ + discussionComponentImpl: componentDiscussionComponentImpl, + mocks: mocks, + } + return componentTestDiscussionWithMocks +} + +func initializeTestRuntimeArchComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testRuntimeArchWithMocks { + mockStores := tests.NewMockStores(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockGitServer := gitserver.NewMockGitServer(t) + componentRuntimeArchitectureComponentImpl := NewTestRuntimeArchitectureComponent(mockStores, mockRepoComponent, mockGitServer) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestRuntimeArchWithMocks := &testRuntimeArchWithMocks{ + runtimeArchitectureComponentImpl: componentRuntimeArchitectureComponentImpl, + mocks: mocks, + } + return componentTestRuntimeArchWithMocks +} + +func initializeTestMirrorComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testMirrorWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockGitServer := gitserver.NewMockGitServer(t) + mockClient := s3.NewMockClient(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + componentMirrorComponentImpl := NewTestMirrorComponent(config, mockStores, mockMirrorServer, mockRepoComponent, mockGitServer, mockClient, mockPriorityQueue) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestMirrorWithMocks := &testMirrorWithMocks{ + mirrorComponentImpl: componentMirrorComponentImpl, + mocks: mocks, + } + return componentTestMirrorWithMocks +} + +func initializeTestCollectionComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testCollectionWithMocks { + mockStores := tests.NewMockStores(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + componentCollectionComponentImpl := NewTestCollectionComponent(mockStores, mockUserSvcClient, mockSpaceComponent) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestCollectionWithMocks := &testCollectionWithMocks{ + collectionComponentImpl: componentCollectionComponentImpl, + mocks: mocks, + } + return componentTestCollectionWithMocks +} + +func initializeTestDatasetComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testDatasetWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + mockGitServer := gitserver.NewMockGitServer(t) + componentDatasetComponentImpl := NewTestDatasetComponent(config, mockStores, mockRepoComponent, mockUserSvcClient, mockSensitiveComponent, mockGitServer) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestDatasetWithMocks := &testDatasetWithMocks{ + datasetComponentImpl: componentDatasetComponentImpl, + mocks: mocks, + } + return componentTestDatasetWithMocks +} + +func initializeTestCodeComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testCodeWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockGitServer := gitserver.NewMockGitServer(t) + componentCodeComponentImpl := NewTestCodeComponent(config, mockStores, mockRepoComponent, mockUserSvcClient, mockGitServer) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestCodeWithMocks := &testCodeWithMocks{ + codeComponentImpl: componentCodeComponentImpl, + mocks: mocks, + } + return componentTestCodeWithMocks +} + +func initializeTestMultiSyncComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testMultiSyncWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockGitServer := gitserver.NewMockGitServer(t) + componentMultiSyncComponentImpl := NewTestMultiSyncComponent(config, mockStores, mockGitServer) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) mockMirrorServer := mirrorserver.NewMockMirrorServer(t) mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) mocks := &Mocks{ stores: mockStores, components: componentMockedComponents, gitServer: mockGitServer, userSvcClient: mockUserSvcClient, - s3Client: s3MockClient, + s3Client: mockClient, mirrorServer: mockMirrorServer, mirrorQueue: mockPriorityQueue, deployer: mockDeployer, - inferenceClient: mockClient, accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, } - componentTestModelWithMocks := &testModelWithMocks{ - modelComponentImpl: componentModelComponentImpl, - mocks: mocks, + componentTestMultiSyncWithMocks := &testMultiSyncWithMocks{ + multiSyncComponentImpl: componentMultiSyncComponentImpl, + mocks: mocks, } - return componentTestModelWithMocks + return componentTestMultiSyncWithMocks } -func initializeTestAccountingComponent(ctx context.Context, t interface { +func initializeTestInternalComponent(ctx context.Context, t interface { Cleanup(func()) mock.TestingT -}) *testAccountingWithMocks { +}) *testInternalWithMocks { + config := ProvideTestConfig() mockStores := tests.NewMockStores(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockGitServer := gitserver.NewMockGitServer(t) + componentInternalComponentImpl := NewTestInternalComponent(config, mockStores, mockRepoComponent, mockGitServer) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) mockAccountingClient := accounting.NewMockAccountingClient(t) - componentAccountingComponentImpl := NewTestAccountingComponent(mockStores, mockAccountingClient) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestInternalWithMocks := &testInternalWithMocks{ + internalComponentImpl: componentInternalComponentImpl, + mocks: mocks, + } + return componentTestInternalWithMocks +} + +func initializeTestMirrorSourceComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testMirrorSourceWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + componentMirrorSourceComponentImpl := NewTestMirrorSourceComponent(config, mockStores) mockAccountingComponent := component.NewMockAccountingComponent(t) mockRepoComponent := component.NewMockRepoComponent(t) mockTagComponent := component.NewMockTagComponent(t) mockSpaceComponent := component.NewMockSpaceComponent(t) mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) componentMockedComponents := &mockedComponents{ accounting: mockAccountingComponent, repo: mockRepoComponent, tag: mockTagComponent, space: mockSpaceComponent, runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, } mockGitServer := gitserver.NewMockGitServer(t) mockUserSvcClient := rpc.NewMockUserSvcClient(t) @@ -277,7 +845,9 @@ func initializeTestAccountingComponent(ctx context.Context, t interface { mockMirrorServer := mirrorserver.NewMockMirrorServer(t) mockPriorityQueue := queue.NewMockPriorityQueue(t) mockDeployer := deploy.NewMockDeployer(t) - inferenceMockClient := inference.NewMockClient(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) mocks := &Mocks{ stores: mockStores, components: componentMockedComponents, @@ -287,14 +857,215 @@ func initializeTestAccountingComponent(ctx context.Context, t interface { mirrorServer: mockMirrorServer, mirrorQueue: mockPriorityQueue, deployer: mockDeployer, - inferenceClient: inferenceMockClient, accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, } - componentTestAccountingWithMocks := &testAccountingWithMocks{ - accountingComponentImpl: componentAccountingComponentImpl, - mocks: mocks, + componentTestMirrorSourceWithMocks := &testMirrorSourceWithMocks{ + mirrorSourceComponentImpl: componentMirrorSourceComponentImpl, + mocks: mocks, } - return componentTestAccountingWithMocks + return componentTestMirrorSourceWithMocks +} + +func initializeTestSpaceResourceComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSpaceResourceWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingComponent := component.NewMockAccountingComponent(t) + componentSpaceResourceComponentImpl := NewTestSpaceResourceComponent(config, mockStores, mockDeployer, mockAccountingComponent) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestSpaceResourceWithMocks := &testSpaceResourceWithMocks{ + spaceResourceComponentImpl: componentSpaceResourceComponentImpl, + mocks: mocks, + } + return componentTestSpaceResourceWithMocks +} + +func initializeTestTagComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testTagWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + componentTagComponentImpl := NewTestTagComponent(config, mockStores, mockModerationSvcClient) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestTagWithMocks := &testTagWithMocks{ + tagComponentImpl: componentTagComponentImpl, + mocks: mocks, + } + return componentTestTagWithMocks +} + +func initializeTestRecomComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testRecomWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockGitServer := gitserver.NewMockGitServer(t) + componentRecomComponentImpl := NewTestRecomComponent(config, mockStores, mockGitServer) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestRecomWithMocks := &testRecomWithMocks{ + recomComponentImpl: componentRecomComponentImpl, + mocks: mocks, + } + return componentTestRecomWithMocks +} + +func initializeTestSpaceSdkComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSpaceSdkWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + componentSpaceSdkComponentImpl := NewTestSpaceSdkComponent(config, mockStores) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestSpaceSdkWithMocks := &testSpaceSdkWithMocks{ + spaceSdkComponentImpl: componentSpaceSdkComponentImpl, + mocks: mocks, + } + return componentTestSpaceSdkWithMocks } // wire.go: @@ -328,3 +1099,78 @@ type testAccountingWithMocks struct { *accountingComponentImpl mocks *Mocks } + +type testDatasetViewerWithMocks struct { + *datasetViewerComponentImpl + mocks *Mocks +} + +type testGitHTTPWithMocks struct { + *gitHTTPComponentImpl + mocks *Mocks +} + +type testDiscussionWithMocks struct { + *discussionComponentImpl + mocks *Mocks +} + +type testRuntimeArchWithMocks struct { + *runtimeArchitectureComponentImpl + mocks *Mocks +} + +type testMirrorWithMocks struct { + *mirrorComponentImpl + mocks *Mocks +} + +type testCollectionWithMocks struct { + *collectionComponentImpl + mocks *Mocks +} + +type testDatasetWithMocks struct { + *datasetComponentImpl + mocks *Mocks +} + +type testCodeWithMocks struct { + *codeComponentImpl + mocks *Mocks +} + +type testMultiSyncWithMocks struct { + *multiSyncComponentImpl + mocks *Mocks +} + +type testInternalWithMocks struct { + *internalComponentImpl + mocks *Mocks +} + +type testMirrorSourceWithMocks struct { + *mirrorSourceComponentImpl + mocks *Mocks +} + +type testSpaceResourceWithMocks struct { + *spaceResourceComponentImpl + mocks *Mocks +} + +type testTagWithMocks struct { + *tagComponentImpl + mocks *Mocks +} + +type testRecomWithMocks struct { + *recomComponentImpl + mocks *Mocks +} + +type testSpaceSdkWithMocks struct { + *spaceSdkComponentImpl + mocks *Mocks +} diff --git a/component/wireset.go b/component/wireset.go index 06f01a7d..b25cc075 100644 --- a/component/wireset.go +++ b/component/wireset.go @@ -6,7 +6,7 @@ import ( mock_deploy "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/deploy" mock_git "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/git/gitserver" mock_mirror "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/git/mirrorserver" - mock_inference "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/inference" + mock_preader "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/parquet" mock_rpc "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/rpc" mock_s3 "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/store/s3" mock_component "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" @@ -15,8 +15,8 @@ import ( "opencsg.com/csghub-server/builder/deploy" "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/git/mirrorserver" - "opencsg.com/csghub-server/builder/inference" "opencsg.com/csghub-server/builder/llm" + "opencsg.com/csghub-server/builder/parquet" "opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/builder/store/s3" "opencsg.com/csghub-server/common/config" @@ -30,6 +30,7 @@ type mockedComponents struct { tag *mock_component.MockTagComponent space *mock_component.MockSpaceComponent runtimeArchitecture *mock_component.MockRuntimeArchitectureComponent + sensitive *mock_component.MockSensitiveComponent } var MockedStoreSet = wire.NewSet( @@ -47,6 +48,8 @@ var MockedComponentSet = wire.NewSet( wire.Bind(new(SpaceComponent), new(*mock_component.MockSpaceComponent)), mock_component.NewMockRuntimeArchitectureComponent, wire.Bind(new(RuntimeArchitectureComponent), new(*mock_component.MockRuntimeArchitectureComponent)), + mock_component.NewMockSensitiveComponent, + wire.Bind(new(SensitiveComponent), new(*mock_component.MockSensitiveComponent)), ) var MockedGitServerSet = wire.NewSet( @@ -79,16 +82,21 @@ var MockedMirrorQueueSet = wire.NewSet( wire.Bind(new(queue.PriorityQueue), new(*mock_mirror_queue.MockPriorityQueue)), ) -var MockedInferenceClientSet = wire.NewSet( - mock_inference.NewMockClient, - wire.Bind(new(inference.Client), new(*mock_inference.MockClient)), -) - var MockedAccountingClientSet = wire.NewSet( mock_accounting.NewMockAccountingClient, wire.Bind(new(accounting.AccountingClient), new(*mock_accounting.MockAccountingClient)), ) +var MockedParquetReaderSet = wire.NewSet( + mock_preader.NewMockReader, + wire.Bind(new(parquet.Reader), new(*mock_preader.MockReader)), +) + +var MockedModerationSvcClientSet = wire.NewSet( + mock_rpc.NewMockModerationSvcClient, + wire.Bind(new(rpc.ModerationSvcClient), new(*mock_rpc.MockModerationSvcClient)), +) + type Mocks struct { stores *tests.MockStores components *mockedComponents @@ -98,8 +106,9 @@ type Mocks struct { mirrorServer *mock_mirror.MockMirrorServer mirrorQueue *mock_mirror_queue.MockPriorityQueue deployer *mock_deploy.MockDeployer - inferenceClient *mock_inference.MockClient accountingClient *mock_accounting.MockAccountingClient + preader *mock_preader.MockReader + moderationClient *mock_rpc.MockModerationSvcClient } var AllMockSet = wire.NewSet( @@ -114,7 +123,8 @@ func ProvideTestConfig() *config.Config { var MockSuperSet = wire.NewSet( MockedComponentSet, AllMockSet, MockedStoreSet, MockedGitServerSet, MockedUserSvcSet, MockedS3Set, MockedDeployerSet, ProvideTestConfig, MockedMirrorServerSet, - MockedMirrorQueueSet, MockedInferenceClientSet, MockedAccountingClientSet, + MockedMirrorQueueSet, MockedAccountingClientSet, MockedParquetReaderSet, + MockedModerationSvcClientSet, ) func NewTestRepoComponent(config *config.Config, stores *tests.MockStores, rpcUser rpc.UserSvcClient, gitServer gitserver.GitServer, tagComponent TagComponent, s3Client s3.Client, deployer deploy.Deployer, accountingComponent AccountingComponent, mq queue.PriorityQueue, mirrorServer mirrorserver.MirrorServer) *repoComponentImpl { @@ -232,7 +242,6 @@ func NewTestModelComponent( stores *tests.MockStores, repoComponent RepoComponent, spaceComponent SpaceComponent, - inferClient inference.Client, deployer deploy.Deployer, accountingComponent AccountingComponent, runtimeArchComponent RuntimeArchitectureComponent, @@ -248,7 +257,6 @@ func NewTestModelComponent( modelStore: stores.Model, repoStore: stores.Repo, spaceResourceStore: stores.SpaceResource, - inferClient: inferClient, userStore: stores.User, deployer: deployer, accountingComponent: accountingComponent, @@ -276,3 +284,206 @@ func NewTestAccountingComponent(stores *tests.MockStores, accountingClient accou } var AccountingComponentSet = wire.NewSet(NewTestAccountingComponent) + +func NewTestDatasetViewerComponent(stores *tests.MockStores, cfg *config.Config, repoComponent RepoComponent, gitServer gitserver.GitServer, preader parquet.Reader) *datasetViewerComponentImpl { + return &datasetViewerComponentImpl{ + cfg: cfg, + preader: preader, + } +} + +var DatasetViewerComponentSet = wire.NewSet(NewTestDatasetViewerComponent) + +func NewTestGitHTTPComponent( + config *config.Config, + stores *tests.MockStores, + repoComponent RepoComponent, + gitServer gitserver.GitServer, + s3Client s3.Client, +) *gitHTTPComponentImpl { + config.APIServer.PublicDomain = "https://foo.com" + config.APIServer.SSHDomain = "ssh://test@127.0.0.1" + return &gitHTTPComponentImpl{ + config: config, + repoComponent: repoComponent, + repoStore: stores.Repo, + userStore: stores.User, + gitServer: gitServer, + s3Client: s3Client, + lfsMetaObjectStore: stores.LfsMetaObject, + lfsLockStore: stores.LfsLock, + } +} + +var GitHTTPComponentSet = wire.NewSet(NewTestGitHTTPComponent) + +func NewTestDiscussionComponent( + stores *tests.MockStores, +) *discussionComponentImpl { + return &discussionComponentImpl{ + repoStore: stores.Repo, + userStore: stores.User, + discussionStore: stores.Discussion, + } +} + +var DiscussionComponentSet = wire.NewSet(NewTestDiscussionComponent) + +func NewTestRuntimeArchitectureComponent(stores *tests.MockStores, repoComponent RepoComponent, gitServer gitserver.GitServer) *runtimeArchitectureComponentImpl { + return &runtimeArchitectureComponentImpl{ + repoComponent: repoComponent, + repoStore: stores.Repo, + repoRuntimeFrameworkStore: stores.RepoRuntimeFramework, + runtimeFrameworksStore: stores.RuntimeFramework, + runtimeArchStore: stores.RuntimeArch, + resouceModelStore: stores.ResourceModel, + tagStore: stores.Tag, + gitServer: gitServer, + } +} + +var RuntimeArchComponentSet = wire.NewSet(NewTestRuntimeArchitectureComponent) + +func NewTestMirrorComponent(config *config.Config, stores *tests.MockStores, mirrorServer mirrorserver.MirrorServer, repoComponent RepoComponent, gitServer gitserver.GitServer, s3Client s3.Client, mq queue.PriorityQueue) *mirrorComponentImpl { + return &mirrorComponentImpl{ + tokenStore: stores.GitServerAccessToken, + mirrorServer: mirrorServer, + repoComp: repoComponent, + git: gitServer, + s3Client: s3Client, + modelStore: stores.Model, + datasetStore: stores.Dataset, + codeStore: stores.Code, + repoStore: stores.Repo, + mirrorStore: stores.Mirror, + mirrorSourceStore: stores.MirrorSource, + namespaceStore: stores.Namespace, + userStore: stores.User, + config: config, + mq: mq, + } +} + +var MirrorComponentSet = wire.NewSet(NewTestMirrorComponent) + +func NewTestCollectionComponent(stores *tests.MockStores, userSvcClient rpc.UserSvcClient, spaceComponent SpaceComponent) *collectionComponentImpl { + return &collectionComponentImpl{ + collectionStore: stores.Collection, + orgStore: stores.Org, + repoStore: stores.Repo, + userStore: stores.User, + userLikesStore: stores.UserLikes, + userSvcClient: userSvcClient, + spaceComponent: spaceComponent, + } +} + +var CollectionComponentSet = wire.NewSet(NewTestCollectionComponent) + +func NewTestDatasetComponent(config *config.Config, stores *tests.MockStores, repoComponent RepoComponent, userSvcClient rpc.UserSvcClient, sensitiveComponent SensitiveComponent, gitServer gitserver.GitServer) *datasetComponentImpl { + return &datasetComponentImpl{ + config: config, + repoComponent: repoComponent, + tagStore: stores.Tag, + datasetStore: stores.Dataset, + repoStore: stores.Repo, + namespaceStore: stores.Namespace, + userStore: stores.User, + sensitiveComponent: sensitiveComponent, + gitServer: gitServer, + userLikesStore: stores.UserLikes, + userSvcClient: userSvcClient, + } +} + +var DatasetComponentSet = wire.NewSet(NewTestDatasetComponent) + +func NewTestCodeComponent(config *config.Config, stores *tests.MockStores, repoComponent RepoComponent, userSvcClient rpc.UserSvcClient, gitServer gitserver.GitServer) *codeComponentImpl { + return &codeComponentImpl{ + config: config, + repoComponent: repoComponent, + codeStore: stores.Code, + repoStore: stores.Repo, + userLikesStore: stores.UserLikes, + gitServer: gitServer, + userSvcClient: userSvcClient, + } +} + +var CodeComponentSet = wire.NewSet(NewTestCodeComponent) + +func NewTestMultiSyncComponent(config *config.Config, stores *tests.MockStores, gitServer gitserver.GitServer) *multiSyncComponentImpl { + return &multiSyncComponentImpl{ + multiSyncStore: stores.MultiSync, + repoStore: stores.Repo, + modelStore: stores.Model, + datasetStore: stores.Dataset, + namespaceStore: stores.Namespace, + userStore: stores.User, + syncVersionStore: stores.SyncVersion, + tagStore: stores.Tag, + fileStore: stores.File, + gitServer: gitServer, + } +} + +var MultiSyncComponentSet = wire.NewSet(NewTestMultiSyncComponent) + +func NewTestInternalComponent(config *config.Config, stores *tests.MockStores, repoComponent RepoComponent, gitServer gitserver.GitServer) *internalComponentImpl { + return &internalComponentImpl{ + config: config, + sshKeyStore: stores.SSH, + repoStore: stores.Repo, + tokenStore: stores.AccessToken, + namespaceStore: stores.Namespace, + repoComponent: repoComponent, + gitServer: gitServer, + } +} + +var InternalComponentSet = wire.NewSet(NewTestInternalComponent) + +func NewTestMirrorSourceComponent(config *config.Config, stores *tests.MockStores) *mirrorSourceComponentImpl { + return &mirrorSourceComponentImpl{ + mirrorSourceStore: stores.MirrorSource, + userStore: stores.User, + } +} + +var MirrorSourceComponentSet = wire.NewSet(NewTestMirrorSourceComponent) + +func NewTestSpaceResourceComponent(config *config.Config, stores *tests.MockStores, deployer deploy.Deployer, accountComponent AccountingComponent) *spaceResourceComponentImpl { + return &spaceResourceComponentImpl{ + deployer: deployer, + } +} + +var SpaceResourceComponentSet = wire.NewSet(NewTestSpaceResourceComponent) + +func NewTestTagComponent(config *config.Config, stores *tests.MockStores, sensitiveChecker rpc.ModerationSvcClient) *tagComponentImpl { + return &tagComponentImpl{ + tagStore: stores.Tag, + repoStore: stores.Repo, + sensitiveChecker: sensitiveChecker, + } +} + +var TagComponentSet = wire.NewSet(NewTestTagComponent) + +func NewTestRecomComponent(config *config.Config, stores *tests.MockStores, gitServer gitserver.GitServer) *recomComponentImpl { + return &recomComponentImpl{ + recomStore: stores.Recom, + repoStore: stores.Repo, + gitServer: gitServer, + } +} + +var RecomComponentSet = wire.NewSet(NewTestRecomComponent) + +func NewTestSpaceSdkComponent(config *config.Config, stores *tests.MockStores) *spaceSdkComponentImpl { + return &spaceSdkComponentImpl{ + spaceSdkStore: stores.SpaceSdk, + } +} + +var SpaceSdkComponentSet = wire.NewSet(NewTestSpaceSdkComponent) diff --git a/go.sum b/go.sum index e6efc3c7..db84171c 100644 --- a/go.sum +++ b/go.sum @@ -307,6 +307,7 @@ github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af h1:kmjWCqn2qkEml422C2 github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= From 21c282b6668332806d8271c1a067558987a41abf Mon Sep 17 00:00:00 2001 From: vincent Date: Mon, 16 Dec 2024 14:35:51 +0800 Subject: [PATCH 05/34] Add repo access check tests (#208) * Add repo access check tests * Fix tests --- component/repo_test.go | 143 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 142 insertions(+), 1 deletion(-) diff --git a/component/repo_test.go b/component/repo_test.go index 4c36b190..3959b7c3 100644 --- a/component/repo_test.go +++ b/component/repo_test.go @@ -1404,7 +1404,7 @@ func TestRepoComponent_UpdateTags(t *testing.T) { } -func TestRepoComponent_checkCurrentUserPermission(t *testing.T) { +func TestRepoComponent_CheckCurrentUserPermission(t *testing.T) { t.Run("can read self-owned", func(t *testing.T) { ctx := context.TODO() @@ -1672,3 +1672,144 @@ func TestRepoComponent_Tree(t *testing.T) { } } + +func TestRepoComponent_AllowReadAccess(t *testing.T) { + t.Run("should return false if repo find return error", func(t *testing.T) { + ctx := context.TODO() + repoComp := initializeTestRepoComponent(ctx, t) + repoComp.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.ModelRepo, "namespace", "name").Return(&database.Repository{}, errors.New("error")) + allow, err := repoComp.AllowReadAccess(ctx, types.ModelRepo, "namespace", "name", "user_name") + require.Error(t, fmt.Errorf("failed to find repo, error: %w", err)) + require.False(t, allow) + }) +} + +func TestRepoComponent_AllowWriteAccess(t *testing.T) { + t.Run("should return false if username is empty", func(t *testing.T) { + ctx := context.TODO() + repoComp := initializeTestRepoComponent(ctx, t) + repoComp.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.ModelRepo, "namespace", "name").Return(&database.Repository{ + ID: 1, + Name: "name", + Path: "namespace/name", + Private: false, + }, nil) + allow, err := repoComp.AllowWriteAccess(ctx, types.ModelRepo, "namespace", "name", "") + require.Error(t, err, ErrUserNotFound) + require.False(t, allow) + }) + + t.Run("should return false if repo find return error", func(t *testing.T) { + ctx := context.TODO() + repoComp := initializeTestRepoComponent(ctx, t) + repoComp.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.ModelRepo, "namespace", "name").Return(&database.Repository{}, errors.New("error")) + allow, err := repoComp.AllowWriteAccess(ctx, types.ModelRepo, "namespace", "name", "user_name") + require.Error(t, err, fmt.Errorf("failed to find repo, error: %w", err)) + require.False(t, allow) + }) + + t.Run("should return false if user has no write access for public repo", func(t *testing.T) { + ctx := context.TODO() + repoComp := initializeTestRepoComponent(ctx, t) + repoComp.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.ModelRepo, "namespace", "name").Return(&database.Repository{ + ID: 1, + Name: "name", + Path: "namespace/name", + Private: false, + }, nil) + repoComp.mocks.stores.NamespaceMock().EXPECT().FindByPath(ctx, "namespace").Return(database.Namespace{ + ID: 1, + Path: "namespace", + NamespaceType: database.UserNamespace, + }, nil) + repoComp.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user_name").Return(database.User{ + ID: 1, + Username: "user_name", + Email: "user@example.com", + RoleMask: "", + }, nil) + allow, err := repoComp.AllowAdminAccess(ctx, types.ModelRepo, "namespace", "name", "user_name") + require.NoError(t, err) + require.False(t, allow) + }) +} + +func TestRepoComponent_AllowAdminAccess(t *testing.T) { + t.Run("should return false if username is empty", func(t *testing.T) { + ctx := context.TODO() + repoComp := initializeTestRepoComponent(ctx, t) + repoComp.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.ModelRepo, "namespace", "name").Return(&database.Repository{ + ID: 1, + Name: "name", + Path: "namespace/name", + Private: false, + }, nil) + allow, err := repoComp.AllowAdminAccess(ctx, types.ModelRepo, "namespace", "name", "") + require.Error(t, err, ErrUserNotFound) + require.False(t, allow) + }) + + t.Run("should return false if repo find return error", func(t *testing.T) { + ctx := context.TODO() + repoComp := initializeTestRepoComponent(ctx, t) + repoComp.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.ModelRepo, "namespace", "name").Return(&database.Repository{}, errors.New("error")) + allow, err := repoComp.AllowAdminAccess(ctx, types.ModelRepo, "namespace", "name", "user_name") + require.Error(t, err, fmt.Errorf("failed to find repo, error: %w", err)) + require.False(t, allow) + }) + + t.Run("should return false if user has no admin access for public repo", func(t *testing.T) { + ctx := context.TODO() + repoComp := initializeTestRepoComponent(ctx, t) + repoComp.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.ModelRepo, "namespace", "name").Return(&database.Repository{ + ID: 1, + Name: "name", + Path: "namespace/name", + Private: false, + }, nil) + repoComp.mocks.stores.NamespaceMock().EXPECT().FindByPath(ctx, "namespace").Return(database.Namespace{ + ID: 1, + Path: "namespace", + NamespaceType: database.UserNamespace, + }, nil) + repoComp.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user_name").Return(database.User{ + ID: 1, + Username: "user_name", + Email: "user@example.com", + RoleMask: "", + }, nil) + allow, err := repoComp.AllowAdminAccess(ctx, types.ModelRepo, "namespace", "name", "user_name") + require.NoError(t, err) + require.False(t, allow) + }) +} + +func TestRepoComponent_AllowReadAccessRepo(t *testing.T) { + t.Run("should return true if repo is public", func(t *testing.T) { + ctx := context.TODO() + repoComp := initializeTestRepoComponent(ctx, t) + + allow, err := repoComp.AllowReadAccessRepo(ctx, &database.Repository{ + ID: 1, + Name: "name", + Path: "namespace/name", + Private: false, + }, "user_name") + require.NoError(t, err) + require.True(t, allow) + }) + + t.Run("should return false if repo is private and username is empty", func(t *testing.T) { + ctx := context.TODO() + repoComp := initializeTestRepoComponent(ctx, t) + + allow, err := repoComp.AllowReadAccessRepo(ctx, &database.Repository{ + ID: 1, + Name: "name", + Path: "namespace/name", + Private: true, + }, "") + require.Error(t, err, ErrUserNotFound) + require.False(t, allow) + }) +} From a26e3fbb0c60b5b4b28daf1f0f0d7fc0f21308f3 Mon Sep 17 00:00:00 2001 From: Yiling-J Date: Mon, 16 Dec 2024 16:30:00 +0800 Subject: [PATCH 06/34] add missing component tests (#212) --- Makefile | 6 +- api/handler/callback/git_callback.go | 2 +- common/tests/stores.go | 24 ++ common/types/prompt.go | 32 +- component/callback/git_callback.go | 174 ++++---- component/callback/git_callback_test.go | 187 ++++++++ component/callback/wire.go | 27 ++ component/callback/wire_gen_test.go | 52 +++ component/callback/wireset.go | 53 +++ component/cluster_test.go | 42 ++ component/evaluation.go | 26 +- component/evaluation_test.go | 84 ++-- component/event.go | 6 +- component/event_test.go | 22 + component/hf_dataset.go | 35 +- component/hf_dataset_test.go | 71 +++ component/list.go | 16 +- component/list_test.go | 46 ++ component/repo_file.go | 24 +- component/repo_file_test.go | 89 ++++ component/sensitive_test.go | 29 +- component/sshkey.go | 31 +- component/sshkey_test.go | 67 +++ component/sync_client_setting_test.go | 43 ++ component/telemetry.go | 21 +- component/telemetry_test.go | 56 +++ component/wire.go | 160 +++++++ component/wire_gen_test.go | 550 ++++++++++++++++++++++++ component/wireset.go | 101 +++++ 29 files changed, 1837 insertions(+), 239 deletions(-) create mode 100644 component/callback/git_callback_test.go create mode 100644 component/callback/wire.go create mode 100644 component/callback/wire_gen_test.go create mode 100644 component/callback/wireset.go create mode 100644 component/cluster_test.go create mode 100644 component/event_test.go create mode 100644 component/hf_dataset_test.go create mode 100644 component/list_test.go create mode 100644 component/repo_file_test.go create mode 100644 component/sshkey_test.go create mode 100644 component/sync_client_setting_test.go create mode 100644 component/telemetry_test.go diff --git a/Makefile b/Makefile index 21c7b05a..739643b5 100644 --- a/Makefile +++ b/Makefile @@ -13,10 +13,12 @@ cover: mock_wire: @echo "Running wire for component mocks..." - @go run -mod=mod github.com/google/wire/cmd/wire opencsg.com/csghub-server/component + @go run -mod=mod github.com/google/wire/cmd/wire opencsg.com/csghub-server/component/... @if [ $$? -eq 0 ]; then \ - echo "Renaming wire_gen.go to wire_gen_test.go..."; \ + echo "Renaming component wire_gen.go to wire_gen_test.go..."; \ mv component/wire_gen.go component/wire_gen_test.go; \ + echo "Renaming component/callback wire_gen.go to wire_gen_test.go..."; \ + mv component/callback/wire_gen.go component/callback/wire_gen_test.go; \ else \ echo "Wire failed, skipping renaming."; \ fi diff --git a/api/handler/callback/git_callback.go b/api/handler/callback/git_callback.go index c0d62056..1fda23f8 100644 --- a/api/handler/callback/git_callback.go +++ b/api/handler/callback/git_callback.go @@ -14,7 +14,7 @@ import ( ) type GitCallbackHandler struct { - cbc *component.GitCallbackComponent + cbc component.GitCallbackComponent config *config.Config } diff --git a/common/tests/stores.go b/common/tests/stores.go index 2197113f..507a7157 100644 --- a/common/tests/stores.go +++ b/common/tests/stores.go @@ -14,6 +14,7 @@ type MockStores struct { Model database.ModelStore SpaceResource database.SpaceResourceStore Tag database.TagStore + TagRule database.TagRuleStore Dataset database.DatasetStore PromptConversation database.PromptConversationStore PromptPrefix database.PromptPrefixStore @@ -44,6 +45,9 @@ type MockStores struct { MultiSync database.MultiSyncStore File database.FileStore SSH database.SSHKeyStore + Telemetry database.TelemetryStore + RepoFile database.RepoFileStore + Event database.EventStore } func NewMockStores(t interface { @@ -88,6 +92,10 @@ func NewMockStores(t interface { MultiSync: mockdb.NewMockMultiSyncStore(t), File: mockdb.NewMockFileStore(t), SSH: mockdb.NewMockSSHKeyStore(t), + Telemetry: mockdb.NewMockTelemetryStore(t), + RepoFile: mockdb.NewMockRepoFileStore(t), + Event: mockdb.NewMockEventStore(t), + TagRule: mockdb.NewMockTagRuleStore(t), } } @@ -119,6 +127,10 @@ func (s *MockStores) TagMock() *mockdb.MockTagStore { return s.Tag.(*mockdb.MockTagStore) } +func (s *MockStores) TagRuleMock() *mockdb.MockTagRuleStore { + return s.TagRule.(*mockdb.MockTagRuleStore) +} + func (s *MockStores) DatasetMock() *mockdb.MockDatasetStore { return s.Dataset.(*mockdb.MockDatasetStore) } @@ -238,3 +250,15 @@ func (s *MockStores) FileMock() *mockdb.MockFileStore { func (s *MockStores) SSHMock() *mockdb.MockSSHKeyStore { return s.SSH.(*mockdb.MockSSHKeyStore) } + +func (s *MockStores) TelemetryMock() *mockdb.MockTelemetryStore { + return s.Telemetry.(*mockdb.MockTelemetryStore) +} + +func (s *MockStores) RepoFileMock() *mockdb.MockRepoFileStore { + return s.RepoFile.(*mockdb.MockRepoFileStore) +} + +func (s *MockStores) EventMock() *mockdb.MockEventStore { + return s.Event.(*mockdb.MockEventStore) +} diff --git a/common/types/prompt.go b/common/types/prompt.go index 31522bbe..7ad62e90 100644 --- a/common/types/prompt.go +++ b/common/types/prompt.go @@ -1,6 +1,8 @@ package types -import "time" +import ( + "time" +) type PromptReq struct { Namespace string `json:"namespace"` @@ -101,3 +103,31 @@ type PromptRes struct { CanManage bool `json:"can_manage"` Namespace *Namespace `json:"namespace"` } + +type Prompt struct { + Title string `json:"title" binding:"required"` + Content string `json:"content" binding:"required"` + Language string `json:"language" binding:"required"` + Tags []string `json:"tags"` + Type string `json:"type"` // "text|image|video|audio" + Source string `json:"source"` + Author string `json:"author"` + Time string `json:"time"` + Copyright string `json:"copyright"` + Feedback []string `json:"feedback"` +} + +type PromptOutput struct { + Prompt + FilePath string `json:"file_path"` + CanWrite bool `json:"can_write"` + CanManage bool `json:"can_manage"` +} + +type CreatePromptReq struct { + Prompt +} + +type UpdatePromptReq struct { + Prompt +} diff --git a/component/callback/git_callback.go b/component/callback/git_callback.go index b1bddcd4..0810f1e4 100644 --- a/component/callback/git_callback.go +++ b/component/callback/git_callback.go @@ -21,33 +21,39 @@ import ( "opencsg.com/csghub-server/component" ) -// define GitCallbackComponent struct -type GitCallbackComponent struct { - config *config.Config - gs gitserver.GitServer - tc component.TagComponent - modSvcClient rpc.ModerationSvcClient - ms database.ModelStore - ds database.DatasetStore - sc component.SpaceComponent - ss database.SpaceStore - rs database.RepoStore - rrs database.RepoRelationsStore - mirrorStore database.MirrorStore - rrf database.RepositoriesRuntimeFrameworkStore - rac component.RuntimeArchitectureComponent - ras database.RuntimeArchitecturesStore - rfs database.RuntimeFrameworksStore - ts database.TagStore - dt database.TagRuleStore +type GitCallbackComponent interface { + SetRepoVisibility(yes bool) + WatchSpaceChange(ctx context.Context, req *types.GiteaCallbackPushReq) error + WatchRepoRelation(ctx context.Context, req *types.GiteaCallbackPushReq) error + SetRepoUpdateTime(ctx context.Context, req *types.GiteaCallbackPushReq) error + UpdateRepoInfos(ctx context.Context, req *types.GiteaCallbackPushReq) error +} + +type gitCallbackComponentImpl struct { + config *config.Config + gitServer gitserver.GitServer + tagComponent component.TagComponent + modSvcClient rpc.ModerationSvcClient + modelStore database.ModelStore + datasetStore database.DatasetStore + spaceComponent component.SpaceComponent + spaceStore database.SpaceStore + repoStore database.RepoStore + repoRelationStore database.RepoRelationsStore + mirrorStore database.MirrorStore + repoRuntimeFrameworkStore database.RepositoriesRuntimeFrameworkStore + runtimeArchComponent component.RuntimeArchitectureComponent + runtimeArchStore database.RuntimeArchitecturesStore + runtimeFrameworkStore database.RuntimeFrameworksStore + tagStore database.TagStore + tagRuleStore database.TagRuleStore // set visibility if file content is sensitive setRepoVisibility bool - pp component.PromptComponent maxPromptFS int64 } // new CallbackComponent -func NewGitCallback(config *config.Config) (*GitCallbackComponent, error) { +func NewGitCallback(config *config.Config) (*gitCallbackComponentImpl, error) { gs, err := git.NewGitServer(config) if err != nil { return nil, err @@ -74,45 +80,40 @@ func NewGitCallback(config *config.Config) (*GitCallbackComponent, error) { } rfs := database.NewRuntimeFrameworksStore() ts := database.NewTagStore() - pp, err := component.NewPromptComponent(config) - if err != nil { - return nil, err - } var modSvcClient rpc.ModerationSvcClient if config.SensitiveCheck.Enable { modSvcClient = rpc.NewModerationSvcHttpClient(fmt.Sprintf("%s:%d", config.Moderation.Host, config.Moderation.Port)) } dt := database.NewTagRuleStore() - return &GitCallbackComponent{ - config: config, - gs: gs, - tc: tc, - ms: ms, - ds: ds, - ss: ss, - sc: sc, - rs: rs, - rrs: rrs, - mirrorStore: mirrorStore, - modSvcClient: modSvcClient, - rrf: rrf, - rac: rac, - ras: ras, - rfs: rfs, - pp: pp, - ts: ts, - dt: dt, - maxPromptFS: config.Dataset.PromptMaxJsonlFileSize, + return &gitCallbackComponentImpl{ + config: config, + gitServer: gs, + tagComponent: tc, + modelStore: ms, + datasetStore: ds, + spaceStore: ss, + spaceComponent: sc, + repoStore: rs, + repoRelationStore: rrs, + mirrorStore: mirrorStore, + modSvcClient: modSvcClient, + repoRuntimeFrameworkStore: rrf, + runtimeArchComponent: rac, + runtimeArchStore: ras, + runtimeFrameworkStore: rfs, + tagStore: ts, + tagRuleStore: dt, + maxPromptFS: config.Dataset.PromptMaxJsonlFileSize, }, nil } // SetRepoVisibility sets a flag whether change repo's visibility if file content is sensitive -func (c *GitCallbackComponent) SetRepoVisibility(yes bool) { +func (c *gitCallbackComponentImpl) SetRepoVisibility(yes bool) { c.setRepoVisibility = yes } -func (c *GitCallbackComponent) WatchSpaceChange(ctx context.Context, req *types.GiteaCallbackPushReq) error { - err := WatchSpaceChange(req, c.ss, c.sc).Run() +func (c *gitCallbackComponentImpl) WatchSpaceChange(ctx context.Context, req *types.GiteaCallbackPushReq) error { + err := WatchSpaceChange(req, c.spaceStore, c.spaceComponent).Run() if err != nil { slog.Error("watch space change failed", slog.Any("error", err)) return err @@ -120,8 +121,8 @@ func (c *GitCallbackComponent) WatchSpaceChange(ctx context.Context, req *types. return nil } -func (c *GitCallbackComponent) WatchRepoRelation(ctx context.Context, req *types.GiteaCallbackPushReq) error { - err := WatchRepoRelation(req, c.rs, c.rrs, c.gs).Run() +func (c *gitCallbackComponentImpl) WatchRepoRelation(ctx context.Context, req *types.GiteaCallbackPushReq) error { + err := WatchRepoRelation(req, c.repoStore, c.repoRelationStore, c.gitServer).Run() if err != nil { slog.Error("watch repo relation failed", slog.Any("error", err)) return err @@ -129,7 +130,7 @@ func (c *GitCallbackComponent) WatchRepoRelation(ctx context.Context, req *types return nil } -func (c *GitCallbackComponent) SetRepoUpdateTime(ctx context.Context, req *types.GiteaCallbackPushReq) error { +func (c *gitCallbackComponentImpl) SetRepoUpdateTime(ctx context.Context, req *types.GiteaCallbackPushReq) error { // split req.Repository.FullName by '/' splits := strings.Split(req.Repository.FullName, "/") fullNamespace, repoName := splits[0], splits[1] @@ -138,7 +139,7 @@ func (c *GitCallbackComponent) SetRepoUpdateTime(ctx context.Context, req *types ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - isMirrorRepo, err := c.rs.IsMirrorRepo(ctx, adjustedRepoType, namespace, repoName) + isMirrorRepo, err := c.repoStore.IsMirrorRepo(ctx, adjustedRepoType, namespace, repoName) if err != nil { slog.Error("failed to check if a mirror repo", slog.Any("error", err), slog.String("repo_type", string(adjustedRepoType)), slog.String("namespace", namespace), slog.String("name", repoName)) return err @@ -149,7 +150,7 @@ func (c *GitCallbackComponent) SetRepoUpdateTime(ctx context.Context, req *types slog.Error("Error parsing time:", slog.Any("error", err), slog.String("timestamp", req.HeadCommit.Timestamp)) return err } - err = c.rs.SetUpdateTimeByPath(ctx, adjustedRepoType, namespace, repoName, updated) + err = c.repoStore.SetUpdateTimeByPath(ctx, adjustedRepoType, namespace, repoName, updated) if err != nil { slog.Error("failed to set repo update time", slog.Any("error", err), slog.String("repo_type", string(adjustedRepoType)), slog.String("namespace", namespace), slog.String("name", repoName)) return err @@ -166,7 +167,7 @@ func (c *GitCallbackComponent) SetRepoUpdateTime(ctx context.Context, req *types return err } } else { - err := c.rs.SetUpdateTimeByPath(ctx, adjustedRepoType, namespace, repoName, time.Now()) + err := c.repoStore.SetUpdateTimeByPath(ctx, adjustedRepoType, namespace, repoName, time.Now()) if err != nil { slog.Error("failed to set repo update time", slog.Any("error", err), slog.String("repo_type", string(adjustedRepoType)), slog.String("namespace", namespace), slog.String("name", repoName)) return err @@ -175,7 +176,7 @@ func (c *GitCallbackComponent) SetRepoUpdateTime(ctx context.Context, req *types return nil } -func (c *GitCallbackComponent) UpdateRepoInfos(ctx context.Context, req *types.GiteaCallbackPushReq) error { +func (c *gitCallbackComponentImpl) UpdateRepoInfos(ctx context.Context, req *types.GiteaCallbackPushReq) error { commits := req.Commits ref := req.Ref // split req.Repository.FullName by '/' @@ -193,7 +194,7 @@ func (c *GitCallbackComponent) UpdateRepoInfos(ctx context.Context, req *types.G return err } -func (c *GitCallbackComponent) SensitiveCheck(ctx context.Context, req *types.GiteaCallbackPushReq) error { +func (c *gitCallbackComponentImpl) SensitiveCheck(ctx context.Context, req *types.GiteaCallbackPushReq) error { // split req.Repository.FullName by '/' splits := strings.Split(req.Repository.FullName, "/") fullNamespace, repoName := splits[0], splits[1] @@ -208,11 +209,12 @@ func (c *GitCallbackComponent) SensitiveCheck(ctx context.Context, req *types.Gi slog.Error("fail to submit repo sensitive check", slog.Any("error", err), slog.Any("repo_type", adjustedRepoType), slog.String("namespace", namespace), slog.String("name", repoName)) return err } + return nil } // modifyFiles method handles modified files, skip if not modify README.md -func (c *GitCallbackComponent) modifyFiles(ctx context.Context, repoType, namespace, repoName, ref string, fileNames []string) error { +func (c *gitCallbackComponentImpl) modifyFiles(ctx context.Context, repoType, namespace, repoName, ref string, fileNames []string) error { for _, fileName := range fileNames { slog.Debug("modify file", slog.String("file", fileName)) // update model runtime @@ -232,7 +234,7 @@ func (c *GitCallbackComponent) modifyFiles(ctx context.Context, repoType, namesp return nil } -func (c *GitCallbackComponent) removeFiles(ctx context.Context, repoType, namespace, repoName, ref string, fileNames []string) error { +func (c *gitCallbackComponentImpl) removeFiles(ctx context.Context, repoType, namespace, repoName, ref string, fileNames []string) error { // handle removed files // delete tags for _, fileName := range fileNames { @@ -244,7 +246,7 @@ func (c *GitCallbackComponent) removeFiles(ctx context.Context, repoType, namesp // use empty content to clear all the meta tags const content string = "" adjustedRepoType := types.RepositoryType(strings.TrimSuffix(repoType, "s")) - err := c.tc.ClearMetaTags(ctx, adjustedRepoType, namespace, repoName) + err := c.tagComponent.ClearMetaTags(ctx, adjustedRepoType, namespace, repoName) if err != nil { slog.Error("failed to clear meta tags", slog.String("content", content), slog.String("repo", path.Join(namespace, repoName)), slog.String("ref", ref), @@ -267,7 +269,7 @@ func (c *GitCallbackComponent) removeFiles(ctx context.Context, repoType, namesp // case SpaceRepoType: // tagScope = database.SpaceTagScope } - err := c.tc.UpdateLibraryTags(ctx, tagScope, namespace, repoName, fileName, "") + err := c.tagComponent.UpdateLibraryTags(ctx, tagScope, namespace, repoName, fileName, "") if err != nil { slog.Error("failed to remove Library tag", slog.String("namespace", namespace), slog.String("name", repoName), slog.String("ref", ref), slog.String("fileName", fileName), @@ -279,7 +281,7 @@ func (c *GitCallbackComponent) removeFiles(ctx context.Context, repoType, namesp return nil } -func (c *GitCallbackComponent) addFiles(ctx context.Context, repoType, namespace, repoName, ref string, fileNames []string) error { +func (c *gitCallbackComponentImpl) addFiles(ctx context.Context, repoType, namespace, repoName, ref string, fileNames []string) error { for _, fileName := range fileNames { slog.Debug("add file", slog.String("file", fileName)) // update model runtime @@ -310,7 +312,7 @@ func (c *GitCallbackComponent) addFiles(ctx context.Context, repoType, namespace // case SpaceRepoType: // tagScope = database.SpaceTagScope } - err := c.tc.UpdateLibraryTags(ctx, tagScope, namespace, repoName, "", fileName) + err := c.tagComponent.UpdateLibraryTags(ctx, tagScope, namespace, repoName, "", fileName) if err != nil { slog.Error("failed to add Library tag", slog.String("namespace", namespace), slog.String("name", repoName), slog.String("ref", ref), slog.String("fileName", fileName), @@ -322,7 +324,7 @@ func (c *GitCallbackComponent) addFiles(ctx context.Context, repoType, namespace return nil } -func (c *GitCallbackComponent) updateMetaTags(ctx context.Context, repoType, namespace, repoName, ref, content string) error { +func (c *gitCallbackComponentImpl) updateMetaTags(ctx context.Context, repoType, namespace, repoName, ref, content string) error { var ( err error tagScope database.TagScope @@ -342,7 +344,7 @@ func (c *GitCallbackComponent) updateMetaTags(ctx context.Context, repoType, nam // case SpaceRepoType: // tagScope = database.SpaceTagScope } - _, err = c.tc.UpdateMetaTags(ctx, tagScope, namespace, repoName, content) + _, err = c.tagComponent.UpdateMetaTags(ctx, tagScope, namespace, repoName, content) if err != nil { slog.Error("failed to update meta tags", slog.String("namespace", namespace), slog.String("content", content), slog.String("repo", repoName), slog.String("ref", ref), @@ -353,7 +355,7 @@ func (c *GitCallbackComponent) updateMetaTags(ctx context.Context, repoType, nam return nil } -func (c *GitCallbackComponent) getFileRaw(repoType, namespace, repoName, ref, fileName string) (string, error) { +func (c *gitCallbackComponentImpl) getFileRaw(repoType, namespace, repoName, ref, fileName string) (string, error) { var ( content string err error @@ -366,7 +368,7 @@ func (c *GitCallbackComponent) getFileRaw(repoType, namespace, repoName, ref, fi Path: fileName, RepoType: types.RepositoryType(repoType), } - content, err = c.gs.GetRepoFileRaw(context.Background(), getFileRawReq) + content, err = c.gitServer.GetRepoFileRaw(context.Background(), getFileRawReq) if err != nil { slog.Error("failed to get file content", slog.String("namespace", namespace), slog.String("file", fileName), slog.String("repo", repoName), slog.String("ref", ref), @@ -380,7 +382,7 @@ func (c *GitCallbackComponent) getFileRaw(repoType, namespace, repoName, ref, fi } // update repo relations -func (c *GitCallbackComponent) updateRepoRelations(ctx context.Context, repoType, namespace, repoName, ref, fileName string, deleteAction bool, fileNames []string) { +func (c *gitCallbackComponentImpl) updateRepoRelations(ctx context.Context, repoType, namespace, repoName, ref, fileName string, deleteAction bool, fileNames []string) { slog.Debug("update model relation for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("repoType", repoType), slog.Any("fileName", fileName), slog.Any("branch", ref)) if repoType == fmt.Sprintf("%ss", types.ModelRepo) { c.updateModelRuntimeFrameworks(ctx, repoType, namespace, repoName, ref, fileName, deleteAction) @@ -391,19 +393,19 @@ func (c *GitCallbackComponent) updateRepoRelations(ctx context.Context, repoType } // update dataset tags for evaluation -func (c *GitCallbackComponent) updateDatasetTags(ctx context.Context, namespace, repoName string, fileNames []string) { +func (c *gitCallbackComponentImpl) updateDatasetTags(ctx context.Context, namespace, repoName string, fileNames []string) { // script dataset repo was not supported so far scriptName := fmt.Sprintf("%s.py", repoName) if slices.Contains(fileNames, scriptName) { return } - repo, err := c.rs.FindByPath(ctx, types.DatasetRepo, namespace, repoName) + repo, err := c.repoStore.FindByPath(ctx, types.DatasetRepo, namespace, repoName) if err != nil || repo == nil { slog.Warn("fail to query repo for in callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("error", err)) return } // check if it's evaluation dataset - evalDataset, err := c.dt.FindByRepo(ctx, string(types.EvaluationCategory), namespace, repoName, string(types.DatasetRepo)) + evalDataset, err := c.tagRuleStore.FindByRepo(ctx, string(types.EvaluationCategory), namespace, repoName, string(types.DatasetRepo)) if err != nil { if errors.Is(err, sql.ErrNoRows) { // check if it's a mirror repo @@ -415,7 +417,7 @@ func (c *GitCallbackComponent) updateDatasetTags(ctx context.Context, namespace, namespace := strings.Split(mirror.SourceRepoPath, "/")[0] name := strings.Split(mirror.SourceRepoPath, "/")[1] // use mirror namespace and name to find dataset - evalDataset, err = c.dt.FindByRepo(ctx, string(types.EvaluationCategory), namespace, name, string(types.DatasetRepo)) + evalDataset, err = c.tagRuleStore.FindByRepo(ctx, string(types.EvaluationCategory), namespace, name, string(types.DatasetRepo)) if err != nil { slog.Debug("not an evaluation dataset, ignore it", slog.Any("repo id", repo.Path)) return @@ -429,13 +431,13 @@ func (c *GitCallbackComponent) updateDatasetTags(ctx context.Context, namespace, tagIds := []int64{} tagIds = append(tagIds, evalDataset.Tag.ID) if evalDataset.RuntimeFramework != "" { - rTag, _ := c.ts.FindTag(ctx, evalDataset.RuntimeFramework, string(types.DatasetRepo), "runtime_framework") + rTag, _ := c.tagStore.FindTag(ctx, evalDataset.RuntimeFramework, string(types.DatasetRepo), "runtime_framework") if rTag != nil { tagIds = append(tagIds, rTag.ID) } } - err = c.ts.UpsertRepoTags(ctx, repo.ID, []int64{}, tagIds) + err = c.tagStore.UpsertRepoTags(ctx, repo.ID, []int64{}, tagIds) if err != nil { slog.Warn("fail to add dataset tag", slog.Any("repoId", repo.ID), slog.Any("tag id", tagIds), slog.Any("error", err)) } @@ -443,39 +445,39 @@ func (c *GitCallbackComponent) updateDatasetTags(ctx context.Context, namespace, } // update model runtime frameworks -func (c *GitCallbackComponent) updateModelRuntimeFrameworks(ctx context.Context, repoType, namespace, repoName, ref, fileName string, deleteAction bool) { +func (c *gitCallbackComponentImpl) updateModelRuntimeFrameworks(ctx context.Context, repoType, namespace, repoName, ref, fileName string, deleteAction bool) { // must be model repo and config.json if repoType != fmt.Sprintf("%ss", types.ModelRepo) || fileName != component.ConfigFileName || (ref != ("refs/heads/"+component.MainBranch) && ref != ("refs/heads/"+component.MasterBranch)) { return } - repo, err := c.rs.FindByPath(ctx, types.ModelRepo, namespace, repoName) + repo, err := c.repoStore.FindByPath(ctx, types.ModelRepo, namespace, repoName) if err != nil || repo == nil { slog.Warn("fail to query repo for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("error", err)) return } // delete event if deleteAction { - err := c.rrf.DeleteByRepoID(ctx, repo.ID) + err := c.repoRuntimeFrameworkStore.DeleteByRepoID(ctx, repo.ID) if err != nil { slog.Warn("fail to remove repo runtimes for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("repoid", repo.ID), slog.Any("error", err)) } return } - arch, err := c.rac.GetArchitectureFromConfig(ctx, namespace, repoName) + arch, err := c.runtimeArchComponent.GetArchitectureFromConfig(ctx, namespace, repoName) if err != nil { slog.Warn("fail to get config.json content for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("error", err)) return } slog.Debug("get arch for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("arch", arch)) //add resource tag, like ascend - runtime_framework_tags, _ := c.ts.GetTagsByScopeAndCategories(ctx, "model", []string{"runtime_framework", "resource"}) + runtime_framework_tags, _ := c.tagStore.GetTagsByScopeAndCategories(ctx, "model", []string{"runtime_framework", "resource"}) fields := strings.Split(repo.Path, "/") - err = c.rac.AddResourceTag(ctx, runtime_framework_tags, fields[1], repo.ID) + err = c.runtimeArchComponent.AddResourceTag(ctx, runtime_framework_tags, fields[1], repo.ID) if err != nil { slog.Warn("fail to add resource tag", slog.Any("error", err)) return } - runtimes, err := c.ras.ListByRArchNameAndModel(ctx, arch, fields[1]) + runtimes, err := c.runtimeArchStore.ListByRArchNameAndModel(ctx, arch, fields[1]) // to do check resource models if err != nil { slog.Warn("fail to get runtime ids by arch for git callback", slog.Any("arch", arch), slog.Any("error", err)) @@ -487,7 +489,7 @@ func (c *GitCallbackComponent) updateModelRuntimeFrameworks(ctx context.Context, frameIDs = append(frameIDs, runtime.RuntimeFrameworkID) } slog.Debug("get new frame ids for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("frameIDs", frameIDs)) - newFrames, err := c.rfs.ListByIDs(ctx, frameIDs) + newFrames, err := c.runtimeFrameworkStore.ListByIDs(ctx, frameIDs) if err != nil { slog.Warn("fail to get runtime frameworks for git callback", slog.Any("arch", arch), slog.Any("error", err)) return @@ -498,7 +500,7 @@ func (c *GitCallbackComponent) updateModelRuntimeFrameworks(ctx context.Context, newFrameMap[strconv.FormatInt(frame.ID, 10)] = strconv.FormatInt(frame.ID, 10) } slog.Debug("get new frame map by arch for git callback", slog.Any("namespace", namespace), slog.Any("repoName", repoName), slog.Any("newFrameMap", newFrameMap)) - oldRepoRuntimes, err := c.rrf.GetByRepoIDs(ctx, repo.ID) + oldRepoRuntimes, err := c.repoRuntimeFrameworkStore.GetByRepoIDs(ctx, repo.ID) if err != nil { slog.Warn("fail to get repo runtimes for git callback", slog.Any("repo.ID", repo.ID), slog.Any("error", err)) return @@ -516,12 +518,12 @@ func (c *GitCallbackComponent) updateModelRuntimeFrameworks(ctx context.Context, _, exist := newFrameMap[strconv.FormatInt(old.RuntimeFrameworkID, 10)] if !exist { // remove incorrect relations - err := c.rrf.Delete(ctx, old.RuntimeFrameworkID, repo.ID, old.Type) + err := c.repoRuntimeFrameworkStore.Delete(ctx, old.RuntimeFrameworkID, repo.ID, old.Type) if err != nil { slog.Warn("fail to delete old repo runtimes for git callback", slog.Any("repo.ID", repo.ID), slog.Any("runtime framework id", old.RuntimeFrameworkID), slog.Any("error", err)) } // remove runtime framework tags - c.rac.RemoveRuntimeFrameworkTag(ctx, runtime_framework_tags, repo.ID, old.RuntimeFrameworkID) + c.runtimeArchComponent.RemoveRuntimeFrameworkTag(ctx, runtime_framework_tags, repo.ID, old.RuntimeFrameworkID) } } @@ -531,12 +533,12 @@ func (c *GitCallbackComponent) updateModelRuntimeFrameworks(ctx context.Context, _, exist := oldFrameMap[strconv.FormatInt(new.ID, 10)] if !exist { // add new relations - err := c.rrf.Add(ctx, new.ID, repo.ID, new.Type) + err := c.repoRuntimeFrameworkStore.Add(ctx, new.ID, repo.ID, new.Type) if err != nil { slog.Warn("fail to add new repo runtimes for git callback", slog.Any("repo.ID", repo.ID), slog.Any("runtime framework id", new.ID), slog.Any("error", err)) } // add runtime framework and resource tags - err = c.rac.AddRuntimeFrameworkTag(ctx, runtime_framework_tags, repo.ID, new.ID) + err = c.runtimeArchComponent.AddRuntimeFrameworkTag(ctx, runtime_framework_tags, repo.ID, new.ID) if err != nil { slog.Warn("fail to add runtime framework tag for git callback", slog.Any("repo.ID", repo.ID), slog.Any("runtime framework id", new.ID), slog.Any("error", err)) } diff --git a/component/callback/git_callback_test.go b/component/callback/git_callback_test.go new file mode 100644 index 00000000..bafa8d85 --- /dev/null +++ b/component/callback/git_callback_test.go @@ -0,0 +1,187 @@ +package callback + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" + "opencsg.com/csghub-server/component" +) + +func TestGitCallbackComponent_SetRepoVisibility(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitCallbackComponent(ctx, t) + + require.False(t, gc.setRepoVisibility) + gc.SetRepoVisibility(true) + require.True(t, gc.setRepoVisibility) +} + +func TestGitCallbackComponent_WatchSpaceChange(t *testing.T) { + ctx := mock.Anything + gc := initializeTestGitCallbackComponent(context.TODO(), t) + + gc.mocks.stores.SpaceMock().EXPECT().FindByPath(ctx, "b", "c").Return( + &database.Space{HasAppFile: true}, nil, + ) + gc.mocks.spaceComponent.EXPECT().FixHasEntryFile(ctx, &database.Space{ + HasAppFile: true, + }).Return(nil) + gc.mocks.spaceComponent.EXPECT().Deploy(ctx, "b", "c", "b").Return(100, nil) + + err := gc.WatchSpaceChange(context.TODO(), &types.GiteaCallbackPushReq{ + Ref: "main", + Repository: types.GiteaCallbackPushReq_Repository{ + FullName: "spaces_b/c/d", + }, + }) + require.Nil(t, err) +} + +func TestGitCallbackComponent_WatchRepoRelation(t *testing.T) { + ctx := mock.Anything + gc := initializeTestGitCallbackComponent(context.TODO(), t) + + gc.mocks.gitServer.EXPECT().GetRepoFileRaw(ctx, gitserver.GetRepoInfoByPathReq{ + Namespace: "b", + Name: "c", + Ref: "refs/heads/main", + Path: "README.md", + RepoType: types.SpaceRepo, + }).Return("", nil) + gc.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.SpaceRepo, "b", "c").Return( + &database.Repository{ID: 1}, nil, + ) + gc.mocks.stores.RepoRelationMock().EXPECT().Override(ctx, int64(1)).Return(nil) + + err := gc.WatchRepoRelation(context.TODO(), &types.GiteaCallbackPushReq{ + Ref: "refs/heads/main", + Repository: types.GiteaCallbackPushReq_Repository{ + FullName: "spaces_b/c/d", + }, + Commits: []types.GiteaCallbackPushReq_Commit{ + {Modified: []string{types.ReadmeFileName}}, + }, + }) + require.Nil(t, err) +} + +func TestGitCallbackComponent_SetRepoUpdateTime(t *testing.T) { + for _, mirror := range []bool{false, true} { + t.Run(fmt.Sprintf("mirror %v", mirror), func(t *testing.T) { + dt := time.Date(2022, 2, 2, 2, 0, 0, 0, time.UTC) + ctx := mock.Anything + gc := initializeTestGitCallbackComponent(context.TODO(), t) + + gc.mocks.stores.RepoMock().EXPECT().IsMirrorRepo( + ctx, types.ModelRepo, "ns", "n", + ).Return(mirror, nil) + + if mirror { + gc.mocks.stores.RepoMock().EXPECT().SetUpdateTimeByPath( + ctx, types.ModelRepo, "ns", "n", dt, + ).Return(nil) + gc.mocks.stores.MirrorMock().EXPECT().FindByRepoPath( + ctx, types.ModelRepo, "ns", "n", + ).Return(&database.Mirror{}, nil) + gc.mocks.stores.MirrorMock().EXPECT().Update( + ctx, mock.Anything, + ).RunAndReturn(func(ctx context.Context, m *database.Mirror) error { + require.GreaterOrEqual(t, m.LastUpdatedAt, time.Now().Add(-5*time.Second)) + return nil + }) + } else { + gc.mocks.stores.RepoMock().EXPECT().SetUpdateTimeByPath( + ctx, types.ModelRepo, "ns", "n", mock.Anything, + ).RunAndReturn(func(ctx context.Context, rt types.RepositoryType, s1, s2 string, tt time.Time) error { + require.GreaterOrEqual(t, tt, time.Now().Add(-5*time.Second)) + return nil + }) + } + + err := gc.SetRepoUpdateTime(context.TODO(), &types.GiteaCallbackPushReq{ + Repository: types.GiteaCallbackPushReq_Repository{ + FullName: "models_ns/n", + }, + HeadCommit: types.GiteaCallbackPushReq_HeadCommit{ + Timestamp: dt.Format(time.RFC3339), + }, + }) + require.Nil(t, err) + }) + } +} + +func TestGitCallbackComponent_UpdateRepoInfos(t *testing.T) { + ctx := context.TODO() + gc := initializeTestGitCallbackComponent(context.TODO(), t) + + // modified mock + gc.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.ModelRepo, "ns", "n").Return( + &database.Repository{ID: 1, Path: "foo/bar"}, nil, + ) + gc.mocks.runtimeArchComponent.EXPECT().GetArchitectureFromConfig(ctx, "ns", "n").Return("foo", nil) + gc.mocks.stores.TagMock().EXPECT().GetTagsByScopeAndCategories( + ctx, database.ModelTagScope, []string{"runtime_framework", "resource"}, + ).Return([]*database.Tag{{Name: "t1"}}, nil) + gc.mocks.runtimeArchComponent.EXPECT().AddResourceTag( + ctx, []*database.Tag{{Name: "t1"}}, "bar", int64(1), + ).Return(nil) + gc.mocks.stores.RuntimeArchMock().EXPECT().ListByRArchNameAndModel(ctx, "foo", "bar").Return( + []database.RuntimeArchitecture{{ID: 11, RuntimeFrameworkID: 111}}, nil, + ) + gc.mocks.stores.RuntimeFrameworkMock().EXPECT().ListByIDs(ctx, []int64{111}).Return( + []database.RuntimeFramework{{ID: 12, FrameName: "fm"}}, nil, + ) + gc.mocks.stores.RepoRuntimeFrameworkMock().EXPECT().GetByRepoIDs(ctx, int64(1)).Return( + []database.RepositoriesRuntimeFramework{{RuntimeFrameworkID: 13}}, nil, + ) + gc.mocks.stores.RepoRuntimeFrameworkMock().EXPECT().Delete(ctx, int64(13), int64(1), 0).Return(nil) + gc.mocks.runtimeArchComponent.EXPECT().RemoveRuntimeFrameworkTag( + ctx, []*database.Tag{{Name: "t1"}}, int64(1), int64(13), + ).Return() + gc.mocks.stores.RepoRuntimeFrameworkMock().EXPECT().Add(ctx, int64(12), int64(1), 0).Return(nil) + gc.mocks.runtimeArchComponent.EXPECT().AddRuntimeFrameworkTag( + ctx, []*database.Tag{{Name: "t1"}}, int64(1), int64(12), + ).Return(nil) + // removed mock + gc.mocks.tagComponent.EXPECT().UpdateLibraryTags( + ctx, database.ModelTagScope, "ns", "n", "bar.go", "", + ).Return(nil) + gc.mocks.tagComponent.EXPECT().ClearMetaTags(ctx, types.ModelRepo, "ns", "n").Return(nil) + // added mock + gc.mocks.tagComponent.EXPECT().UpdateLibraryTags( + ctx, database.ModelTagScope, "ns", "n", "", "foo.go", + ).Return(nil) + gc.mocks.gitServer.EXPECT().GetRepoFileRaw(mock.Anything, gitserver.GetRepoInfoByPathReq{ + Namespace: "ns", + Name: "n", + Ref: "refs/heads/main", + Path: "README.md", + RepoType: types.ModelRepo, + }).Return("", nil) + gc.mocks.tagComponent.EXPECT().UpdateMetaTags( + ctx, database.ModelTagScope, "ns", "n", "", + ).Return(nil, nil) + + err := gc.UpdateRepoInfos(ctx, &types.GiteaCallbackPushReq{ + Ref: "refs/heads/main", + Repository: types.GiteaCallbackPushReq_Repository{ + FullName: "models_ns/n", + }, + Commits: []types.GiteaCallbackPushReq_Commit{ + { + Modified: []string{component.ConfigFileName}, + Removed: []string{"bar.go", types.ReadmeFileName}, + Added: []string{"foo.go", types.ReadmeFileName}, + }, + }, + }) + require.Nil(t, err) +} diff --git a/component/callback/wire.go b/component/callback/wire.go new file mode 100644 index 00000000..e9abaea1 --- /dev/null +++ b/component/callback/wire.go @@ -0,0 +1,27 @@ +//go:build wireinject +// +build wireinject + +package callback + +import ( + "context" + + "github.com/google/wire" + "github.com/stretchr/testify/mock" +) + +type testGitCallbackWithMocks struct { + *gitCallbackComponentImpl + mocks *Mocks +} + +func initializeTestGitCallbackComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testGitCallbackWithMocks { + wire.Build( + MockCallbackSuperSet, GitCallbackComponentSet, + wire.Struct(new(testGitCallbackWithMocks), "*"), + ) + return &testGitCallbackWithMocks{} +} diff --git a/component/callback/wire_gen_test.go b/component/callback/wire_gen_test.go new file mode 100644 index 00000000..f05a8e39 --- /dev/null +++ b/component/callback/wire_gen_test.go @@ -0,0 +1,52 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:generate go run -mod=mod github.com/google/wire/cmd/wire +//go:build !wireinject +// +build !wireinject + +package callback + +import ( + "context" + "github.com/stretchr/testify/mock" + "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/rpc" + component2 "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/common/tests" + "opencsg.com/csghub-server/component" +) + +// Injectors from wire.go: + +func initializeTestGitCallbackComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testGitCallbackWithMocks { + config := component.ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockGitServer := gitserver.NewMockGitServer(t) + mockTagComponent := component2.NewMockTagComponent(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mockRuntimeArchitectureComponent := component2.NewMockRuntimeArchitectureComponent(t) + mockSpaceComponent := component2.NewMockSpaceComponent(t) + callbackGitCallbackComponentImpl := NewTestGitCallbackComponent(config, mockStores, mockGitServer, mockTagComponent, mockModerationSvcClient, mockRuntimeArchitectureComponent, mockSpaceComponent) + mocks := &Mocks{ + stores: mockStores, + tagComponent: mockTagComponent, + spaceComponent: mockSpaceComponent, + gitServer: mockGitServer, + runtimeArchComponent: mockRuntimeArchitectureComponent, + } + callbackTestGitCallbackWithMocks := &testGitCallbackWithMocks{ + gitCallbackComponentImpl: callbackGitCallbackComponentImpl, + mocks: mocks, + } + return callbackTestGitCallbackWithMocks +} + +// wire.go: + +type testGitCallbackWithMocks struct { + *gitCallbackComponentImpl + mocks *Mocks +} diff --git a/component/callback/wireset.go b/component/callback/wireset.go new file mode 100644 index 00000000..4401844c --- /dev/null +++ b/component/callback/wireset.go @@ -0,0 +1,53 @@ +package callback + +import ( + "github.com/google/wire" + mock_git "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/git/gitserver" + mock_component "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/rpc" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/tests" + "opencsg.com/csghub-server/component" +) + +type Mocks struct { + stores *tests.MockStores + tagComponent *mock_component.MockTagComponent + spaceComponent *mock_component.MockSpaceComponent + gitServer *mock_git.MockGitServer + runtimeArchComponent *mock_component.MockRuntimeArchitectureComponent +} + +var AllMockSet = wire.NewSet( + wire.Struct(new(Mocks), "*"), +) + +var MockCallbackSuperSet = wire.NewSet( + component.MockedStoreSet, component.MockedComponentSet, AllMockSet, + component.ProvideTestConfig, component.MockedGitServerSet, component.MockedModerationSvcClientSet, +) + +func NewTestGitCallbackComponent(config *config.Config, stores *tests.MockStores, gitServer gitserver.GitServer, tagComponent component.TagComponent, modSvcClient rpc.ModerationSvcClient, runtimeArchComponent component.RuntimeArchitectureComponent, spaceComponent component.SpaceComponent) *gitCallbackComponentImpl { + return &gitCallbackComponentImpl{ + config: config, + gitServer: gitServer, + tagComponent: tagComponent, + modSvcClient: modSvcClient, + modelStore: stores.Model, + datasetStore: stores.Dataset, + spaceComponent: spaceComponent, + spaceStore: stores.Space, + repoStore: stores.Repo, + repoRelationStore: stores.RepoRelation, + mirrorStore: stores.Mirror, + repoRuntimeFrameworkStore: stores.RepoRuntimeFramework, + runtimeArchComponent: runtimeArchComponent, + runtimeArchStore: stores.RuntimeArch, + runtimeFrameworkStore: stores.RuntimeFramework, + tagStore: stores.Tag, + tagRuleStore: stores.TagRule, + } +} + +var GitCallbackComponentSet = wire.NewSet(NewTestGitCallbackComponent) diff --git a/component/cluster_test.go b/component/cluster_test.go new file mode 100644 index 00000000..04e7820d --- /dev/null +++ b/component/cluster_test.go @@ -0,0 +1,42 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/common/types" +) + +func TestClusterComponent_Index(t *testing.T) { + ctx := context.TODO() + cc := initializeTestClusterComponent(ctx, t) + + cc.mocks.deployer.EXPECT().ListCluster(ctx).Return(nil, nil) + + data, err := cc.Index(ctx) + require.Nil(t, err) + require.Equal(t, []types.ClusterRes([]types.ClusterRes(nil)), data) +} + +func TestClusterComponent_GetClusterById(t *testing.T) { + ctx := context.TODO() + cc := initializeTestClusterComponent(ctx, t) + + cc.mocks.deployer.EXPECT().GetClusterById(ctx, "c1").Return(nil, nil) + + data, err := cc.GetClusterById(ctx, "c1") + require.Nil(t, err) + require.Equal(t, (*types.ClusterRes)(nil), data) +} + +func TestClusterComponent_Update(t *testing.T) { + ctx := context.TODO() + cc := initializeTestClusterComponent(ctx, t) + + cc.mocks.deployer.EXPECT().UpdateCluster(ctx, types.ClusterRequest{}).Return(nil, nil) + + data, err := cc.Update(ctx, types.ClusterRequest{}) + require.Nil(t, err) + require.Equal(t, (*types.UpdateClusterResponse)(nil), data) +} diff --git a/component/evaluation.go b/component/evaluation.go index 683261cb..e437a11a 100644 --- a/component/evaluation.go +++ b/component/evaluation.go @@ -15,16 +15,16 @@ import ( ) type evaluationComponentImpl struct { - deployer deploy.Deployer - userStore database.UserStore - modelStore database.ModelStore - datasetStore database.DatasetStore - mirrorStore database.MirrorStore - spaceResourceStore database.SpaceResourceStore - tokenStore database.AccessTokenStore - rtfm database.RuntimeFrameworksStore - config *config.Config - ac AccountingComponent + deployer deploy.Deployer + userStore database.UserStore + modelStore database.ModelStore + datasetStore database.DatasetStore + mirrorStore database.MirrorStore + spaceResourceStore database.SpaceResourceStore + tokenStore database.AccessTokenStore + runtimeFrameworkStore database.RuntimeFrameworksStore + config *config.Config + accountingComponent AccountingComponent } type EvaluationComponent interface { @@ -43,13 +43,13 @@ func NewEvaluationComponent(config *config.Config) (EvaluationComponent, error) c.datasetStore = database.NewDatasetStore() c.mirrorStore = database.NewMirrorStore() c.tokenStore = database.NewAccessTokenStore() - c.rtfm = database.NewRuntimeFrameworksStore() + c.runtimeFrameworkStore = database.NewRuntimeFrameworksStore() c.config = config ac, err := NewAccountingComponent(config) if err != nil { return nil, fmt.Errorf("failed to create accounting component, %w", err) } - c.ac = ac + c.accountingComponent = ac return c, nil } @@ -97,7 +97,7 @@ func (c *evaluationComponentImpl) CreateEvaluation(ctx context.Context, req type hardware.Cpu.Num = "8" hardware.Memory = "32Gi" } - frame, err := c.rtfm.FindEnabledByID(ctx, req.RuntimeFrameworkId) + frame, err := c.runtimeFrameworkStore.FindEnabledByID(ctx, req.RuntimeFrameworkId) if err != nil { return nil, fmt.Errorf("cannot find available runtime framework, %w", err) } diff --git a/component/evaluation_test.go b/component/evaluation_test.go index a1620f14..cccc2ca9 100644 --- a/component/evaluation_test.go +++ b/component/evaluation_test.go @@ -6,32 +6,10 @@ import ( "testing" "github.com/stretchr/testify/require" - mock_deploy "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/deploy" - mock_component "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" - "opencsg.com/csghub-server/builder/deploy" "opencsg.com/csghub-server/builder/store/database" - "opencsg.com/csghub-server/common/config" - "opencsg.com/csghub-server/common/tests" "opencsg.com/csghub-server/common/types" ) -func NewTestEvaluationComponent(deployer deploy.Deployer, stores *tests.MockStores, ac AccountingComponent) EvaluationComponent { - cfg := &config.Config{} - cfg.Argo.QuotaGPUNumber = "1" - return &evaluationComponentImpl{ - deployer: deployer, - config: cfg, - userStore: stores.User, - modelStore: stores.Model, - datasetStore: stores.Dataset, - mirrorStore: stores.Mirror, - spaceResourceStore: stores.SpaceResource, - tokenStore: stores.AccessToken, - rtfm: stores.RuntimeFramework, - ac: ac, - } -} - func TestEvaluationComponent_CreateEvaluation(t *testing.T) { req := types.EvaluationReq{ TaskName: "test", @@ -66,30 +44,28 @@ func TestEvaluationComponent_CreateEvaluation(t *testing.T) { Token: "foo", } t.Run("create evaluation without resource id", func(t *testing.T) { - deployerMock := &mock_deploy.MockDeployer{} - stores := tests.NewMockStores(t) - ac := &mock_component.MockAccountingComponent{} - c := NewTestEvaluationComponent(deployerMock, stores, ac) - stores.UserMock().EXPECT().FindByUsername(ctx, req.Username).Return(database.User{ + c := initializeTestEvaluationComponent(ctx, t) + c.config.Argo.QuotaGPUNumber = "1" + c.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, req.Username).Return(database.User{ RoleMask: "admin", Username: req.Username, UUID: req.Username, ID: 1, }, nil).Once() - stores.ModelMock().EXPECT().FindByPath(ctx, "opencsg", "wukong").Return( + c.mocks.stores.ModelMock().EXPECT().FindByPath(ctx, "opencsg", "wukong").Return( &database.Model{ ID: 1, }, nil, ).Maybe() - stores.MirrorMock().EXPECT().FindByRepoPath(ctx, types.DatasetRepo, "opencsg", "hellaswag").Return(&database.Mirror{ + c.mocks.stores.MirrorMock().EXPECT().FindByRepoPath(ctx, types.DatasetRepo, "opencsg", "hellaswag").Return(&database.Mirror{ SourceRepoPath: "Rowan/hellaswag", }, nil) - stores.AccessTokenMock().EXPECT().FindByUID(ctx, int64(1)).Return(&database.AccessToken{Token: "foo"}, nil) - stores.RuntimeFrameworkMock().EXPECT().FindEnabledByID(ctx, int64(1)).Return(&database.RuntimeFramework{ + c.mocks.stores.AccessTokenMock().EXPECT().FindByUID(ctx, int64(1)).Return(&database.AccessToken{Token: "foo"}, nil) + c.mocks.stores.RuntimeFrameworkMock().EXPECT().FindEnabledByID(ctx, int64(1)).Return(&database.RuntimeFramework{ ID: 1, FrameImage: "lm-evaluation-harness:0.4.6", }, nil) - deployerMock.EXPECT().SubmitEvaluation(ctx, req2).Return(&types.ArgoWorkFlowRes{ + c.mocks.deployer.EXPECT().SubmitEvaluation(ctx, req2).Return(&types.ArgoWorkFlowRes{ ID: 1, TaskName: "test", }, nil) @@ -101,36 +77,36 @@ func TestEvaluationComponent_CreateEvaluation(t *testing.T) { t.Run("create evaluation with resource id", func(t *testing.T) { req.ResourceId = 1 req2.ResourceId = 1 - deployerMock := &mock_deploy.MockDeployer{} - stores := tests.NewMockStores(t) - ac := &mock_component.MockAccountingComponent{} - c := NewTestEvaluationComponent(deployerMock, stores, ac) - stores.UserMock().EXPECT().FindByUsername(ctx, req.Username).Return(database.User{ + c := initializeTestEvaluationComponent(ctx, t) + c.config.Argo.QuotaGPUNumber = "1" + c.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, req.Username).Return(database.User{ RoleMask: "admin", Username: req.Username, UUID: req.Username, ID: 1, }, nil).Once() - stores.ModelMock().EXPECT().FindByPath(ctx, "opencsg", "wukong").Return( + c.mocks.stores.ModelMock().EXPECT().FindByPath(ctx, "opencsg", "wukong").Return( &database.Model{ ID: 1, }, nil, ).Maybe() - stores.MirrorMock().EXPECT().FindByRepoPath(ctx, types.DatasetRepo, "opencsg", "hellaswag").Return(&database.Mirror{ + c.mocks.stores.MirrorMock().EXPECT().FindByRepoPath(ctx, types.DatasetRepo, "opencsg", "hellaswag").Return(&database.Mirror{ SourceRepoPath: "Rowan/hellaswag", }, nil) - stores.AccessTokenMock().EXPECT().FindByUID(ctx, int64(1)).Return(&database.AccessToken{Token: "foo"}, nil) - stores.RuntimeFrameworkMock().EXPECT().FindEnabledByID(ctx, int64(1)).Return(&database.RuntimeFramework{ + c.mocks.stores.AccessTokenMock().EXPECT().FindByUID(ctx, int64(1)).Return(&database.AccessToken{Token: "foo"}, nil) + c.mocks.stores.RuntimeFrameworkMock().EXPECT().FindEnabledByID(ctx, int64(1)).Return(&database.RuntimeFramework{ ID: 1, FrameImage: "lm-evaluation-harness:0.4.6", }, nil) + resource, err := json.Marshal(req2.Hardware) require.Nil(t, err) - stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(1)).Return(&database.SpaceResource{ + c.mocks.stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(1)).Return(&database.SpaceResource{ ID: 1, Resources: string(resource), }, nil) - deployerMock.EXPECT().SubmitEvaluation(ctx, req2).Return(&types.ArgoWorkFlowRes{ + c.mocks.deployer.EXPECT().SubmitEvaluation(ctx, req2).Return(&types.ArgoWorkFlowRes{ + ID: 1, TaskName: "test", }, nil) @@ -142,15 +118,13 @@ func TestEvaluationComponent_CreateEvaluation(t *testing.T) { } func TestEvaluationComponent_GetEvaluation(t *testing.T) { - deployerMock := &mock_deploy.MockDeployer{} - stores := tests.NewMockStores(t) - ac := &mock_component.MockAccountingComponent{} - c := NewTestEvaluationComponent(deployerMock, stores, ac) + ctx := context.TODO() + c := initializeTestEvaluationComponent(ctx, t) + c.config.Argo.QuotaGPUNumber = "1" req := types.EvaluationGetReq{ Username: "test", } - ctx := context.TODO() - deployerMock.EXPECT().GetEvaluation(ctx, req).Return(&types.ArgoWorkFlowRes{ + c.mocks.deployer.EXPECT().GetEvaluation(ctx, req).Return(&types.ArgoWorkFlowRes{ ID: 1, RepoIds: []string{"Rowan/hellaswag"}, Datasets: []string{"Rowan/hellaswag"}, @@ -161,7 +135,7 @@ func TestEvaluationComponent_GetEvaluation(t *testing.T) { TaskType: "evaluation", Status: "Succeed", }, nil) - stores.DatasetMock().EXPECT().ListByPath(ctx, []string{"Rowan/hellaswag"}).Return([]database.Dataset{ + c.mocks.stores.DatasetMock().EXPECT().ListByPath(ctx, []string{"Rowan/hellaswag"}).Return([]database.Dataset{ { Repository: &database.Repository{ Path: "Rowan/hellaswag", @@ -184,15 +158,13 @@ func TestEvaluationComponent_GetEvaluation(t *testing.T) { } func TestEvaluationComponent_DeleteEvaluation(t *testing.T) { - deployerMock := &mock_deploy.MockDeployer{} - stores := tests.NewMockStores(t) - ac := &mock_component.MockAccountingComponent{} - c := NewTestEvaluationComponent(deployerMock, stores, ac) + ctx := context.TODO() + c := initializeTestEvaluationComponent(ctx, t) + c.config.Argo.QuotaGPUNumber = "1" req := types.EvaluationDelReq{ Username: "test", } - ctx := context.TODO() - deployerMock.EXPECT().DeleteEvaluation(ctx, req).Return(nil) + c.mocks.deployer.EXPECT().DeleteEvaluation(ctx, req).Return(nil) err := c.DeleteEvaluation(ctx, req) require.Nil(t, err) } diff --git a/component/event.go b/component/event.go index 2ea4a1ec..3a28bb69 100644 --- a/component/event.go +++ b/component/event.go @@ -8,7 +8,7 @@ import ( ) type eventComponentImpl struct { - es database.EventStore + eventStore database.EventStore } // NewEventComponent creates a new EventComponent @@ -19,7 +19,7 @@ type EventComponent interface { func NewEventComponent() EventComponent { return &eventComponentImpl{ - es: database.NewEventStore(), + eventStore: database.NewEventStore(), } } @@ -34,5 +34,5 @@ func (ec *eventComponentImpl) NewEvents(ctx context.Context, events []types.Even }) } - return ec.es.BatchSave(ctx, dbevents) + return ec.eventStore.BatchSave(ctx, dbevents) } diff --git a/component/event_test.go b/component/event_test.go new file mode 100644 index 00000000..ebd0a039 --- /dev/null +++ b/component/event_test.go @@ -0,0 +1,22 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestEventComponent_NewEvent(t *testing.T) { + ctx := context.TODO() + ec := initializeTestEventComponent(ctx, t) + + ec.mocks.stores.EventMock().EXPECT().BatchSave(ctx, []database.Event{ + {EventID: "e1"}, + }).Return(nil) + + err := ec.NewEvents(ctx, []types.Event{{ID: "e1"}}) + require.Nil(t, err) +} diff --git a/component/hf_dataset.go b/component/hf_dataset.go index 03c48e89..5b669376 100644 --- a/component/hf_dataset.go +++ b/component/hf_dataset.go @@ -6,6 +6,7 @@ import ( "log/slog" "strings" + "opencsg.com/csghub-server/builder/git" "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/config" @@ -19,22 +20,28 @@ type HFDatasetComponent interface { func NewHFDatasetComponent(config *config.Config) (HFDatasetComponent, error) { c := &hFDatasetComponentImpl{} - c.ts = database.NewTagStore() - c.ds = database.NewDatasetStore() - c.rs = database.NewRepoStore() + c.tagStore = database.NewTagStore() + c.datasetStore = database.NewDatasetStore() + c.repoStore = database.NewRepoStore() var err error - c.repoComponentImpl, err = NewRepoComponentImpl(config) + c.repoComponent, err = NewRepoComponentImpl(config) if err != nil { return nil, err } + gs, err := git.NewGitServer(config) + if err != nil { + return nil, fmt.Errorf("failed to create git server, error: %w", err) + } + c.gitServer = gs return c, nil } type hFDatasetComponentImpl struct { - *repoComponentImpl - ts database.TagStore - ds database.DatasetStore - rs database.RepoStore + repoComponent RepoComponent + tagStore database.TagStore + datasetStore database.DatasetStore + repoStore database.RepoStore + gitServer gitserver.GitServer } func convertFilePathFromRoute(path string) string { @@ -42,12 +49,12 @@ func convertFilePathFromRoute(path string) string { } func (h *hFDatasetComponentImpl) GetPathsInfo(ctx context.Context, req types.PathReq) ([]types.HFDSPathInfo, error) { - ds, err := h.ds.FindByPath(ctx, req.Namespace, req.Name) + ds, err := h.datasetStore.FindByPath(ctx, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find dataset, error: %w", err) } - allow, err := h.AllowReadAccessRepo(ctx, ds.Repository, req.CurrentUser) + allow, err := h.repoComponent.AllowReadAccessRepo(ctx, ds.Repository, req.CurrentUser) if err != nil { return nil, fmt.Errorf("failed to check dataset permission, error: %w", err) } @@ -62,7 +69,7 @@ func (h *hFDatasetComponentImpl) GetPathsInfo(ctx context.Context, req types.Pat Path: convertFilePathFromRoute(req.Path), RepoType: types.DatasetRepo, } - file, _ := h.git.GetRepoFileContents(ctx, getRepoFileTree) + file, _ := h.gitServer.GetRepoFileContents(ctx, getRepoFileTree) if file == nil { return []types.HFDSPathInfo{}, nil } @@ -81,12 +88,12 @@ func (h *hFDatasetComponentImpl) GetPathsInfo(ctx context.Context, req types.Pat } func (h *hFDatasetComponentImpl) GetDatasetTree(ctx context.Context, req types.PathReq) ([]types.HFDSPathInfo, error) { - ds, err := h.ds.FindByPath(ctx, req.Namespace, req.Name) + ds, err := h.datasetStore.FindByPath(ctx, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find dataset tree, error: %w", err) } - allow, err := h.AllowReadAccessRepo(ctx, ds.Repository, req.CurrentUser) + allow, err := h.repoComponent.AllowReadAccessRepo(ctx, ds.Repository, req.CurrentUser) if err != nil { return nil, fmt.Errorf("failed to check dataset permission, error: %w", err) } @@ -102,7 +109,7 @@ func (h *hFDatasetComponentImpl) GetDatasetTree(ctx context.Context, req types.P Path: req.Path, RepoType: types.DatasetRepo, } - tree, err := h.git.GetRepoFileTree(ctx, getRepoFileTree) + tree, err := h.gitServer.GetRepoFileTree(ctx, getRepoFileTree) if err != nil { slog.Warn("failed to get repo file tree", slog.Any("getRepoFileTree", getRepoFileTree), slog.String("error", err.Error())) return []types.HFDSPathInfo{}, nil diff --git a/component/hf_dataset_test.go b/component/hf_dataset_test.go new file mode 100644 index 00000000..b8e0bda1 --- /dev/null +++ b/component/hf_dataset_test.go @@ -0,0 +1,71 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestHFDataset_GetPathsInfo(t *testing.T) { + ctx := context.TODO() + hc := initializeTestHFDatasetComponent(ctx, t) + + dataset := &database.Dataset{} + hc.mocks.stores.DatasetMock().EXPECT().FindByPath(ctx, "ns", "n").Return(dataset, nil) + hc.mocks.components.repo.EXPECT().AllowReadAccessRepo(ctx, dataset.Repository, "user").Return(true, nil) + hc.mocks.gitServer.EXPECT().GetRepoFileContents(ctx, gitserver.GetRepoInfoByPathReq{ + Namespace: "ns", + Name: "n", + Path: "a/b", + Ref: "main", + RepoType: types.DatasetRepo, + }).Return(&types.File{ + Type: "go", LastCommitSHA: "sha", Size: 5, Path: "foo", + }, nil) + + data, err := hc.GetPathsInfo(ctx, types.PathReq{ + Namespace: "ns", + Name: "n", + Ref: "main", + Path: "a/b", + CurrentUser: "user", + }) + require.Nil(t, err) + require.Equal(t, []types.HFDSPathInfo{ + {Type: "file", Path: "foo", Size: 5, OID: "sha"}, + }, data) + +} + +func TestHFDataset_GetDatasetTree(t *testing.T) { + ctx := context.TODO() + hc := initializeTestHFDatasetComponent(ctx, t) + + dataset := &database.Dataset{} + hc.mocks.stores.DatasetMock().EXPECT().FindByPath(ctx, "ns", "n").Return(dataset, nil) + hc.mocks.components.repo.EXPECT().AllowReadAccessRepo(ctx, dataset.Repository, "user").Return(true, nil) + hc.mocks.gitServer.EXPECT().GetRepoFileTree(ctx, gitserver.GetRepoInfoByPathReq{ + Namespace: "ns", + Name: "n", + Path: "a/b", + RepoType: types.DatasetRepo, + }).Return([]*types.File{ + {Type: "go", LastCommitSHA: "sha", Size: 5, Path: "foo"}, + }, nil) + + data, err := hc.GetDatasetTree(ctx, types.PathReq{ + Namespace: "ns", + Name: "n", + Ref: "main", + Path: "a/b", + CurrentUser: "user", + }) + require.Nil(t, err) + require.Equal(t, []types.HFDSPathInfo{ + {Type: "go", Path: "foo", Size: 5, OID: "sha"}, + }, data) +} diff --git a/component/list.go b/component/list.go index 3567692b..2fcccdde 100644 --- a/component/list.go +++ b/component/list.go @@ -16,22 +16,22 @@ type ListComponent interface { func NewListComponent(config *config.Config) (ListComponent, error) { c := &listComponentImpl{} - c.ds = database.NewDatasetStore() - c.ms = database.NewModelStore() - c.ss = database.NewSpaceStore() + c.datasetStore = database.NewDatasetStore() + c.modelStore = database.NewModelStore() + c.spaceStore = database.NewSpaceStore() return c, nil } type listComponentImpl struct { - ms database.ModelStore - ds database.DatasetStore - ss database.SpaceStore + modelStore database.ModelStore + datasetStore database.DatasetStore + spaceStore database.SpaceStore } func (c *listComponentImpl) ListModelsByPath(ctx context.Context, req *types.ListByPathReq) ([]*types.ModelResp, error) { var modelResp []*types.ModelResp - models, err := c.ms.ListByPath(ctx, req.Paths) + models, err := c.modelStore.ListByPath(ctx, req.Paths) if err != nil { slog.Error("error listing models by path", "error", err, slog.Any("paths", req.Paths)) return nil, err @@ -67,7 +67,7 @@ func (c *listComponentImpl) ListModelsByPath(ctx context.Context, req *types.Lis func (c *listComponentImpl) ListDatasetsByPath(ctx context.Context, req *types.ListByPathReq) ([]*types.DatasetResp, error) { var datasetResp []*types.DatasetResp - datasets, err := c.ds.ListByPath(ctx, req.Paths) + datasets, err := c.datasetStore.ListByPath(ctx, req.Paths) if err != nil { slog.Error("error listing datasets by path", "error", err, slog.Any("paths", req.Paths)) return nil, err diff --git a/component/list_test.go b/component/list_test.go new file mode 100644 index 00000000..490503cf --- /dev/null +++ b/component/list_test.go @@ -0,0 +1,46 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestListComponent_ListModelsByPath(t *testing.T) { + ctx := context.TODO() + lc := initializeTestListComponent(ctx, t) + + lc.mocks.stores.ModelMock().EXPECT().ListByPath(ctx, []string{"foo"}).Return( + []database.Model{ + {Repository: &database.Repository{ + Name: "r1", + Tags: []database.Tag{{Name: "t1"}}, + }}, + }, nil, + ) + + data, err := lc.ListModelsByPath(ctx, &types.ListByPathReq{Paths: []string{"foo"}}) + require.Nil(t, err) + require.Equal(t, []*types.ModelResp{{Name: "r1", Tags: []types.RepoTag{{Name: "t1"}}}}, data) +} + +func TestListComponent_ListDatasetByPath(t *testing.T) { + ctx := context.TODO() + lc := initializeTestListComponent(ctx, t) + + lc.mocks.stores.DatasetMock().EXPECT().ListByPath(ctx, []string{"foo"}).Return( + []database.Dataset{ + {Repository: &database.Repository{ + Name: "r1", + Tags: []database.Tag{{Name: "t1"}}, + }}, + }, nil, + ) + + data, err := lc.ListDatasetsByPath(ctx, &types.ListByPathReq{Paths: []string{"foo"}}) + require.Nil(t, err) + require.Equal(t, []*types.ModelResp{{Name: "r1", Tags: []types.RepoTag{{Name: "t1"}}}}, data) +} diff --git a/component/repo_file.go b/component/repo_file.go index 22481ca6..c1f35b0f 100644 --- a/component/repo_file.go +++ b/component/repo_file.go @@ -14,9 +14,9 @@ import ( ) type repoFileComponentImpl struct { - rfs database.RepoFileStore - rs database.RepoStore - gs gitserver.GitServer + repoFileStore database.RepoFileStore + repoStore database.RepoStore + gitServer gitserver.GitServer } type RepoFileComponent interface { @@ -26,23 +26,23 @@ type RepoFileComponent interface { func NewRepoFileComponent(conf *config.Config) (RepoFileComponent, error) { c := &repoFileComponentImpl{ - rfs: database.NewRepoFileStore(), - rs: database.NewRepoStore(), + repoFileStore: database.NewRepoFileStore(), + repoStore: database.NewRepoStore(), } gs, err := git.NewGitServer(conf) if err != nil { return nil, fmt.Errorf("failed to create git server, error: %w", err) } - c.gs = gs + c.gitServer = gs return c, nil } func (c *repoFileComponentImpl) GenRepoFileRecords(ctx context.Context, repoType types.RepositoryType, namespace, name string) error { - repo, err := c.rs.FindByPath(ctx, repoType, namespace, name) + repo, err := c.repoStore.FindByPath(ctx, repoType, namespace, name) if err != nil { return fmt.Errorf("failed to find repo, error: %w", err) } - return c.createRepoFileRecords(ctx, *repo, "", c.gs.GetRepoFileTree) + return c.createRepoFileRecords(ctx, *repo, "", c.gitServer.GetRepoFileTree) } func (c *repoFileComponentImpl) GenRepoFileRecordsBatch(ctx context.Context, repoType types.RepositoryType, lastRepoID int64, concurrency int) error { @@ -54,7 +54,7 @@ func (c *repoFileComponentImpl) GenRepoFileRecordsBatch(ctx context.Context, rep //TODO: load last repo id from redis cache batch := 10 for { - repos, err := c.rs.BatchGet(ctx, repoType, lastRepoID, batch) + repos, err := c.repoStore.BatchGet(ctx, repoType, lastRepoID, batch) if err != nil { return fmt.Errorf("failed to get repos in batch, error: %w", err) } @@ -65,7 +65,7 @@ func (c *repoFileComponentImpl) GenRepoFileRecordsBatch(ctx context.Context, rep go func(repo database.Repository) { slog.Info("start to get files of repository", slog.Any("repoType", repoType), slog.String("path", repo.Path)) //get file paths of repo - err := c.createRepoFileRecords(ctx, repo, "", c.gs.GetRepoFileTree) + err := c.createRepoFileRecords(ctx, repo, "", c.gitServer.GetRepoFileTree) if err != nil { slog.Error("fail to get all files of repository", slog.String("path", repo.Path), slog.String("repo_type", string(repo.RepositoryType)), @@ -127,7 +127,7 @@ func (c *repoFileComponentImpl) createRepoFileRecords(ctx context.Context, repo var exists bool var err error - if exists, err = c.rfs.Exists(ctx, rf); err != nil { + if exists, err = c.repoFileStore.Exists(ctx, rf); err != nil { slog.Error("failed to check repository file exists", slog.Any("repo_id", repo.ID), slog.String("file_path", rf.Path), slog.String("error", err.Error())) continue @@ -137,7 +137,7 @@ func (c *repoFileComponentImpl) createRepoFileRecords(ctx context.Context, repo slog.Info("skip create exist repository file", slog.Any("repo_id", repo.ID), slog.String("file_path", rf.Path)) continue } - if err := c.rfs.Create(ctx, &rf); err != nil { + if err := c.repoFileStore.Create(ctx, &rf); err != nil { slog.Error("failed to save repository file", slog.Any("repo_id", repo.ID), slog.String("error", err.Error())) return fmt.Errorf("failed to save repository file, error: %w", err) diff --git a/component/repo_file_test.go b/component/repo_file_test.go new file mode 100644 index 00000000..c40c4978 --- /dev/null +++ b/component/repo_file_test.go @@ -0,0 +1,89 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestRepoFileComponent_GenRepoFileRecords(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRepoFileComponent(ctx, t) + + rc.mocks.stores.RepoMock().EXPECT().FindByPath(ctx, types.ModelRepo, "ns", "n").Return( + &database.Repository{ID: 1, Path: "foo/bar"}, nil, + ) + rc.mocks.gitServer.EXPECT().GetRepoFileTree(mock.Anything, gitserver.GetRepoInfoByPathReq{ + Namespace: "foo", + Name: "bar", + }).Return( + []*types.File{ + {Path: "a/b", Type: "dir"}, + {Path: "foo.go", Type: "go"}, + }, nil, + ) + rc.mocks.gitServer.EXPECT().GetRepoFileTree(mock.Anything, gitserver.GetRepoInfoByPathReq{ + Path: "a/b", + Namespace: "foo", + Name: "bar", + }).Return( + []*types.File{}, nil, + ) + rc.mocks.stores.RepoFileMock().EXPECT().Exists(ctx, database.RepositoryFile{ + RepositoryID: 1, + Path: "foo.go", + FileType: "go", + }).Return(false, nil) + rc.mocks.stores.RepoFileMock().EXPECT().Create(ctx, &database.RepositoryFile{ + RepositoryID: 1, + Path: "foo.go", + FileType: "go", + }).Return(nil) + + err := rc.GenRepoFileRecords(ctx, types.ModelRepo, "ns", "n") + require.Nil(t, err) + +} + +func TestRepoFileComponent_GenRepoFileRecordsBatch(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRepoFileComponent(ctx, t) + + rc.mocks.stores.RepoMock().EXPECT().BatchGet(ctx, types.ModelRepo, int64(1), 10).Return( + []database.Repository{{ID: 1, Path: "foo/bar"}}, nil, + ) + rc.mocks.gitServer.EXPECT().GetRepoFileTree(mock.Anything, gitserver.GetRepoInfoByPathReq{ + Namespace: "foo", + Name: "bar", + }).Return( + []*types.File{ + {Path: "a/b", Type: "dir"}, + {Path: "foo.go", Type: "go"}, + }, nil, + ) + rc.mocks.gitServer.EXPECT().GetRepoFileTree(mock.Anything, gitserver.GetRepoInfoByPathReq{ + Path: "a/b", + Namespace: "foo", + Name: "bar", + }).Return( + []*types.File{}, nil, + ) + rc.mocks.stores.RepoFileMock().EXPECT().Exists(ctx, database.RepositoryFile{ + RepositoryID: 1, + Path: "foo.go", + FileType: "go", + }).Return(false, nil) + rc.mocks.stores.RepoFileMock().EXPECT().Create(ctx, &database.RepositoryFile{ + RepositoryID: 1, + Path: "foo.go", + FileType: "go", + }).Return(nil) + + err := rc.GenRepoFileRecordsBatch(ctx, types.ModelRepo, 1, 10) + require.Nil(t, err) +} diff --git a/component/sensitive_test.go b/component/sensitive_test.go index a7c1eb97..3d760373 100644 --- a/component/sensitive_test.go +++ b/component/sensitive_test.go @@ -6,7 +6,6 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - mockrpc "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/rpc" mocktypes "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/common/types" "opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/builder/sensitive" @@ -15,13 +14,12 @@ import ( ) func TestSensitiveComponent_CheckText(t *testing.T) { - mockModeration := mockrpc.NewMockModerationSvcClient(t) - mockModeration.EXPECT().PassTextCheck(mock.Anything, mock.Anything, mock.Anything).Return(&rpc.CheckResult{ + ctx := context.TODO() + comp := initializeTestSensitiveComponent(ctx, t) + + comp.mocks.moderationClient.EXPECT().PassTextCheck(mock.Anything, mock.Anything, mock.Anything).Return(&rpc.CheckResult{ IsSensitive: false, }, nil) - comp := &sensitiveComponentImpl{ - checker: mockModeration, - } success, err := comp.CheckText(context.TODO(), string(sensitive.ScenarioChatDetection), "test") require.Nil(t, err) @@ -29,13 +27,12 @@ func TestSensitiveComponent_CheckText(t *testing.T) { } func TestSensitiveComponent_CheckImage(t *testing.T) { - mockModeration := mockrpc.NewMockModerationSvcClient(t) - mockModeration.EXPECT().PassImageCheck(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&rpc.CheckResult{ + ctx := context.TODO() + comp := initializeTestSensitiveComponent(ctx, t) + + comp.mocks.moderationClient.EXPECT().PassImageCheck(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&rpc.CheckResult{ IsSensitive: false, }, nil) - comp := &sensitiveComponentImpl{ - checker: mockModeration, - } success, err := comp.CheckImage(context.TODO(), string(sensitive.ScenarioChatDetection), "ossBucketName", "ossObjectName") require.Nil(t, err) @@ -43,13 +40,13 @@ func TestSensitiveComponent_CheckImage(t *testing.T) { } func TestSensitiveComponent_CheckRequestV2(t *testing.T) { - mockModeration := mockrpc.NewMockModerationSvcClient(t) - mockModeration.EXPECT().PassTextCheck(mock.Anything, mock.Anything, mock.Anything).Return(&rpc.CheckResult{ + ctx := context.TODO() + comp := initializeTestSensitiveComponent(ctx, t) + + comp.mocks.moderationClient.EXPECT().PassTextCheck(mock.Anything, mock.Anything, mock.Anything).Return(&rpc.CheckResult{ IsSensitive: false, }, nil).Twice() - comp := &sensitiveComponentImpl{ - checker: mockModeration, - } + mockRequest := mocktypes.NewMockSensitiveRequestV2(t) mockRequest.EXPECT().GetSensitiveFields().Return([]types.SensitiveField{ { diff --git a/component/sshkey.go b/component/sshkey.go index a0f4cd10..9ccc9fa1 100644 --- a/component/sshkey.go +++ b/component/sshkey.go @@ -23,10 +23,10 @@ type SSHKeyComponent interface { func NewSSHKeyComponent(config *config.Config) (SSHKeyComponent, error) { c := &sSHKeyComponentImpl{} - c.ss = database.NewSSHKeyStore() - c.us = database.NewUserStore() + c.sshKeyStore = database.NewSSHKeyStore() + c.userStore = database.NewUserStore() var err error - c.gs, err = git.NewGitServer(config) + c.gitServer, err = git.NewGitServer(config) if err != nil { newError := fmt.Errorf("failed to create git server,error:%w", err) slog.Error(newError.Error()) @@ -36,17 +36,17 @@ func NewSSHKeyComponent(config *config.Config) (SSHKeyComponent, error) { } type sSHKeyComponentImpl struct { - ss database.SSHKeyStore - us database.UserStore - gs gitserver.GitServer + sshKeyStore database.SSHKeyStore + userStore database.UserStore + gitServer gitserver.GitServer } func (c *sSHKeyComponentImpl) Create(ctx context.Context, req *types.CreateSSHKeyRequest) (*database.SSHKey, error) { - user, err := c.us.FindByUsername(ctx, req.Username) + user, err := c.userStore.FindByUsername(ctx, req.Username) if err != nil { return nil, fmt.Errorf("failed to find user,error:%w", err) } - nameExistsKey, err := c.ss.FindByNameAndUserID(ctx, req.Name, user.ID) + nameExistsKey, err := c.sshKeyStore.FindByNameAndUserID(ctx, req.Name, user.ID) if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, fmt.Errorf("failed to find if ssh key exists,error:%w", err) } @@ -54,15 +54,14 @@ func (c *sSHKeyComponentImpl) Create(ctx context.Context, req *types.CreateSSHKe return nil, fmt.Errorf("ssh key name already exists") } - contentExistsKey, err := c.ss.FindByKeyContent(ctx, req.Content) + contentExistsKey, err := c.sshKeyStore.FindByKeyContent(ctx, req.Content) if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, fmt.Errorf("failed to find if ssh key exists,error:%w", err) } if contentExistsKey.ID != 0 { return nil, fmt.Errorf("ssh key already exists") } - - sk, err := c.gs.CreateSSHKey(req) + sk, err := c.gitServer.CreateSSHKey(req) if err != nil { return nil, fmt.Errorf("failed to create git SSH key,error:%w", err) } @@ -80,7 +79,7 @@ func (c *sSHKeyComponentImpl) Create(ctx context.Context, req *types.CreateSSHKe } sk.UserID = user.ID sk.FingerprintSHA256 = fingerprint - resSk, err := c.ss.Create(ctx, sk) + resSk, err := c.sshKeyStore.Create(ctx, sk) if err != nil { return nil, fmt.Errorf("failed to create database SSH key,error:%w", err) } @@ -88,7 +87,7 @@ func (c *sSHKeyComponentImpl) Create(ctx context.Context, req *types.CreateSSHKe } func (c *sSHKeyComponentImpl) Index(ctx context.Context, username string, per, page int) ([]database.SSHKey, error) { - sks, err := c.ss.Index(ctx, username, per, page) + sks, err := c.sshKeyStore.Index(ctx, username, per, page) if err != nil { return nil, fmt.Errorf("failed to get database SSH keys,error:%w", err) } @@ -96,15 +95,15 @@ func (c *sSHKeyComponentImpl) Index(ctx context.Context, username string, per, p } func (c *sSHKeyComponentImpl) Delete(ctx context.Context, username, name string) error { - sshKey, err := c.ss.FindByUsernameAndName(ctx, username, name) + sshKey, err := c.sshKeyStore.FindByUsernameAndName(ctx, username, name) if err != nil { return fmt.Errorf("failed to get database SSH keys,error:%w", err) } - err = c.gs.DeleteSSHKey(int(sshKey.GitID)) + err = c.gitServer.DeleteSSHKey(int(sshKey.GitID)) if err != nil { return fmt.Errorf("failed to delete git SSH keys,error:%w", err) } - err = c.ss.Delete(ctx, sshKey.GitID) + err = c.sshKeyStore.Delete(ctx, sshKey.ID) if err != nil { return fmt.Errorf("failed to delete database SSH keys,error:%w", err) } diff --git a/component/sshkey_test.go b/component/sshkey_test.go new file mode 100644 index 00000000..5bececec --- /dev/null +++ b/component/sshkey_test.go @@ -0,0 +1,67 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +const testKey = ` +ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCn4yeHw9InFrZIxYxFhs5Giam76NPIJ1kOqEq1xvWz4vJJMGkoqosTsqUf+V4Pj18qSUbSEDbwibzkIAPFNRiF1lQWgpFvZrZsTmD6rV1ODYjGPu5HLHqjCY/ffY+n/cAz66sZ5TQUMh+9HmUkVriu/Flfo7dWrbsrC73vgfVptSzSIEehkm4wL40XaZI4wQ7JffdXyqz5CU/lK+CFaPU2nLnxVoL9CEaFbCglcP4sO2jir2Rcx5ZNBMHYpsqk9N4cOxpS/IA9YX2tla3o4wltJoO83Vp0qH1ds15WBAlwUAdpJGDajh93kgYki6Kn2v41IgmqgFcXpmBQ+48QZXfh +` + +func TestSSHKeyComponent_Create(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSSHKeyComponent(ctx, t) + + req := &types.CreateSSHKeyRequest{ + Username: "user", + Name: "n", + Content: testKey, + } + sc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ID: 1}, nil) + sc.mocks.stores.SSHMock().EXPECT().FindByNameAndUserID(ctx, "n", int64(1)).Return( + &database.SSHKey{}, nil, + ) + sc.mocks.stores.SSHMock().EXPECT().FindByKeyContent(ctx, testKey).Return(&database.SSHKey{}, nil) + sc.mocks.gitServer.EXPECT().CreateSSHKey(req).Return(&database.SSHKey{}, nil) + sc.mocks.stores.SSHMock().EXPECT().Create(ctx, &database.SSHKey{ + UserID: 1, + FingerprintSHA256: "DZMgXySN8FuYZo2qvIAZOXNB0J81NMAv1SikyHvCPmw", + }).Return(&database.SSHKey{}, nil) + + data, err := sc.Create(ctx, req) + require.NoError(t, err) + require.Equal(t, &database.SSHKey{}, data) + +} + +func TestSSHKeyComponent_Index(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSSHKeyComponent(ctx, t) + + sc.mocks.stores.SSHMock().EXPECT().Index(ctx, "user", 10, 1).Return( + []database.SSHKey{{Name: "a"}}, nil, + ) + + data, err := sc.Index(ctx, "user", 10, 1) + require.Nil(t, err) + require.Equal(t, data, []database.SSHKey{{Name: "a"}}) +} + +func TestSSHKeyComponent_Delete(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSSHKeyComponent(ctx, t) + + sc.mocks.stores.SSHMock().EXPECT().FindByUsernameAndName(ctx, "user", "key").Return( + database.SSHKey{ID: 1, GitID: 123}, nil, + ) + sc.mocks.gitServer.EXPECT().DeleteSSHKey(123).Return(nil) + sc.mocks.stores.SSHMock().EXPECT().Delete(ctx, int64(1)).Return(nil) + + err := sc.Delete(ctx, "user", "key") + require.Nil(t, err) +} diff --git a/component/sync_client_setting_test.go b/component/sync_client_setting_test.go new file mode 100644 index 00000000..9b1efeba --- /dev/null +++ b/component/sync_client_setting_test.go @@ -0,0 +1,43 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestSyncClientSettingComponent_Create(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSyncClientSettingComponent(ctx, t) + + sc.mocks.stores.SyncClientSettingMock().EXPECT().SyncClientSettingExists(ctx).Return(true, nil) + sc.mocks.stores.SyncClientSettingMock().EXPECT().DeleteAll(ctx).Return(nil) + sc.mocks.stores.SyncClientSettingMock().EXPECT().Create(ctx, &database.SyncClientSetting{ + Token: "t", + ConcurrentCount: 1, + MaxBandwidth: 5, + }).Return(&database.SyncClientSetting{}, nil) + + data, err := sc.Create(ctx, types.CreateSyncClientSettingReq{ + Token: "t", + ConcurrentCount: 1, + MaxBandwidth: 5, + }) + require.Nil(t, err) + require.Equal(t, &database.SyncClientSetting{}, data) + +} + +func TestSyncClientSettingComponent_Show(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSyncClientSettingComponent(ctx, t) + + sc.mocks.stores.SyncClientSettingMock().EXPECT().First(ctx).Return(&database.SyncClientSetting{}, nil) + + data, err := sc.Show(ctx) + require.Nil(t, err) + require.Equal(t, &database.SyncClientSetting{}, data) +} diff --git a/component/telemetry.go b/component/telemetry.go index a1654ae9..bbdd362f 100644 --- a/component/telemetry.go +++ b/component/telemetry.go @@ -12,10 +12,9 @@ import ( ) type telemetryComponentImpl struct { - // Add telemetry related fields and methods here - ts database.TelemetryStore - us database.UserStore - rs database.RepoStore + telemetryStore database.TelemetryStore + userStore database.UserStore + repoStore database.RepoStore } type TelemetryComponent interface { @@ -27,7 +26,7 @@ func NewTelemetryComponent() (TelemetryComponent, error) { ts := database.NewTelemetryStore() us := database.NewUserStore() rs := database.NewRepoStore() - return &telemetryComponentImpl{ts: ts, us: us, rs: rs}, nil + return &telemetryComponentImpl{telemetryStore: ts, userStore: us, repoStore: rs}, nil } func (tc *telemetryComponentImpl) SaveUsageData(ctx context.Context, usage telemetry.Usage) error { @@ -52,7 +51,7 @@ func (tc *telemetryComponentImpl) SaveUsageData(ctx context.Context, usage telem Settings: usage.Settings, Counts: usage.Counts, } - err := tc.ts.Save(ctx, &t) + err := tc.telemetryStore.Save(ctx, &t) if err != nil { return fmt.Errorf("failed to save telemetry data to db: %w", err) } @@ -105,27 +104,27 @@ func (tc *telemetryComponentImpl) GenUsageData(ctx context.Context) (telemetry.U } func (tc *telemetryComponentImpl) getUserCnt(ctx context.Context) (int, error) { - return tc.us.CountUsers(ctx) + return tc.userStore.CountUsers(ctx) } func (tc *telemetryComponentImpl) getCounts(ctx context.Context) (telemetry.Counts, error) { var counts telemetry.Counts - modelCnt, err := tc.rs.CountByRepoType(ctx, types.ModelRepo) + modelCnt, err := tc.repoStore.CountByRepoType(ctx, types.ModelRepo) if err != nil { return counts, fmt.Errorf("failed to get model repo count: %w", err) } - dsCnt, err := tc.rs.CountByRepoType(ctx, types.DatasetRepo) + dsCnt, err := tc.repoStore.CountByRepoType(ctx, types.DatasetRepo) if err != nil { return counts, fmt.Errorf("failed to get dataset repo count: %w", err) } - codeCnt, err := tc.rs.CountByRepoType(ctx, types.CodeRepo) + codeCnt, err := tc.repoStore.CountByRepoType(ctx, types.CodeRepo) if err != nil { return counts, fmt.Errorf("failed to get code repo count: %w", err) } - spaceCnt, err := tc.rs.CountByRepoType(ctx, types.SpaceRepo) + spaceCnt, err := tc.repoStore.CountByRepoType(ctx, types.SpaceRepo) if err != nil { return counts, fmt.Errorf("failed to get space repo count: %w", err) } diff --git a/component/telemetry_test.go b/component/telemetry_test.go new file mode 100644 index 00000000..dd839952 --- /dev/null +++ b/component/telemetry_test.go @@ -0,0 +1,56 @@ +package component + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" + "opencsg.com/csghub-server/common/types/telemetry" +) + +func TestTelemetryComponent_SaveUsageData(t *testing.T) { + ctx := context.TODO() + tc := initializeTestTelemetryComponent(ctx, t) + + tc.mocks.stores.TelemetryMock().EXPECT().Save(ctx, &database.Telemetry{ + UUID: "uid", + Version: "v1", + Licensee: telemetry.Licensee{}, + Settings: telemetry.Settings{}, + Counts: telemetry.Counts{}, + }).Return(nil) + + err := tc.SaveUsageData(ctx, telemetry.Usage{ + UUID: "uid", + Version: "v1", + }) + require.Nil(t, err) + +} + +func TestTelemetryComponent_GenUsageData(t *testing.T) { + ctx := context.TODO() + tc := initializeTestTelemetryComponent(ctx, t) + + tc.mocks.stores.UserMock().EXPECT().CountUsers(ctx).Return(100, nil) + tc.mocks.stores.RepoMock().EXPECT().CountByRepoType(ctx, types.ModelRepo).Return(10, nil) + tc.mocks.stores.RepoMock().EXPECT().CountByRepoType(ctx, types.DatasetRepo).Return(20, nil) + tc.mocks.stores.RepoMock().EXPECT().CountByRepoType(ctx, types.CodeRepo).Return(30, nil) + tc.mocks.stores.RepoMock().EXPECT().CountByRepoType(ctx, types.SpaceRepo).Return(40, nil) + + data, err := tc.GenUsageData(ctx) + require.Nil(t, err) + + require.Equal(t, 100, data.ActiveUserCount) + require.Equal(t, 30, data.Counts.Codes) + require.Equal(t, 20, data.Counts.Datasets) + require.Equal(t, 10, data.Counts.Models) + require.Equal(t, 40, data.Counts.Spaces) + require.Equal(t, 100, data.Counts.TotalRepos) + require.NotEmpty(t, data.UUID) + require.GreaterOrEqual(t, time.Now(), data.RecordedAt) + require.LessOrEqual(t, time.Now().Add(-5*time.Second), data.RecordedAt) +} diff --git a/component/wire.go b/component/wire.go index 89b3ab39..76ebad30 100644 --- a/component/wire.go +++ b/component/wire.go @@ -345,3 +345,163 @@ func initializeTestSpaceSdkComponent(ctx context.Context, t interface { ) return &testSpaceSdkWithMocks{} } + +type testTelemetryWithMocks struct { + *telemetryComponentImpl + mocks *Mocks +} + +func initializeTestTelemetryComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testTelemetryWithMocks { + wire.Build( + MockSuperSet, TelemetryComponentSet, + wire.Struct(new(testTelemetryWithMocks), "*"), + ) + return &testTelemetryWithMocks{} +} + +type testClusterWithMocks struct { + *clusterComponentImpl + mocks *Mocks +} + +func initializeTestClusterComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testClusterWithMocks { + wire.Build( + MockSuperSet, ClusterComponentSet, + wire.Struct(new(testClusterWithMocks), "*"), + ) + return &testClusterWithMocks{} +} + +type testEvaluationWithMocks struct { + *evaluationComponentImpl + mocks *Mocks +} + +func initializeTestEvaluationComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testEvaluationWithMocks { + wire.Build( + MockSuperSet, EvaluationComponentSet, + wire.Struct(new(testEvaluationWithMocks), "*"), + ) + return &testEvaluationWithMocks{} +} + +type testHFDatasetWithMocks struct { + *hFDatasetComponentImpl + mocks *Mocks +} + +func initializeTestHFDatasetComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testHFDatasetWithMocks { + wire.Build( + MockSuperSet, HFDatasetComponentSet, + wire.Struct(new(testHFDatasetWithMocks), "*"), + ) + return &testHFDatasetWithMocks{} +} + +type testRepoFileWithMocks struct { + *repoFileComponentImpl + mocks *Mocks +} + +func initializeTestRepoFileComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testRepoFileWithMocks { + wire.Build( + MockSuperSet, RepoFileComponentSet, + wire.Struct(new(testRepoFileWithMocks), "*"), + ) + return &testRepoFileWithMocks{} +} + +type testSensitiveWithMocks struct { + *sensitiveComponentImpl + mocks *Mocks +} + +func initializeTestSensitiveComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSensitiveWithMocks { + wire.Build( + MockSuperSet, SensitiveComponentSet, + wire.Struct(new(testSensitiveWithMocks), "*"), + ) + return &testSensitiveWithMocks{} +} + +type testSSHKeyWithMocks struct { + *sSHKeyComponentImpl + mocks *Mocks +} + +func initializeTestSSHKeyComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSSHKeyWithMocks { + wire.Build( + MockSuperSet, SSHKeyComponentSet, + wire.Struct(new(testSSHKeyWithMocks), "*"), + ) + return &testSSHKeyWithMocks{} +} + +type testListWithMocks struct { + *listComponentImpl + mocks *Mocks +} + +func initializeTestListComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testListWithMocks { + wire.Build( + MockSuperSet, ListComponentSet, + wire.Struct(new(testListWithMocks), "*"), + ) + return &testListWithMocks{} +} + +type testSyncClientSettingWithMocks struct { + *syncClientSettingComponentImpl + mocks *Mocks +} + +func initializeTestSyncClientSettingComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSyncClientSettingWithMocks { + wire.Build( + MockSuperSet, SyncClientSettingComponentSet, + wire.Struct(new(testSyncClientSettingWithMocks), "*"), + ) + return &testSyncClientSettingWithMocks{} +} + +type testEventWithMocks struct { + *eventComponentImpl + mocks *Mocks +} + +func initializeTestEventComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testEventWithMocks { + wire.Build( + MockSuperSet, EventComponentSet, + wire.Struct(new(testEventWithMocks), "*"), + ) + return &testEventWithMocks{} +} diff --git a/component/wire_gen_test.go b/component/wire_gen_test.go index 2fac5ee9..a8fec516 100644 --- a/component/wire_gen_test.go +++ b/component/wire_gen_test.go @@ -1068,6 +1068,506 @@ func initializeTestSpaceSdkComponent(ctx context.Context, t interface { return componentTestSpaceSdkWithMocks } +func initializeTestTelemetryComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testTelemetryWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + componentTelemetryComponentImpl := NewTestTelemetryComponent(config, mockStores) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestTelemetryWithMocks := &testTelemetryWithMocks{ + telemetryComponentImpl: componentTelemetryComponentImpl, + mocks: mocks, + } + return componentTestTelemetryWithMocks +} + +func initializeTestClusterComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testClusterWithMocks { + config := ProvideTestConfig() + mockDeployer := deploy.NewMockDeployer(t) + componentClusterComponentImpl := NewTestClusterComponent(config, mockDeployer) + mockStores := tests.NewMockStores(t) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestClusterWithMocks := &testClusterWithMocks{ + clusterComponentImpl: componentClusterComponentImpl, + mocks: mocks, + } + return componentTestClusterWithMocks +} + +func initializeTestEvaluationComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testEvaluationWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingComponent := component.NewMockAccountingComponent(t) + componentEvaluationComponentImpl := NewTestEvaluationComponent(config, mockStores, mockDeployer, mockAccountingComponent) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestEvaluationWithMocks := &testEvaluationWithMocks{ + evaluationComponentImpl: componentEvaluationComponentImpl, + mocks: mocks, + } + return componentTestEvaluationWithMocks +} + +func initializeTestHFDatasetComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testHFDatasetWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockGitServer := gitserver.NewMockGitServer(t) + componentHFDatasetComponentImpl := NewTestHFDatasetComponent(config, mockStores, mockRepoComponent, mockGitServer) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestHFDatasetWithMocks := &testHFDatasetWithMocks{ + hFDatasetComponentImpl: componentHFDatasetComponentImpl, + mocks: mocks, + } + return componentTestHFDatasetWithMocks +} + +func initializeTestRepoFileComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testRepoFileWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockGitServer := gitserver.NewMockGitServer(t) + componentRepoFileComponentImpl := NewTestRepoFileComponent(config, mockStores, mockGitServer) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestRepoFileWithMocks := &testRepoFileWithMocks{ + repoFileComponentImpl: componentRepoFileComponentImpl, + mocks: mocks, + } + return componentTestRepoFileWithMocks +} + +func initializeTestSensitiveComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSensitiveWithMocks { + config := ProvideTestConfig() + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + componentSensitiveComponentImpl := NewTestSensitiveComponent(config, mockModerationSvcClient) + mockStores := tests.NewMockStores(t) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestSensitiveWithMocks := &testSensitiveWithMocks{ + sensitiveComponentImpl: componentSensitiveComponentImpl, + mocks: mocks, + } + return componentTestSensitiveWithMocks +} + +func initializeTestSSHKeyComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSSHKeyWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + mockGitServer := gitserver.NewMockGitServer(t) + componentSSHKeyComponentImpl := NewTestSSHKeyComponent(config, mockStores, mockGitServer) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestSSHKeyWithMocks := &testSSHKeyWithMocks{ + sSHKeyComponentImpl: componentSSHKeyComponentImpl, + mocks: mocks, + } + return componentTestSSHKeyWithMocks +} + +func initializeTestListComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testListWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + componentListComponentImpl := NewTestListComponent(config, mockStores) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestListWithMocks := &testListWithMocks{ + listComponentImpl: componentListComponentImpl, + mocks: mocks, + } + return componentTestListWithMocks +} + +func initializeTestSyncClientSettingComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testSyncClientSettingWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + componentSyncClientSettingComponentImpl := NewTestSyncClientSettingComponent(config, mockStores) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestSyncClientSettingWithMocks := &testSyncClientSettingWithMocks{ + syncClientSettingComponentImpl: componentSyncClientSettingComponentImpl, + mocks: mocks, + } + return componentTestSyncClientSettingWithMocks +} + +func initializeTestEventComponent(ctx context.Context, t interface { + Cleanup(func()) + mock.TestingT +}) *testEventWithMocks { + config := ProvideTestConfig() + mockStores := tests.NewMockStores(t) + componentEventComponentImpl := NewTestEventComponent(config, mockStores) + mockAccountingComponent := component.NewMockAccountingComponent(t) + mockRepoComponent := component.NewMockRepoComponent(t) + mockTagComponent := component.NewMockTagComponent(t) + mockSpaceComponent := component.NewMockSpaceComponent(t) + mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) + mockSensitiveComponent := component.NewMockSensitiveComponent(t) + componentMockedComponents := &mockedComponents{ + accounting: mockAccountingComponent, + repo: mockRepoComponent, + tag: mockTagComponent, + space: mockSpaceComponent, + runtimeArchitecture: mockRuntimeArchitectureComponent, + sensitive: mockSensitiveComponent, + } + mockGitServer := gitserver.NewMockGitServer(t) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + mockClient := s3.NewMockClient(t) + mockMirrorServer := mirrorserver.NewMockMirrorServer(t) + mockPriorityQueue := queue.NewMockPriorityQueue(t) + mockDeployer := deploy.NewMockDeployer(t) + mockAccountingClient := accounting.NewMockAccountingClient(t) + mockReader := parquet.NewMockReader(t) + mockModerationSvcClient := rpc.NewMockModerationSvcClient(t) + mocks := &Mocks{ + stores: mockStores, + components: componentMockedComponents, + gitServer: mockGitServer, + userSvcClient: mockUserSvcClient, + s3Client: mockClient, + mirrorServer: mockMirrorServer, + mirrorQueue: mockPriorityQueue, + deployer: mockDeployer, + accountingClient: mockAccountingClient, + preader: mockReader, + moderationClient: mockModerationSvcClient, + } + componentTestEventWithMocks := &testEventWithMocks{ + eventComponentImpl: componentEventComponentImpl, + mocks: mocks, + } + return componentTestEventWithMocks +} + // wire.go: type testRepoWithMocks struct { @@ -1174,3 +1674,53 @@ type testSpaceSdkWithMocks struct { *spaceSdkComponentImpl mocks *Mocks } + +type testTelemetryWithMocks struct { + *telemetryComponentImpl + mocks *Mocks +} + +type testClusterWithMocks struct { + *clusterComponentImpl + mocks *Mocks +} + +type testEvaluationWithMocks struct { + *evaluationComponentImpl + mocks *Mocks +} + +type testHFDatasetWithMocks struct { + *hFDatasetComponentImpl + mocks *Mocks +} + +type testRepoFileWithMocks struct { + *repoFileComponentImpl + mocks *Mocks +} + +type testSensitiveWithMocks struct { + *sensitiveComponentImpl + mocks *Mocks +} + +type testSSHKeyWithMocks struct { + *sSHKeyComponentImpl + mocks *Mocks +} + +type testListWithMocks struct { + *listComponentImpl + mocks *Mocks +} + +type testSyncClientSettingWithMocks struct { + *syncClientSettingComponentImpl + mocks *Mocks +} + +type testEventWithMocks struct { + *eventComponentImpl + mocks *Mocks +} diff --git a/component/wireset.go b/component/wireset.go index b25cc075..6723031f 100644 --- a/component/wireset.go +++ b/component/wireset.go @@ -487,3 +487,104 @@ func NewTestSpaceSdkComponent(config *config.Config, stores *tests.MockStores) * } var SpaceSdkComponentSet = wire.NewSet(NewTestSpaceSdkComponent) + +func NewTestTelemetryComponent(config *config.Config, stores *tests.MockStores) *telemetryComponentImpl { + return &telemetryComponentImpl{ + telemetryStore: stores.Telemetry, + userStore: stores.User, + repoStore: stores.Repo, + } +} + +var TelemetryComponentSet = wire.NewSet(NewTestTelemetryComponent) + +func NewTestClusterComponent(config *config.Config, deployer deploy.Deployer) *clusterComponentImpl { + return &clusterComponentImpl{ + deployer: deployer, + } +} + +var ClusterComponentSet = wire.NewSet(NewTestClusterComponent) + +func NewTestEvaluationComponent(config *config.Config, stores *tests.MockStores, deployer deploy.Deployer, accountingComponent AccountingComponent) *evaluationComponentImpl { + return &evaluationComponentImpl{ + deployer: deployer, + userStore: stores.User, + modelStore: stores.Model, + datasetStore: stores.Dataset, + mirrorStore: stores.Mirror, + spaceResourceStore: stores.SpaceResource, + tokenStore: stores.AccessToken, + runtimeFrameworkStore: stores.RuntimeFramework, + config: config, + accountingComponent: accountingComponent, + } +} + +var EvaluationComponentSet = wire.NewSet(NewTestEvaluationComponent) + +func NewTestHFDatasetComponent(config *config.Config, stores *tests.MockStores, repoComponent RepoComponent, gitServer gitserver.GitServer) *hFDatasetComponentImpl { + return &hFDatasetComponentImpl{ + repoComponent: repoComponent, + tagStore: stores.Tag, + datasetStore: stores.Dataset, + repoStore: stores.Repo, + gitServer: gitServer, + } +} + +var HFDatasetComponentSet = wire.NewSet(NewTestHFDatasetComponent) + +func NewTestRepoFileComponent(config *config.Config, stores *tests.MockStores, gitServer gitserver.GitServer) *repoFileComponentImpl { + return &repoFileComponentImpl{ + repoFileStore: stores.RepoFile, + repoStore: stores.Repo, + gitServer: gitServer, + } +} + +var RepoFileComponentSet = wire.NewSet(NewTestRepoFileComponent) + +func NewTestSensitiveComponent(config *config.Config, checker rpc.ModerationSvcClient) *sensitiveComponentImpl { + return &sensitiveComponentImpl{ + checker: checker, + } +} + +var SensitiveComponentSet = wire.NewSet(NewTestSensitiveComponent) + +func NewTestSSHKeyComponent(config *config.Config, stores *tests.MockStores, gitServer gitserver.GitServer) *sSHKeyComponentImpl { + return &sSHKeyComponentImpl{ + sshKeyStore: stores.SSH, + userStore: stores.User, + gitServer: gitServer, + } +} + +var SSHKeyComponentSet = wire.NewSet(NewTestSSHKeyComponent) + +func NewTestListComponent(config *config.Config, stores *tests.MockStores) *listComponentImpl { + return &listComponentImpl{ + modelStore: stores.Model, + datasetStore: stores.Dataset, + spaceStore: stores.Space, + } +} + +var ListComponentSet = wire.NewSet(NewTestListComponent) + +func NewTestSyncClientSettingComponent(config *config.Config, stores *tests.MockStores) *syncClientSettingComponentImpl { + return &syncClientSettingComponentImpl{ + settingStore: stores.SyncClientSetting, + } +} + +var SyncClientSettingComponentSet = wire.NewSet(NewTestSyncClientSettingComponent) + +func NewTestEventComponent(config *config.Config, stores *tests.MockStores) *eventComponentImpl { + return &eventComponentImpl{ + eventStore: stores.Event, + } +} + +var EventComponentSet = wire.NewSet(NewTestEventComponent) From d3b09358ba309ef47470a5851a17020cd15b0a56 Mon Sep 17 00:00:00 2001 From: yiling Date: Wed, 18 Dec 2024 09:32:06 +0800 Subject: [PATCH 07/34] sync code component --- common/types/code.go | 45 +++++++++++++++++++++--------------------- common/types/repo.go | 20 +++++++++++++++++++ component/code.go | 3 +++ component/code_test.go | 13 ++++++------ 4 files changed, 53 insertions(+), 28 deletions(-) diff --git a/common/types/code.go b/common/types/code.go index bc1ff658..b90b8bdf 100644 --- a/common/types/code.go +++ b/common/types/code.go @@ -17,26 +17,27 @@ type UpdateCodeReq struct { } type Code struct { - ID int64 `json:"id"` - Name string `json:"name"` - Nickname string `json:"nickname"` - Description string `json:"description"` - Likes int64 `json:"likes"` - Downloads int64 `json:"downloads"` - Path string `json:"path"` - RepositoryID int64 `json:"repository_id"` - Repository Repository `json:"repository"` - Private bool `json:"private"` - User User `json:"user"` - Tags []RepoTag `json:"tags"` - DefaultBranch string `json:"default_branch"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - UserLikes bool `json:"user_likes"` - Source RepositorySource `json:"source"` - SyncStatus RepositorySyncStatus `json:"sync_status"` - License string `json:"license"` - CanWrite bool `json:"can_write"` - CanManage bool `json:"can_manage"` - Namespace *Namespace `json:"namespace"` + ID int64 `json:"id"` + Name string `json:"name"` + Nickname string `json:"nickname"` + Description string `json:"description"` + Likes int64 `json:"likes"` + Downloads int64 `json:"downloads"` + Path string `json:"path"` + RepositoryID int64 `json:"repository_id"` + Repository Repository `json:"repository"` + Private bool `json:"private"` + User User `json:"user"` + Tags []RepoTag `json:"tags"` + DefaultBranch string `json:"default_branch"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + UserLikes bool `json:"user_likes"` + Source RepositorySource `json:"source"` + SyncStatus RepositorySyncStatus `json:"sync_status"` + License string `json:"license"` + CanWrite bool `json:"can_write"` + CanManage bool `json:"can_manage"` + Namespace *Namespace `json:"namespace"` + SensitiveCheckStatus string `json:"sensitive_check_status"` } diff --git a/common/types/repo.go b/common/types/repo.go index 63c93013..ba387aac 100644 --- a/common/types/repo.go +++ b/common/types/repo.go @@ -9,6 +9,26 @@ type RepositorySource string type RepositorySyncStatus string type SensitiveCheckStatus int +// String returns a string representation of the sensitive check status. +// +// It returns one of "Fail", "Pending", "Pass", "Skip", "Exception", or "Unknown". +func (s SensitiveCheckStatus) String() string { + switch s { + case SensitiveCheckFail: + return "Fail" + case SensitiveCheckPending: + return "Pending" + case SensitiveCheckPass: + return "Pass" + case SensitiveCheckSkip: + return "Skip" + case SensitiveCheckException: + return "Exception" + default: + return "Unknown" + } +} + const ( ResTypeKey string = "hub-res-type" ResNameKey string = "hub-res-name" diff --git a/component/code.go b/component/code.go index 07f7a8b1..4db155e2 100644 --- a/component/code.go +++ b/component/code.go @@ -349,6 +349,9 @@ func (c *codeComponentImpl) Show(ctx context.Context, namespace, name, currentUs CanManage: permission.CanAdmin, Namespace: ns, } + if permission.CanAdmin { + resCode.SensitiveCheckStatus = code.Repository.SensitiveCheckStatus.String() + } return resCode, nil } diff --git a/component/code_test.go b/component/code_test.go index 4dff2019..56a9530e 100644 --- a/component/code_test.go +++ b/component/code_test.go @@ -166,12 +166,13 @@ func TestCodeComponent_Show(t *testing.T) { HTTPCloneURL: "/s/.git", SSHCloneURL: ":s/.git", }, - RepositoryID: 11, - Namespace: &types.Namespace{}, - Name: "name", - User: types.User{Username: "user"}, - CanManage: true, - UserLikes: true, + RepositoryID: 11, + Namespace: &types.Namespace{}, + Name: "name", + User: types.User{Username: "user"}, + CanManage: true, + UserLikes: true, + SensitiveCheckStatus: "Pending", }, data) } From 723e92192122847479a33eeecdc3af33dd0a2507 Mon Sep 17 00:00:00 2001 From: yiling Date: Wed, 18 Dec 2024 09:36:58 +0800 Subject: [PATCH 08/34] sync dataset component --- common/types/dataset.go | 49 +++++++++++++++++++------------------- component/dataset.go | 5 +++- component/git_http.go | 4 ++-- component/git_http_test.go | 2 +- component/repo.go | 4 ++-- component/repo_test.go | 4 ++-- 6 files changed, 36 insertions(+), 32 deletions(-) diff --git a/common/types/dataset.go b/common/types/dataset.go index f81915f1..14462801 100644 --- a/common/types/dataset.go +++ b/common/types/dataset.go @@ -18,28 +18,29 @@ type UpdateDatasetReq struct { } type Dataset struct { - ID int64 `json:"id,omitempty"` - Name string `json:"name"` - Nickname string `json:"nickname"` - Description string `json:"description"` - Likes int64 `json:"likes"` - Downloads int64 `json:"downloads"` - Path string `json:"path"` - RepositoryID int64 `json:"repository_id"` - Repository Repository `json:"repository"` - Private bool `json:"private"` - User User `json:"user"` - Tags []RepoTag `json:"tags"` - Readme string `json:"readme"` - DefaultBranch string `json:"default_branch"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - UserLikes bool `json:"user_likes"` - Source RepositorySource `json:"source"` - SyncStatus RepositorySyncStatus `json:"sync_status"` - License string `json:"license"` - CanWrite bool `json:"can_write"` - CanManage bool `json:"can_manage"` - Namespace *Namespace `json:"namespace"` - MirrorLastUpdatedAt time.Time `json:"mirror_last_updated_at"` + ID int64 `json:"id,omitempty"` + Name string `json:"name"` + Nickname string `json:"nickname"` + Description string `json:"description"` + Likes int64 `json:"likes"` + Downloads int64 `json:"downloads"` + Path string `json:"path"` + RepositoryID int64 `json:"repository_id"` + Repository Repository `json:"repository"` + Private bool `json:"private"` + User User `json:"user"` + Tags []RepoTag `json:"tags"` + Readme string `json:"readme"` + DefaultBranch string `json:"default_branch"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + UserLikes bool `json:"user_likes"` + Source RepositorySource `json:"source"` + SyncStatus RepositorySyncStatus `json:"sync_status"` + License string `json:"license"` + CanWrite bool `json:"can_write"` + CanManage bool `json:"can_manage"` + Namespace *Namespace `json:"namespace"` + SensitiveCheckStatus string `json:"sensitive_check_status"` + MirrorLastUpdatedAt time.Time `json:"mirror_last_updated_at"` } diff --git a/component/dataset.go b/component/dataset.go index 302fbf68..875a2329 100644 --- a/component/dataset.go +++ b/component/dataset.go @@ -77,7 +77,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text const ( initCommitMessage = "initial commit" - ossFileExpireSeconds = 259200 * time.Second + ossFileExpire = 259200 * time.Second readmeFileName = "README.md" gitattributesFileName = ".gitattributes" ) @@ -473,6 +473,9 @@ func (c *datasetComponentImpl) Show(ctx context.Context, namespace, name, curren CanManage: permission.CanAdmin, Namespace: ns, } + if permission.CanAdmin { + resDataset.SensitiveCheckStatus = dataset.Repository.SensitiveCheckStatus.String() + } return resDataset, nil } diff --git a/component/git_http.go b/component/git_http.go index 4decc37a..d69ab5f4 100644 --- a/component/git_http.go +++ b/component/git_http.go @@ -281,7 +281,7 @@ func (c *gitHTTPComponentImpl) buildObjectResponse(ctx context.Context, req type var link *types.Link reqParams := make(url.Values) objectKey := path.Join("lfs", pointer.RelativePath()) - url, err := c.s3Client.PresignedGetObject(ctx, c.config.S3.Bucket, objectKey, ossFileExpireSeconds, reqParams) + url, err := c.s3Client.PresignedGetObject(ctx, c.config.S3.Bucket, objectKey, ossFileExpire, reqParams) if url != nil && err == nil { delete(header, "Authorization") link = &types.Link{Href: url.String(), Header: header} @@ -682,7 +682,7 @@ func (c *gitHTTPComponentImpl) LfsDownload(ctx context.Context, req types.Downlo // allow rename when download through content-disposition header reqParams.Set("response-content-disposition", fmt.Sprintf("attachment;filename=%s", req.SaveAs)) } - signedUrl, err := c.s3Client.PresignedGetObject(ctx, c.config.S3.Bucket, objectKey, ossFileExpireSeconds, reqParams) + signedUrl, err := c.s3Client.PresignedGetObject(ctx, c.config.S3.Bucket, objectKey, ossFileExpire, reqParams) if err != nil { return nil, err } diff --git a/component/git_http_test.go b/component/git_http_test.go index 88d24547..25de03b4 100644 --- a/component/git_http_test.go +++ b/component/git_http_test.go @@ -476,7 +476,7 @@ func TestGitHTTPComponent_LfsDownload(t *testing.T) { reqParams := make(url.Values) reqParams.Set("response-content-disposition", fmt.Sprintf("attachment;filename=%s", "sa")) url := &url.URL{Scheme: "http"} - gc.mocks.s3Client.EXPECT().PresignedGetObject(ctx, "", "lfs/oid", ossFileExpireSeconds, reqParams).Return(url, nil) + gc.mocks.s3Client.EXPECT().PresignedGetObject(ctx, "", "lfs/oid", ossFileExpire, reqParams).Return(url, nil) u, err := gc.LfsDownload(ctx, types.DownloadRequest{ Oid: "oid", diff --git a/component/repo.go b/component/repo.go index 95a73d01..9977780e 100644 --- a/component/repo.go +++ b/component/repo.go @@ -968,7 +968,7 @@ func (c *repoComponentImpl) DownloadFile(ctx context.Context, req *types.GetFile // allow rename when download through content-disposition header reqParams.Set("response-content-disposition", fmt.Sprintf("attachment;filename=%s", req.SaveAs)) } - signedUrl, err := c.s3Client.PresignedGetObject(ctx, c.lfsBucket, objectKey, ossFileExpireSeconds, reqParams) + signedUrl, err := c.s3Client.PresignedGetObject(ctx, c.lfsBucket, objectKey, ossFileExpire, reqParams) if err != nil { return nil, 0, downloadUrl, err } @@ -1289,7 +1289,7 @@ func (c *repoComponentImpl) SDKDownloadFile(ctx context.Context, req *types.GetF // allow rename when download through content-disposition header reqParams.Set("response-content-disposition", fmt.Sprintf("attachment;filename=%s", req.SaveAs)) } - signedUrl, err := c.s3Client.PresignedGetObject(ctx, c.lfsBucket, objectKey, ossFileExpireSeconds, reqParams) + signedUrl, err := c.s3Client.PresignedGetObject(ctx, c.lfsBucket, objectKey, ossFileExpire, reqParams) if err != nil { if err.Error() == ErrNotFoundMessage || err.Error() == ErrGetContentsOrList { return nil, 0, downloadUrl, ErrNotFound diff --git a/component/repo_test.go b/component/repo_test.go index 3959b7c3..c19df4d2 100644 --- a/component/repo_test.go +++ b/component/repo_test.go @@ -652,7 +652,7 @@ func TestRepoComponent_DownloadFile(t *testing.T) { reqParams := make(url.Values) reqParams.Set("response-content-disposition", fmt.Sprintf("attachment;filename=%s", "zzz")) repo.mocks.s3Client.EXPECT().PresignedGetObject( - ctx, repo.lfsBucket, "lfs/path", ossFileExpireSeconds, reqParams, + ctx, repo.lfsBucket, "lfs/path", ossFileExpire, reqParams, ).Return(&url.URL{Path: "foobar"}, nil) } else { repo.mocks.gitServer.EXPECT().GetRepoFileReader(ctx, gitserver.GetRepoInfoByPathReq{ @@ -765,7 +765,7 @@ func TestRepoComponent_SDKDownloadFile(t *testing.T) { reqParams := make(url.Values) reqParams.Set("response-content-disposition", fmt.Sprintf("attachment;filename=%s", "zzz")) repo.mocks.s3Client.EXPECT().PresignedGetObject( - ctx, repo.lfsBucket, "lfs/qqq", ossFileExpireSeconds, reqParams, + ctx, repo.lfsBucket, "lfs/qqq", ossFileExpire, reqParams, ).Return(&url.URL{Path: "foobar"}, nil) } else { repo.mocks.gitServer.EXPECT().GetRepoFileReader(ctx, gitserver.GetRepoInfoByPathReq{ From 001c56c3b573648d739b7972da4b3c6866e86e20 Mon Sep 17 00:00:00 2001 From: yiling Date: Wed, 18 Dec 2024 09:41:35 +0800 Subject: [PATCH 09/34] sync list component --- component/list.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/component/list.go b/component/list.go index 2fcccdde..a906522b 100644 --- a/component/list.go +++ b/component/list.go @@ -33,7 +33,7 @@ func (c *listComponentImpl) ListModelsByPath(ctx context.Context, req *types.Lis models, err := c.modelStore.ListByPath(ctx, req.Paths) if err != nil { - slog.Error("error listing models by path", "error", err, slog.Any("paths", req.Paths)) + slog.Error("error listing models by path: %v", slog.Any("error", err), slog.Any("paths", req.Paths)) return nil, err } for _, model := range models { @@ -69,7 +69,7 @@ func (c *listComponentImpl) ListDatasetsByPath(ctx context.Context, req *types.L datasets, err := c.datasetStore.ListByPath(ctx, req.Paths) if err != nil { - slog.Error("error listing datasets by path", "error", err, slog.Any("paths", req.Paths)) + slog.Error("error listing datasets by path: %v", slog.Any("error", err), slog.Any("paths", req.Paths)) return nil, err } for _, dataset := range datasets { From 67007de107dbc58e48cf81eb4dad769465ed273a Mon Sep 17 00:00:00 2001 From: yiling Date: Wed, 18 Dec 2024 09:48:06 +0800 Subject: [PATCH 10/34] sync mirror component --- .../store/database/mock_MirrorStore.go | 91 ++++- api/handler/mirror.go | 3 +- builder/store/database/mirror.go | 39 +- builder/store/database/mirror_test.go | 13 +- common/types/mirror.go | 5 + component/mirror.go | 74 ++-- component/mirror_test.go | 355 ++++++++++++++++++ 7 files changed, 534 insertions(+), 46 deletions(-) create mode 100644 component/mirror_test.go diff --git a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_MirrorStore.go b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_MirrorStore.go index 4f47d431..48656ec1 100644 --- a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_MirrorStore.go +++ b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_MirrorStore.go @@ -429,9 +429,9 @@ func (_c *MockMirrorStore_Finished_Call) RunAndReturn(run func(context.Context) return _c } -// IndexWithPagination provides a mock function with given fields: ctx, per, page -func (_m *MockMirrorStore) IndexWithPagination(ctx context.Context, per int, page int) ([]database.Mirror, int, error) { - ret := _m.Called(ctx, per, page) +// IndexWithPagination provides a mock function with given fields: ctx, per, page, search +func (_m *MockMirrorStore) IndexWithPagination(ctx context.Context, per int, page int, search string) ([]database.Mirror, int, error) { + ret := _m.Called(ctx, per, page, search) if len(ret) == 0 { panic("no return value specified for IndexWithPagination") @@ -440,25 +440,25 @@ func (_m *MockMirrorStore) IndexWithPagination(ctx context.Context, per int, pag var r0 []database.Mirror var r1 int var r2 error - if rf, ok := ret.Get(0).(func(context.Context, int, int) ([]database.Mirror, int, error)); ok { - return rf(ctx, per, page) + if rf, ok := ret.Get(0).(func(context.Context, int, int, string) ([]database.Mirror, int, error)); ok { + return rf(ctx, per, page, search) } - if rf, ok := ret.Get(0).(func(context.Context, int, int) []database.Mirror); ok { - r0 = rf(ctx, per, page) + if rf, ok := ret.Get(0).(func(context.Context, int, int, string) []database.Mirror); ok { + r0 = rf(ctx, per, page, search) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]database.Mirror) } } - if rf, ok := ret.Get(1).(func(context.Context, int, int) int); ok { - r1 = rf(ctx, per, page) + if rf, ok := ret.Get(1).(func(context.Context, int, int, string) int); ok { + r1 = rf(ctx, per, page, search) } else { r1 = ret.Get(1).(int) } - if rf, ok := ret.Get(2).(func(context.Context, int, int) error); ok { - r2 = rf(ctx, per, page) + if rf, ok := ret.Get(2).(func(context.Context, int, int, string) error); ok { + r2 = rf(ctx, per, page, search) } else { r2 = ret.Error(2) } @@ -475,13 +475,14 @@ type MockMirrorStore_IndexWithPagination_Call struct { // - ctx context.Context // - per int // - page int -func (_e *MockMirrorStore_Expecter) IndexWithPagination(ctx interface{}, per interface{}, page interface{}) *MockMirrorStore_IndexWithPagination_Call { - return &MockMirrorStore_IndexWithPagination_Call{Call: _e.mock.On("IndexWithPagination", ctx, per, page)} +// - search string +func (_e *MockMirrorStore_Expecter) IndexWithPagination(ctx interface{}, per interface{}, page interface{}, search interface{}) *MockMirrorStore_IndexWithPagination_Call { + return &MockMirrorStore_IndexWithPagination_Call{Call: _e.mock.On("IndexWithPagination", ctx, per, page, search)} } -func (_c *MockMirrorStore_IndexWithPagination_Call) Run(run func(ctx context.Context, per int, page int)) *MockMirrorStore_IndexWithPagination_Call { +func (_c *MockMirrorStore_IndexWithPagination_Call) Run(run func(ctx context.Context, per int, page int, search string)) *MockMirrorStore_IndexWithPagination_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(int), args[2].(int)) + run(args[0].(context.Context), args[1].(int), args[2].(int), args[3].(string)) }) return _c } @@ -491,7 +492,7 @@ func (_c *MockMirrorStore_IndexWithPagination_Call) Return(mirrors []database.Mi return _c } -func (_c *MockMirrorStore_IndexWithPagination_Call) RunAndReturn(run func(context.Context, int, int) ([]database.Mirror, int, error)) *MockMirrorStore_IndexWithPagination_Call { +func (_c *MockMirrorStore_IndexWithPagination_Call) RunAndReturn(run func(context.Context, int, int, string) ([]database.Mirror, int, error)) *MockMirrorStore_IndexWithPagination_Call { _c.Call.Return(run) return _c } @@ -728,6 +729,64 @@ func (_c *MockMirrorStore_PushedMirror_Call) RunAndReturn(run func(context.Conte return _c } +// StatusCount provides a mock function with given fields: ctx +func (_m *MockMirrorStore) StatusCount(ctx context.Context) ([]database.MirrorStatusCount, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for StatusCount") + } + + var r0 []database.MirrorStatusCount + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]database.MirrorStatusCount, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []database.MirrorStatusCount); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]database.MirrorStatusCount) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMirrorStore_StatusCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StatusCount' +type MockMirrorStore_StatusCount_Call struct { + *mock.Call +} + +// StatusCount is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockMirrorStore_Expecter) StatusCount(ctx interface{}) *MockMirrorStore_StatusCount_Call { + return &MockMirrorStore_StatusCount_Call{Call: _e.mock.On("StatusCount", ctx)} +} + +func (_c *MockMirrorStore_StatusCount_Call) Run(run func(ctx context.Context)) *MockMirrorStore_StatusCount_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockMirrorStore_StatusCount_Call) Return(_a0 []database.MirrorStatusCount, _a1 error) *MockMirrorStore_StatusCount_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMirrorStore_StatusCount_Call) RunAndReturn(run func(context.Context) ([]database.MirrorStatusCount, error)) *MockMirrorStore_StatusCount_Call { + _c.Call.Return(run) + return _c +} + // ToSyncLfs provides a mock function with given fields: ctx func (_m *MockMirrorStore) ToSyncLfs(ctx context.Context) ([]database.Mirror, error) { ret := _m.Called(ctx) diff --git a/api/handler/mirror.go b/api/handler/mirror.go index 36d35d46..bc74824b 100644 --- a/api/handler/mirror.go +++ b/api/handler/mirror.go @@ -123,7 +123,8 @@ func (h *MirrorHandler) Index(ctx *gin.Context) { httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) return } - repos, total, err := h.mc.Index(ctx, currentUser, per, page) + search := ctx.Query("search") + repos, total, err := h.mc.Index(ctx, currentUser, per, page, search) if err != nil { slog.Error("failed to get mirror repos", slog.Any("error", err)) httpbase.ServerError(ctx, err) diff --git a/builder/store/database/mirror.go b/builder/store/database/mirror.go index d9ccd448..2d266d02 100644 --- a/builder/store/database/mirror.go +++ b/builder/store/database/mirror.go @@ -3,6 +3,7 @@ package database import ( "context" "fmt" + "strings" "time" "github.com/uptrace/bun" @@ -31,7 +32,8 @@ type MirrorStore interface { Finished(ctx context.Context) ([]Mirror, error) ToSyncRepo(ctx context.Context) ([]Mirror, error) ToSyncLfs(ctx context.Context) ([]Mirror, error) - IndexWithPagination(ctx context.Context, per, page int) (mirrors []Mirror, count int, err error) + IndexWithPagination(ctx context.Context, per, page int, search string) (mirrors []Mirror, count int, err error) + StatusCount(ctx context.Context) ([]MirrorStatusCount, error) UpdateMirrorAndRepository(ctx context.Context, mirror *Mirror, repo *Repository) error } @@ -76,6 +78,11 @@ type Mirror struct { times } +type MirrorStatusCount struct { + Status types.MirrorTaskStatus `bun:"status"` + Count int `bun:"count"` +} + func (s *mirrorStoreImpl) IsExist(ctx context.Context, repoID int64) (exists bool, err error) { var mirror Mirror exists, err = s.db.Operator.Core. @@ -90,7 +97,8 @@ func (s *mirrorStoreImpl) IsRepoExist(ctx context.Context, repoType types.Reposi exists, err = s.db.Operator.Core. NewSelect(). Model(&repo). - Where("git_path=?", fmt.Sprintf("%ss_%s/%s", repoType, namespace, name)). + Where("path=?", fmt.Sprintf("%s/%s", namespace, name)). + Where("repository_type=?", repoType). Exists(ctx) return } @@ -273,7 +281,13 @@ func (s *mirrorStoreImpl) ToSyncRepo(ctx context.Context) ([]Mirror, error) { var mirrors []Mirror err := s.db.Operator.Core.NewSelect(). Model(&mirrors). - Where("next_execution_timestamp < ? or status in (?,?,?)", time.Now(), types.MirrorIncomplete, types.MirrorFailed, types.MirrorWaiting). + Where( + "next_execution_timestamp < ? or status in (?,?,?,?)", + time.Now(), + types.MirrorIncomplete, + types.MirrorFailed, + types.MirrorWaiting, + types.MirrorRunning). Scan(ctx) if err != nil { return nil, err @@ -293,11 +307,17 @@ func (s *mirrorStoreImpl) ToSyncLfs(ctx context.Context) ([]Mirror, error) { return mirrors, nil } -func (s *mirrorStoreImpl) IndexWithPagination(ctx context.Context, per, page int) (mirrors []Mirror, count int, err error) { +func (s *mirrorStoreImpl) IndexWithPagination(ctx context.Context, per, page int, search string) (mirrors []Mirror, count int, err error) { q := s.db.Operator.Core.NewSelect(). Model(&mirrors). Relation("Repository"). Relation("MirrorSource") + if search != "" { + q = q.Where("LOWER(mirror.source_url) like ? or LOWER(mirror.local_repo_path) like ?", + fmt.Sprintf("%%%s%%", strings.ToLower(search)), + fmt.Sprintf("%%%s%%", strings.ToLower(search)), + ) + } count, err = q.Count(ctx) if err != nil { return @@ -313,6 +333,17 @@ func (s *mirrorStoreImpl) IndexWithPagination(ctx context.Context, per, page int return } +func (s *mirrorStoreImpl) StatusCount(ctx context.Context) ([]MirrorStatusCount, error) { + var statusCounts []MirrorStatusCount + err := s.db.Operator.Core.NewSelect(). + Model((*Mirror)(nil)). + Column("status"). + ColumnExpr("COUNT(*) AS count"). + Group("status"). + Scan(ctx, &statusCounts) + return statusCounts, err +} + func (s *mirrorStoreImpl) UpdateMirrorAndRepository(ctx context.Context, mirror *Mirror, repo *Repository) error { err := s.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { _, err := tx.NewUpdate().Model(mirror).WherePK().Exec(ctx) diff --git a/builder/store/database/mirror_test.go b/builder/store/database/mirror_test.go index ab056121..6175ebc3 100644 --- a/builder/store/database/mirror_test.go +++ b/builder/store/database/mirror_test.go @@ -51,6 +51,7 @@ func TestMirrorStore_CRUD(t *testing.T) { RepositoryType: types.ModelRepo, GitPath: "models_ns/n", Name: "repo", + Path: "ns/n", } err = db.Core.NewInsert().Model(repo).Scan(ctx, repo) require.Nil(t, err) @@ -194,7 +195,7 @@ func TestMirrorStore_ToSync(t *testing.T) { for _, m := range ms { names = append(names, m.Interval) } - require.ElementsMatch(t, []string{"m1", "m3", "m6", "m7"}, names) + require.ElementsMatch(t, []string{"m1", "m3", "m5", "m6", "m7"}, names) ms, err = store.ToSyncLfs(ctx) require.Nil(t, err) @@ -222,7 +223,7 @@ func TestMirrorStore_IndexWithPagination(t *testing.T) { require.Nil(t, err) } - ms, count, err := store.IndexWithPagination(ctx, 10, 1) + ms, count, err := store.IndexWithPagination(ctx, 10, 1, "foo") require.Nil(t, err) names := []string{} for _, m := range ms { @@ -250,4 +251,12 @@ func TestMirrorStore_StatusCount(t *testing.T) { require.Nil(t, err) } + cs, err := store.StatusCount(ctx) + require.Nil(t, err) + require.Equal(t, 2, len(cs)) + require.ElementsMatch(t, []database.MirrorStatusCount{ + {types.MirrorFailed, 2}, + {types.MirrorFinished, 1}, + }, cs) + } diff --git a/common/types/mirror.go b/common/types/mirror.go index 294bed0c..d079597c 100644 --- a/common/types/mirror.go +++ b/common/types/mirror.go @@ -135,3 +135,8 @@ type Mirror struct { type MirrorSource struct { SourceName string `json:"source_name"` } + +type MirrorStatusCount struct { + Status MirrorTaskStatus + Count int +} diff --git a/component/mirror.go b/component/mirror.go index 050ae92e..67aa642b 100644 --- a/component/mirror.go +++ b/component/mirror.go @@ -48,7 +48,8 @@ type MirrorComponent interface { CreateMirrorRepo(ctx context.Context, req types.CreateMirrorRepoReq) (*database.Mirror, error) CheckMirrorProgress(ctx context.Context) error Repos(ctx context.Context, currentUser string, per, page int) ([]types.MirrorRepo, int, error) - Index(ctx context.Context, currentUser string, per, page int) ([]types.Mirror, int, error) + Index(ctx context.Context, currentUser string, per, page int, search string) ([]types.Mirror, int, error) + Statistics(ctx context.Context, currentUser string) ([]types.MirrorStatusCount, error) } func NewMirrorComponent(config *config.Config) (MirrorComponent, error) { @@ -66,7 +67,8 @@ func NewMirrorComponent(config *config.Config) (MirrorComponent, error) { if err != nil { return nil, fmt.Errorf("failed to get priority queue: %v", err) } - c.repoComp, err = NewRepoComponent(config) + + c.repoComp, err = NewRepoComponentImpl(config) if err != nil { return nil, fmt.Errorf("fail to create repo component,error:%w", err) } @@ -246,6 +248,7 @@ func (c *mirrorComponentImpl) CreateMirrorRepo(ctx context.Context, req types.Cr mirror.LocalRepoPath = fmt.Sprintf("%s_%s_%s_%s", mirrorSource.SourceName, req.RepoType, req.SourceNamespace, req.SourceName) mirror.SourceRepoPath = fmt.Sprintf("%s/%s", req.SourceNamespace, req.SourceName) mirror.Priority = types.HighMirrorPriority + var taskId int64 if c.config.GitServer.Type == types.GitServerTypeGitea { taskId, err = c.mirrorServer.CreateMirrorRepo(ctx, mirrorserver.CreateMirrorRepoReq{ @@ -262,14 +265,12 @@ func (c *mirrorComponentImpl) CreateMirrorRepo(ctx context.Context, req types.Cr return nil, fmt.Errorf("failed to create push mirror in mirror server: %v", err) } } - mirror.MirrorTaskID = taskId reqMirror, err := c.mirrorStore.Create(ctx, &mirror) if err != nil { return nil, fmt.Errorf("failed to create mirror") } - if c.config.GitServer.Type == types.GitServerTypeGitaly { c.mq.PushRepoMirror(&queue.MirrorTask{ MirrorID: reqMirror.ID, @@ -283,6 +284,7 @@ func (c *mirrorComponentImpl) CreateMirrorRepo(ctx context.Context, req types.Cr } return reqMirror, nil + } func (m *mirrorComponentImpl) mapNamespaceAndName(sourceNamespace string) string { @@ -620,7 +622,7 @@ func (c *mirrorComponentImpl) Repos(ctx context.Context, currentUser string, per return mirrorRepos, total, nil } -func (c *mirrorComponentImpl) Index(ctx context.Context, currentUser string, per, page int) ([]types.Mirror, int, error) { +func (c *mirrorComponentImpl) Index(ctx context.Context, currentUser string, per, page int, search string) ([]types.Mirror, int, error) { var mirrorsResp []types.Mirror user, err := c.userStore.FindByUsername(ctx, currentUser) if err != nil { @@ -629,28 +631,54 @@ func (c *mirrorComponentImpl) Index(ctx context.Context, currentUser string, per if !user.CanAdmin() { return nil, 0, errors.New("user does not have admin permission") } - mirrors, total, err := c.mirrorStore.IndexWithPagination(ctx, per, page) + mirrors, total, err := c.mirrorStore.IndexWithPagination(ctx, per, page, search) if err != nil { return nil, 0, fmt.Errorf("failed to get mirror mirrors: %v", err) } for _, mirror := range mirrors { - mirrorsResp = append(mirrorsResp, types.Mirror{ - SourceUrl: mirror.SourceUrl, - MirrorSource: types.MirrorSource{ - SourceName: mirror.MirrorSource.SourceName, - }, - Username: mirror.Username, - AccessToken: mirror.AccessToken, - PushUrl: mirror.PushUrl, - PushUsername: mirror.PushUsername, - PushAccessToken: mirror.PushAccessToken, - LastUpdatedAt: mirror.LastUpdatedAt, - SourceRepoPath: mirror.SourceRepoPath, - LocalRepoPath: fmt.Sprintf("%ss/%s", mirror.Repository.RepositoryType, mirror.Repository.Path), - LastMessage: mirror.LastMessage, - Status: mirror.Status, - Progress: mirror.Progress, - }) + if mirror.Repository != nil { + mirrorsResp = append(mirrorsResp, types.Mirror{ + SourceUrl: mirror.SourceUrl, + MirrorSource: types.MirrorSource{ + SourceName: mirror.MirrorSource.SourceName, + }, + Username: mirror.Username, + AccessToken: mirror.AccessToken, + PushUrl: mirror.PushUrl, + PushUsername: mirror.PushUsername, + PushAccessToken: mirror.PushAccessToken, + LastUpdatedAt: mirror.LastUpdatedAt, + SourceRepoPath: mirror.SourceRepoPath, + LocalRepoPath: fmt.Sprintf("%ss/%s", mirror.Repository.RepositoryType, mirror.Repository.Path), + LastMessage: mirror.LastMessage, + Status: mirror.Status, + Progress: mirror.Progress, + }) + } } return mirrorsResp, total, nil } + +func (c *mirrorComponentImpl) Statistics(ctx context.Context, currentUser string) ([]types.MirrorStatusCount, error) { + var scs []types.MirrorStatusCount + user, err := c.userStore.FindByUsername(ctx, currentUser) + if err != nil { + return nil, errors.New("user does not exist") + } + if !user.CanAdmin() { + return nil, errors.New("user does not have admin permission") + } + statusCounts, err := c.mirrorStore.StatusCount(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get mirror statistics: %v", err) + } + + for _, statusCount := range statusCounts { + scs = append(scs, types.MirrorStatusCount{ + Status: statusCount.Status, + Count: statusCount.Count, + }) + } + + return scs, nil +} diff --git a/component/mirror_test.go b/component/mirror_test.go new file mode 100644 index 00000000..375002c3 --- /dev/null +++ b/component/mirror_test.go @@ -0,0 +1,355 @@ +package component + +import ( + "context" + "database/sql" + "fmt" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/git/mirrorserver" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" + "opencsg.com/csghub-server/mirror/queue" +) + +func TestMirrorComponent_CreatePushMirrorForFinishedMirrorTask(t *testing.T) { + ctx := context.TODO() + mc := initializeTestMirrorComponent(ctx, t) + + mc.mocks.stores.MirrorMock().EXPECT().NoPushMirror(ctx).Return([]database.Mirror{ + {MirrorTaskID: 1}, + {MirrorTaskID: 2, LocalRepoPath: "foo"}, + }, nil) + mc.mocks.mirrorServer.EXPECT().GetMirrorTaskInfo(ctx, int64(1)).Return( + &mirrorserver.MirrorTaskInfo{}, nil, + ) + mc.mocks.mirrorServer.EXPECT().GetMirrorTaskInfo(ctx, int64(2)).Return( + &mirrorserver.MirrorTaskInfo{ + Status: mirrorserver.TaskStatusFinished, + }, nil, + ) + mc.mocks.mirrorServer.EXPECT().CreatePushMirror(ctx, mirrorserver.CreatePushMirrorReq{ + Name: "foo", + Interval: "8h", + }).Return(nil) + mc.mocks.stores.MirrorMock().EXPECT().Update(ctx, &database.Mirror{ + MirrorTaskID: 2, LocalRepoPath: "foo", PushMirrorCreated: true, + }).Return(nil) + + err := mc.CreatePushMirrorForFinishedMirrorTask(ctx) + require.Nil(t, err) +} + +func TestMirrorComponent_CreateMirrorRepo(t *testing.T) { + + cases := []struct { + repoType types.RepositoryType + gitea bool + }{ + {types.ModelRepo, false}, + {types.DatasetRepo, false}, + {types.CodeRepo, false}, + {types.CodeRepo, true}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("%+v", c), func(t *testing.T) { + + ctx := context.TODO() + mc := initializeTestMirrorComponent(ctx, t) + + req := types.CreateMirrorRepoReq{ + SourceNamespace: "sns", + SourceName: "sn", + RepoType: c.repoType, + CurrentUser: "user", + } + + if c.gitea { + mc.config.GitServer.Type = types.GitServerTypeGitea + } else { + mc.config.GitServer.Type = types.GitServerTypeGitaly + } + + mc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + RoleMask: "admin", + }, nil) + repo := &database.Repository{} + mc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, req.RepoType, "AIWizards", "sn", + ).Return( + repo, nil, + ) + mc.mocks.stores.RepoMock().EXPECT().FindByPath( + ctx, req.RepoType, "AIWizards", "sns_sn", + ).Return( + nil, sql.ErrNoRows, + ) + mc.mocks.stores.NamespaceMock().EXPECT().FindByPath( + ctx, "AIWizards", + ).Return(database.Namespace{ + User: database.User{Username: "user"}, + }, nil) + mc.mocks.components.repo.EXPECT().CreateRepo(ctx, types.CreateRepoReq{ + Username: "user", + Namespace: "AIWizards", + Name: "sns_sn", + Nickname: "sns_sn", + Description: req.Description, + Private: true, + License: req.License, + DefaultBranch: req.DefaultBranch, + RepoType: req.RepoType, + }).Return(&gitserver.CreateRepoResp{}, &database.Repository{}, nil) + switch req.RepoType { + case types.ModelRepo: + mc.mocks.stores.ModelMock().EXPECT().Create(ctx, database.Model{ + Repository: repo, + RepositoryID: repo.ID, + }).Return(nil, nil) + case types.DatasetRepo: + mc.mocks.stores.DatasetMock().EXPECT().Create(ctx, database.Dataset{ + Repository: repo, + RepositoryID: repo.ID, + }).Return(nil, nil) + case types.CodeRepo: + mc.mocks.stores.CodeMock().EXPECT().Create(ctx, database.Code{ + Repository: repo, + RepositoryID: repo.ID, + }).Return(nil, nil) + } + mc.mocks.stores.GitServerAccessTokenMock().EXPECT().FindByType(ctx, "git").Return( + []database.GitServerAccessToken{ + {}, + }, nil, + ) + mc.mocks.stores.MirrorSourceMock().EXPECT().Get(ctx, int64(0)).Return( + &database.MirrorSource{}, nil, + ) + if c.gitea { + mc.mocks.mirrorServer.EXPECT().CreateMirrorRepo(ctx, mirrorserver.CreateMirrorRepoReq{ + Name: "_code_sns_sn", + Namespace: "root", + Private: false, + SyncLfs: req.SyncLfs, + }).Return(123, nil) + } + reqMirror := &database.Mirror{ + ID: 1, + Priority: types.HighMirrorPriority, + } + localRepoPath := "" + switch req.RepoType { + case types.ModelRepo: + localRepoPath = "_model_sns_sn" + case types.DatasetRepo: + localRepoPath = "_dataset_sns_sn" + case types.CodeRepo: + localRepoPath = "_code_sns_sn" + } + + cm := &database.Mirror{ + Username: "sns", + PushUsername: "root", + SourceRepoPath: "sns/sn", + LocalRepoPath: localRepoPath, + Priority: types.HighMirrorPriority, + Repository: &database.Repository{}, + } + if c.gitea { + cm.MirrorTaskID = 123 + } + mc.mocks.stores.MirrorMock().EXPECT().Create(ctx, cm).Return( + reqMirror, nil, + ) + if !c.gitea { + mc.mocks.mirrorQueue.EXPECT().PushRepoMirror(&queue.MirrorTask{ + MirrorID: reqMirror.ID, + Priority: queue.PriorityMap[reqMirror.Priority], + }) + mc.mocks.stores.MirrorMock().EXPECT().Update(ctx, reqMirror).Return(nil) + } + + m, err := mc.CreateMirrorRepo(ctx, req) + require.Nil(t, err) + require.Equal(t, reqMirror, m) + + }) + } + +} + +func TestMirrorComponent_CheckMirrorProgress(t *testing.T) { + + for _, saas := range []bool{false, true} { + t.Run(fmt.Sprintf("saas %v", saas), func(t *testing.T) { + ctx := context.TODO() + mc := initializeTestMirrorComponent(ctx, t) + mc.saas = saas + + mirrors := []database.Mirror{ + { + ID: 1, MirrorTaskID: 11, + Repository: &database.Repository{ + ID: 111, Path: "foo/bar", RepositoryType: types.ModelRepo, + }, + }, + { + ID: 2, MirrorTaskID: 12, + Repository: &database.Repository{ + ID: 111, Path: "foo/bar", RepositoryType: types.ModelRepo, + }, + }, + { + ID: 3, MirrorTaskID: 13, + Repository: &database.Repository{ + ID: 111, Path: "foo/bar", RepositoryType: types.ModelRepo, + }, + }, + { + ID: 4, MirrorTaskID: 14, + Repository: &database.Repository{ + ID: 111, Path: "foo/bar", RepositoryType: types.ModelRepo, + }, + }, + } + mc.mocks.stores.MirrorMock().EXPECT().Unfinished(ctx).Return(mirrors, nil) + + if saas { + mc.mocks.mirrorServer.EXPECT().GetMirrorTaskInfo(ctx, int64(11)).Return( + &mirrorserver.MirrorTaskInfo{ + Status: mirrorserver.TaskStatusQueued, + }, nil, + ) + mc.mocks.mirrorServer.EXPECT().GetMirrorTaskInfo(ctx, int64(12)).Return( + &mirrorserver.MirrorTaskInfo{ + Status: mirrorserver.TaskStatusRunning, + }, nil, + ) + mc.mocks.mirrorServer.EXPECT().GetMirrorTaskInfo(ctx, int64(13)).Return( + &mirrorserver.MirrorTaskInfo{ + Status: mirrorserver.TaskStatusFailed, + }, nil, + ) + mc.mocks.mirrorServer.EXPECT().GetMirrorTaskInfo(ctx, int64(14)).Return( + &mirrorserver.MirrorTaskInfo{ + Status: mirrorserver.TaskStatusFinished, + }, nil, + ) + } else { + mc.mocks.gitServer.EXPECT().GetMirrorTaskInfo(ctx, int64(11)).Return( + &gitserver.MirrorTaskInfo{ + Status: gitserver.TaskStatusQueued, + }, nil, + ) + mc.mocks.gitServer.EXPECT().GetMirrorTaskInfo(ctx, int64(12)).Return( + &gitserver.MirrorTaskInfo{ + Status: gitserver.TaskStatusRunning, + }, nil, + ) + mc.mocks.gitServer.EXPECT().GetMirrorTaskInfo(ctx, int64(13)).Return( + &gitserver.MirrorTaskInfo{ + Status: gitserver.TaskStatusFailed, + }, nil, + ) + mc.mocks.gitServer.EXPECT().GetMirrorTaskInfo(ctx, int64(14)).Return( + &gitserver.MirrorTaskInfo{ + Status: gitserver.TaskStatusFinished, + }, nil, + ) + } + mirrors[0].Status = types.MirrorWaiting + mirrors[1].Status = types.MirrorRunning + mirrors[1].Progress = 100 + mirrors[2].Status = types.MirrorFailed + mirrors[3].Status = types.MirrorFinished + mirrors[3].Progress = 100 + mc.mocks.gitServer.EXPECT().GetRepo(ctx, gitserver.GetRepoReq{ + Namespace: "foo", + Name: "bar", + RepoType: types.ModelRepo, + }).Return(&gitserver.CreateRepoResp{}, nil) + for _, m := range mirrors { + m.Repository.SyncStatus = mirrorStatusAndRepoSyncStatusMapping[m.Status] + mv := m + mc.mocks.stores.MirrorMock().EXPECT().Update(ctx, &mv).Return(nil).Once() + mc.mocks.stores.RepoMock().EXPECT().UpdateRepo( + ctx, database.Repository{ + ID: 111, + Path: "foo/bar", + RepositoryType: types.ModelRepo, + SyncStatus: mirrorStatusAndRepoSyncStatusMapping[mv.Status], + }, + ).Return(nil, nil).Once() + } + mc.mocks.gitServer.EXPECT().GetRepoFileTree( + mock.Anything, gitserver.GetRepoInfoByPathReq{ + Namespace: "foo", Name: "bar", RepoType: "model"}, + ).Return([]*types.File{{Name: "foo.go"}}, nil) + + err := mc.CheckMirrorProgress(ctx) + require.Nil(t, err) + }) + } + +} + +func TestMirrorComponent_Repos(t *testing.T) { + ctx := context.TODO() + mc := initializeTestMirrorComponent(ctx, t) + + mc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + RoleMask: "admin", + }, nil) + mc.mocks.stores.RepoMock().EXPECT().WithMirror(ctx, 10, 1).Return([]database.Repository{ + {Path: "foo", SyncStatus: types.SyncStatusCompleted, RepositoryType: types.ModelRepo}, + }, 100, nil) + + data, total, err := mc.Repos(ctx, "user", 10, 1) + require.Nil(t, err) + require.Equal(t, 100, total) + require.Equal(t, []types.MirrorRepo{ + {Path: "foo", SyncStatus: types.SyncStatusCompleted, RepoType: types.ModelRepo}, + }, data) +} + +func TestMirrorComponent_Index(t *testing.T) { + ctx := context.TODO() + mc := initializeTestMirrorComponent(ctx, t) + + mc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + RoleMask: "admin", + }, nil) + mc.mocks.stores.MirrorMock().EXPECT().IndexWithPagination(ctx, 10, 1, "foo").Return( + []database.Mirror{{Username: "user", LastMessage: "msg", Repository: &database.Repository{}}}, 100, nil, + ) + + data, total, err := mc.Index(ctx, "user", 10, 1, "foo") + require.Nil(t, err) + require.Equal(t, 100, total) + require.Equal(t, []types.Mirror{ + {Username: "user", LastMessage: "msg", LocalRepoPath: "s/"}, + }, data) +} + +func TestMirrorComponent_Statistic(t *testing.T) { + ctx := context.TODO() + mc := initializeTestMirrorComponent(ctx, t) + + mc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + RoleMask: "admin", + }, nil) + mc.mocks.stores.MirrorMock().EXPECT().StatusCount(ctx).Return([]database.MirrorStatusCount{ + {Status: types.MirrorFinished, Count: 100}, + }, nil) + + s, err := mc.Statistics(ctx, "user") + require.Nil(t, err) + require.Equal(t, []types.MirrorStatusCount{ + {Status: types.MirrorFinished, Count: 100}, + }, s) + +} From c2bff3c021229c272b1e407cbd6d8caac2de77bd Mon Sep 17 00:00:00 2001 From: SeanHH86 <154984842+SeanHH86@users.noreply.github.com> Date: Wed, 18 Dec 2024 11:10:35 +0800 Subject: [PATCH 11/34] [Tags] add tag management feature (#214) * [Tags] add tag management feature --------- Co-authored-by: Haihui.Wang --- .../builder/store/database/mock_TagStore.go | 165 +++++++++++ .../component/mock_TagComponent.go | 229 +++++++++++++++ api/handler/tag.go | 150 +++++++++- api/handler/tag_test.go | 197 +++++++++++++ api/router/api.go | 17 +- builder/store/database/tag.go | 47 +++ builder/store/database/tag_test.go | 83 ++++++ common/types/tag.go | 15 +- component/tag.go | 104 +++++++ component/tag_test.go | 147 ++++++++++ component/wireset.go | 1 + docs/docs.go | 268 +++++++++++++++++- docs/swagger.json | 268 +++++++++++++++++- docs/swagger.yaml | 174 +++++++++++- 14 files changed, 1844 insertions(+), 21 deletions(-) create mode 100644 api/handler/tag_test.go diff --git a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_TagStore.go b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_TagStore.go index f2638d83..6a02617a 100644 --- a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_TagStore.go +++ b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_TagStore.go @@ -841,6 +841,53 @@ func (_c *MockTagStore_CreateTag_Call) RunAndReturn(run func(context.Context, st return _c } +// DeleteTagByID provides a mock function with given fields: ctx, id +func (_m *MockTagStore) DeleteTagByID(ctx context.Context, id int64) error { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteTagByID") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockTagStore_DeleteTagByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteTagByID' +type MockTagStore_DeleteTagByID_Call struct { + *mock.Call +} + +// DeleteTagByID is a helper method to define mock.On call +// - ctx context.Context +// - id int64 +func (_e *MockTagStore_Expecter) DeleteTagByID(ctx interface{}, id interface{}) *MockTagStore_DeleteTagByID_Call { + return &MockTagStore_DeleteTagByID_Call{Call: _e.mock.On("DeleteTagByID", ctx, id)} +} + +func (_c *MockTagStore_DeleteTagByID_Call) Run(run func(ctx context.Context, id int64)) *MockTagStore_DeleteTagByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockTagStore_DeleteTagByID_Call) Return(_a0 error) *MockTagStore_DeleteTagByID_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTagStore_DeleteTagByID_Call) RunAndReturn(run func(context.Context, int64) error) *MockTagStore_DeleteTagByID_Call { + _c.Call.Return(run) + return _c +} + // FindOrCreate provides a mock function with given fields: ctx, tag func (_m *MockTagStore) FindOrCreate(ctx context.Context, tag database.Tag) (*database.Tag, error) { ret := _m.Called(ctx, tag) @@ -961,6 +1008,65 @@ func (_c *MockTagStore_FindTag_Call) RunAndReturn(run func(context.Context, stri return _c } +// FindTagByID provides a mock function with given fields: ctx, id +func (_m *MockTagStore) FindTagByID(ctx context.Context, id int64) (*database.Tag, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for FindTagByID") + } + + var r0 *database.Tag + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int64) (*database.Tag, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, int64) *database.Tag); ok { + r0 = rf(ctx, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.Tag) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockTagStore_FindTagByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'FindTagByID' +type MockTagStore_FindTagByID_Call struct { + *mock.Call +} + +// FindTagByID is a helper method to define mock.On call +// - ctx context.Context +// - id int64 +func (_e *MockTagStore_Expecter) FindTagByID(ctx interface{}, id interface{}) *MockTagStore_FindTagByID_Call { + return &MockTagStore_FindTagByID_Call{Call: _e.mock.On("FindTagByID", ctx, id)} +} + +func (_c *MockTagStore_FindTagByID_Call) Run(run func(ctx context.Context, id int64)) *MockTagStore_FindTagByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockTagStore_FindTagByID_Call) Return(_a0 *database.Tag, _a1 error) *MockTagStore_FindTagByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockTagStore_FindTagByID_Call) RunAndReturn(run func(context.Context, int64) (*database.Tag, error)) *MockTagStore_FindTagByID_Call { + _c.Call.Return(run) + return _c +} + // GetTagsByScopeAndCategories provides a mock function with given fields: ctx, scope, categories func (_m *MockTagStore) GetTagsByScopeAndCategories(ctx context.Context, scope database.TagScope, categories []string) ([]*database.Tag, error) { ret := _m.Called(ctx, scope, categories) @@ -1229,6 +1335,65 @@ func (_c *MockTagStore_SetMetaTags_Call) RunAndReturn(run func(context.Context, return _c } +// UpdateTagByID provides a mock function with given fields: ctx, tag +func (_m *MockTagStore) UpdateTagByID(ctx context.Context, tag *database.Tag) (*database.Tag, error) { + ret := _m.Called(ctx, tag) + + if len(ret) == 0 { + panic("no return value specified for UpdateTagByID") + } + + var r0 *database.Tag + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *database.Tag) (*database.Tag, error)); ok { + return rf(ctx, tag) + } + if rf, ok := ret.Get(0).(func(context.Context, *database.Tag) *database.Tag); ok { + r0 = rf(ctx, tag) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.Tag) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *database.Tag) error); ok { + r1 = rf(ctx, tag) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockTagStore_UpdateTagByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateTagByID' +type MockTagStore_UpdateTagByID_Call struct { + *mock.Call +} + +// UpdateTagByID is a helper method to define mock.On call +// - ctx context.Context +// - tag *database.Tag +func (_e *MockTagStore_Expecter) UpdateTagByID(ctx interface{}, tag interface{}) *MockTagStore_UpdateTagByID_Call { + return &MockTagStore_UpdateTagByID_Call{Call: _e.mock.On("UpdateTagByID", ctx, tag)} +} + +func (_c *MockTagStore_UpdateTagByID_Call) Run(run func(ctx context.Context, tag *database.Tag)) *MockTagStore_UpdateTagByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*database.Tag)) + }) + return _c +} + +func (_c *MockTagStore_UpdateTagByID_Call) Return(_a0 *database.Tag, _a1 error) *MockTagStore_UpdateTagByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockTagStore_UpdateTagByID_Call) RunAndReturn(run func(context.Context, *database.Tag) (*database.Tag, error)) *MockTagStore_UpdateTagByID_Call { + _c.Call.Return(run) + return _c +} + // UpsertRepoTags provides a mock function with given fields: ctx, repoID, oldTagIDs, newTagIDs func (_m *MockTagStore) UpsertRepoTags(ctx context.Context, repoID int64, oldTagIDs []int64, newTagIDs []int64) error { ret := _m.Called(ctx, repoID, oldTagIDs, newTagIDs) diff --git a/_mocks/opencsg.com/csghub-server/component/mock_TagComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_TagComponent.go index 05021f20..d6e7c76a 100644 --- a/_mocks/opencsg.com/csghub-server/component/mock_TagComponent.go +++ b/_mocks/opencsg.com/csghub-server/component/mock_TagComponent.go @@ -133,6 +133,174 @@ func (_c *MockTagComponent_ClearMetaTags_Call) RunAndReturn(run func(context.Con return _c } +// CreateTag provides a mock function with given fields: ctx, username, req +func (_m *MockTagComponent) CreateTag(ctx context.Context, username string, req types.CreateTag) (*database.Tag, error) { + ret := _m.Called(ctx, username, req) + + if len(ret) == 0 { + panic("no return value specified for CreateTag") + } + + var r0 *database.Tag + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, types.CreateTag) (*database.Tag, error)); ok { + return rf(ctx, username, req) + } + if rf, ok := ret.Get(0).(func(context.Context, string, types.CreateTag) *database.Tag); ok { + r0 = rf(ctx, username, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.Tag) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, types.CreateTag) error); ok { + r1 = rf(ctx, username, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockTagComponent_CreateTag_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateTag' +type MockTagComponent_CreateTag_Call struct { + *mock.Call +} + +// CreateTag is a helper method to define mock.On call +// - ctx context.Context +// - username string +// - req types.CreateTag +func (_e *MockTagComponent_Expecter) CreateTag(ctx interface{}, username interface{}, req interface{}) *MockTagComponent_CreateTag_Call { + return &MockTagComponent_CreateTag_Call{Call: _e.mock.On("CreateTag", ctx, username, req)} +} + +func (_c *MockTagComponent_CreateTag_Call) Run(run func(ctx context.Context, username string, req types.CreateTag)) *MockTagComponent_CreateTag_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(types.CreateTag)) + }) + return _c +} + +func (_c *MockTagComponent_CreateTag_Call) Return(_a0 *database.Tag, _a1 error) *MockTagComponent_CreateTag_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockTagComponent_CreateTag_Call) RunAndReturn(run func(context.Context, string, types.CreateTag) (*database.Tag, error)) *MockTagComponent_CreateTag_Call { + _c.Call.Return(run) + return _c +} + +// DeleteTag provides a mock function with given fields: ctx, username, id +func (_m *MockTagComponent) DeleteTag(ctx context.Context, username string, id int64) error { + ret := _m.Called(ctx, username, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteTag") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, int64) error); ok { + r0 = rf(ctx, username, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockTagComponent_DeleteTag_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteTag' +type MockTagComponent_DeleteTag_Call struct { + *mock.Call +} + +// DeleteTag is a helper method to define mock.On call +// - ctx context.Context +// - username string +// - id int64 +func (_e *MockTagComponent_Expecter) DeleteTag(ctx interface{}, username interface{}, id interface{}) *MockTagComponent_DeleteTag_Call { + return &MockTagComponent_DeleteTag_Call{Call: _e.mock.On("DeleteTag", ctx, username, id)} +} + +func (_c *MockTagComponent_DeleteTag_Call) Run(run func(ctx context.Context, username string, id int64)) *MockTagComponent_DeleteTag_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(int64)) + }) + return _c +} + +func (_c *MockTagComponent_DeleteTag_Call) Return(_a0 error) *MockTagComponent_DeleteTag_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTagComponent_DeleteTag_Call) RunAndReturn(run func(context.Context, string, int64) error) *MockTagComponent_DeleteTag_Call { + _c.Call.Return(run) + return _c +} + +// GetTagByID provides a mock function with given fields: ctx, username, id +func (_m *MockTagComponent) GetTagByID(ctx context.Context, username string, id int64) (*database.Tag, error) { + ret := _m.Called(ctx, username, id) + + if len(ret) == 0 { + panic("no return value specified for GetTagByID") + } + + var r0 *database.Tag + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, int64) (*database.Tag, error)); ok { + return rf(ctx, username, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string, int64) *database.Tag); ok { + r0 = rf(ctx, username, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.Tag) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, int64) error); ok { + r1 = rf(ctx, username, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockTagComponent_GetTagByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetTagByID' +type MockTagComponent_GetTagByID_Call struct { + *mock.Call +} + +// GetTagByID is a helper method to define mock.On call +// - ctx context.Context +// - username string +// - id int64 +func (_e *MockTagComponent_Expecter) GetTagByID(ctx interface{}, username interface{}, id interface{}) *MockTagComponent_GetTagByID_Call { + return &MockTagComponent_GetTagByID_Call{Call: _e.mock.On("GetTagByID", ctx, username, id)} +} + +func (_c *MockTagComponent_GetTagByID_Call) Run(run func(ctx context.Context, username string, id int64)) *MockTagComponent_GetTagByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(int64)) + }) + return _c +} + +func (_c *MockTagComponent_GetTagByID_Call) Return(_a0 *database.Tag, _a1 error) *MockTagComponent_GetTagByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockTagComponent_GetTagByID_Call) RunAndReturn(run func(context.Context, string, int64) (*database.Tag, error)) *MockTagComponent_GetTagByID_Call { + _c.Call.Return(run) + return _c +} + // UpdateLibraryTags provides a mock function with given fields: ctx, tagScope, namespace, name, oldFilePath, newFilePath func (_m *MockTagComponent) UpdateLibraryTags(ctx context.Context, tagScope database.TagScope, namespace string, name string, oldFilePath string, newFilePath string) error { ret := _m.Called(ctx, tagScope, namespace, name, oldFilePath, newFilePath) @@ -296,6 +464,67 @@ func (_c *MockTagComponent_UpdateRepoTagsByCategory_Call) RunAndReturn(run func( return _c } +// UpdateTag provides a mock function with given fields: ctx, username, id, req +func (_m *MockTagComponent) UpdateTag(ctx context.Context, username string, id int64, req types.UpdateTag) (*database.Tag, error) { + ret := _m.Called(ctx, username, id, req) + + if len(ret) == 0 { + panic("no return value specified for UpdateTag") + } + + var r0 *database.Tag + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, int64, types.UpdateTag) (*database.Tag, error)); ok { + return rf(ctx, username, id, req) + } + if rf, ok := ret.Get(0).(func(context.Context, string, int64, types.UpdateTag) *database.Tag); ok { + r0 = rf(ctx, username, id, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.Tag) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, int64, types.UpdateTag) error); ok { + r1 = rf(ctx, username, id, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockTagComponent_UpdateTag_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateTag' +type MockTagComponent_UpdateTag_Call struct { + *mock.Call +} + +// UpdateTag is a helper method to define mock.On call +// - ctx context.Context +// - username string +// - id int64 +// - req types.UpdateTag +func (_e *MockTagComponent_Expecter) UpdateTag(ctx interface{}, username interface{}, id interface{}, req interface{}) *MockTagComponent_UpdateTag_Call { + return &MockTagComponent_UpdateTag_Call{Call: _e.mock.On("UpdateTag", ctx, username, id, req)} +} + +func (_c *MockTagComponent_UpdateTag_Call) Run(run func(ctx context.Context, username string, id int64, req types.UpdateTag)) *MockTagComponent_UpdateTag_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(int64), args[3].(types.UpdateTag)) + }) + return _c +} + +func (_c *MockTagComponent_UpdateTag_Call) Return(_a0 *database.Tag, _a1 error) *MockTagComponent_UpdateTag_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockTagComponent_UpdateTag_Call) RunAndReturn(run func(context.Context, string, int64, types.UpdateTag) (*database.Tag, error)) *MockTagComponent_UpdateTag_Call { + _c.Call.Return(run) + return _c +} + // NewMockTagComponent creates a new instance of MockTagComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockTagComponent(t interface { diff --git a/api/handler/tag.go b/api/handler/tag.go index e4e62ca2..53a10ebc 100644 --- a/api/handler/tag.go +++ b/api/handler/tag.go @@ -3,10 +3,12 @@ package handler import ( "log/slog" "net/http" + "strconv" "github.com/gin-gonic/gin" "opencsg.com/csghub-server/api/httpbase" "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" "opencsg.com/csghub-server/component" ) @@ -27,13 +29,14 @@ type TagsHandler struct { // GetAllTags godoc // @Security ApiKey // @Summary Get all tags -// @Description get all tags +// @Description Get all tags // @Tags Tag // @Accept json // @Produce json // @Param category query string false "category name" // @Param scope query string false "scope name" Enums(model, dataset) -// @Success 200 {object} types.ResponseWithTotal{data=[]database.Tag,total=int} "tags" +// @Success 200 {object} types.ResponseWithTotal{data=[]database.Tag} "tags" +// @Failure 400 {object} types.APIBadRequest "Bad request" // @Failure 500 {object} types.APIInternalServerError "Internal server error" // @Router /tags [get] func (t *TagsHandler) AllTags(ctx *gin.Context) { @@ -42,14 +45,151 @@ func (t *TagsHandler) AllTags(ctx *gin.Context) { scope := ctx.Query("scope") tags, err := t.tc.AllTagsByScopeAndCategory(ctx, scope, category) if err != nil { - slog.Error("Failed to load tags", "error", err) + slog.Error("Failed to load tags", slog.Any("category", category), slog.Any("scope", scope), slog.Any("error", err)) httpbase.ServerError(ctx, err) return } respData := gin.H{ "data": tags, } - - slog.Info("Tags loaded successfully", "count", len(tags)) ctx.JSON(http.StatusOK, respData) } + +// CreateTag godoc +// @Security ApiKey +// @Summary Create new tag +// @Description Create new tag +// @Tags Tag +// @Accept json +// @Produce json +// @Param body body types.CreateTag true "body" +// @Success 200 {object} types.Response{database.Tag} "OK" +// @Failure 400 {object} types.APIBadRequest "Bad request" +// @Failure 500 {object} types.APIInternalServerError "Internal server error" +// @Router /tags [post] +func (t *TagsHandler) CreateTag(ctx *gin.Context) { + userName := httpbase.GetCurrentUser(ctx) + if userName == "" { + httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) + return + } + var req types.CreateTag + if err := ctx.ShouldBindJSON(&req); err != nil { + slog.Error("Bad request format", slog.Any("error", err)) + httpbase.BadRequest(ctx, err.Error()) + return + } + tag, err := t.tc.CreateTag(ctx, userName, req) + if err != nil { + slog.Error("Failed to create tag", slog.Any("req", req), slog.Any("error", err)) + httpbase.ServerError(ctx, err) + return + } + ctx.JSON(http.StatusOK, gin.H{"data": tag}) +} + +// GetTag godoc +// @Security ApiKey +// @Summary Get a tag by id +// @Description Get a tag by id +// @Tags Tag +// @Accept json +// @Produce json +// @Param id path string true "id of the tag" +// @Success 200 {object} types.Response{database.Tag} "OK" +// @Failure 400 {object} types.APIBadRequest "Bad request" +// @Failure 500 {object} types.APIInternalServerError "Internal server error" +// @Router /tag/{id} [get] +func (t *TagsHandler) GetTagByID(ctx *gin.Context) { + userName := httpbase.GetCurrentUser(ctx) + if userName == "" { + httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) + return + } + id, err := strconv.ParseInt(ctx.Param("id"), 10, 64) + if err != nil { + slog.Error("Bad request format", slog.Any("error", err)) + httpbase.BadRequest(ctx, err.Error()) + return + } + tag, err := t.tc.GetTagByID(ctx, userName, id) + if err != nil { + slog.Error("Failed to get tag", slog.Int64("id", id), slog.Any("error", err)) + httpbase.ServerError(ctx, err) + return + } + ctx.JSON(http.StatusOK, gin.H{"data": tag}) +} + +// UpdateTag godoc +// @Security ApiKey +// @Summary Update a tag by id +// @Description Update a tag by id +// @Tags Tag +// @Accept json +// @Produce json +// @Param id path string true "id of the tag" +// @Param body body types.UpdateTag true "body" +// @Success 200 {object} types.Response{database.Tag} "OK" +// @Failure 400 {object} types.APIBadRequest "Bad request" +// @Failure 500 {object} types.APIInternalServerError "Internal server error" +// @Router /tag/{id} [put] +func (t *TagsHandler) UpdateTag(ctx *gin.Context) { + userName := httpbase.GetCurrentUser(ctx) + if userName == "" { + httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) + return + } + id, err := strconv.ParseInt(ctx.Param("id"), 10, 64) + if err != nil { + slog.Error("Bad request format", slog.Any("error", err)) + httpbase.BadRequest(ctx, err.Error()) + return + } + var req types.UpdateTag + if err := ctx.ShouldBindJSON(&req); err != nil { + slog.Error("Bad request format", slog.Any("error", err)) + httpbase.BadRequest(ctx, err.Error()) + return + } + tag, err := t.tc.UpdateTag(ctx, userName, id, req) + if err != nil { + slog.Error("Failed to update tag", slog.Int64("id", id), slog.Any("error", err)) + httpbase.ServerError(ctx, err) + return + } + ctx.JSON(http.StatusOK, gin.H{"data": tag}) +} + +// DeleteTag godoc +// @Security ApiKey +// @Summary Delete a tag by id +// @Description Delete a tag by id +// @Tags Tag +// @Accept json +// @Produce json +// @Param id path string true "id of the tag" +// @Success 200 {object} types.Response{} "OK" +// @Failure 400 {object} types.APIBadRequest "Bad request" +// @Failure 500 {object} types.APIInternalServerError "Internal server error" +// @Router /tag/{id} [delete] +func (t *TagsHandler) DeleteTag(ctx *gin.Context) { + userName := httpbase.GetCurrentUser(ctx) + if userName == "" { + httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) + return + } + id, err := strconv.ParseInt(ctx.Param("id"), 10, 64) + if err != nil { + slog.Error("Bad request format", slog.Any("error", err)) + httpbase.BadRequest(ctx, err.Error()) + return + } + err = t.tc.DeleteTag(ctx, userName, id) + if err != nil { + slog.Error("Failed to delete tag", slog.Int64("id", id), slog.Any("error", err)) + httpbase.ServerError(ctx, err) + return + } + ctx.JSON(http.StatusOK, nil) +} diff --git a/api/handler/tag_test.go b/api/handler/tag_test.go new file mode 100644 index 00000000..ed47c226 --- /dev/null +++ b/api/handler/tag_test.go @@ -0,0 +1,197 @@ +package handler + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + mockcom "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/api/httpbase" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" + "opencsg.com/csghub-server/component" +) + +func NewTestTagHandler( + tagComp component.TagComponent, +) (*TagsHandler, error) { + return &TagsHandler{ + tc: tagComp, + }, nil +} + +func TestTagHandler_AllTags(t *testing.T) { + var tags []*database.Tag + tags = append(tags, &database.Tag{ID: 1, Name: "test1"}) + + values := url.Values{} + values.Add("category", "testcate") + values.Add("scope", "testscope") + req := httptest.NewRequest("get", "/api/v1/tags?"+values.Encode(), nil) + + hr := httptest.NewRecorder() + ginContext, _ := gin.CreateTestContext(hr) + ginContext.Request = req + + tagComp := mockcom.NewMockTagComponent(t) + tagComp.EXPECT().AllTagsByScopeAndCategory(ginContext, "testscope", "testcate").Return(tags, nil) + + tagHandler, err := NewTestTagHandler(tagComp) + require.Nil(t, err) + + tagHandler.AllTags(ginContext) + + require.Equal(t, http.StatusOK, hr.Code) + + var resp httpbase.R + + err = json.Unmarshal(hr.Body.Bytes(), &resp) + require.Nil(t, err) + + require.Equal(t, 0, resp.Code) + require.Equal(t, "", resp.Msg) + require.NotNil(t, resp.Data) +} + +func TestTagHandler_CreateTag(t *testing.T) { + username := "testuser" + data := types.CreateTag{ + Name: "testtag", + Scope: "testscope", + Category: "testcategory", + } + + reqBody, _ := json.Marshal(data) + + req := httptest.NewRequest("post", "/api/v1/tags", bytes.NewBuffer(reqBody)) + + hr := httptest.NewRecorder() + ginContext, _ := gin.CreateTestContext(hr) + ginContext.Set("currentUser", username) + ginContext.Request = req + + tagComp := mockcom.NewMockTagComponent(t) + tagComp.EXPECT().CreateTag(ginContext, username, mock.Anything).Return(&database.Tag{ID: 1, Name: "testtag"}, nil) + + tagHandler, err := NewTestTagHandler(tagComp) + require.Nil(t, err) + + tagHandler.CreateTag(ginContext) + + require.Equal(t, http.StatusOK, hr.Code) + + var resp httpbase.R + + err = json.Unmarshal(hr.Body.Bytes(), &resp) + require.Nil(t, err) + + require.Equal(t, 0, resp.Code) + require.Equal(t, "", resp.Msg) + require.NotNil(t, resp.Data) +} + +func TestTagHandler_GetTagByID(t *testing.T) { + username := "testuser" + + req := httptest.NewRequest("get", "/api/v1/tags/1", nil) + + hr := httptest.NewRecorder() + ginContext, _ := gin.CreateTestContext(hr) + ginContext.Set("currentUser", username) + ginContext.AddParam("id", "1") + ginContext.Request = req + + tagComp := mockcom.NewMockTagComponent(t) + tagComp.EXPECT().GetTagByID(ginContext, username, int64(1)).Return(&database.Tag{ID: 1, Name: "test1"}, nil) + + tagHandler, err := NewTestTagHandler(tagComp) + require.Nil(t, err) + + tagHandler.GetTagByID(ginContext) + + require.Equal(t, http.StatusOK, hr.Code) + + var resp httpbase.R + + err = json.Unmarshal(hr.Body.Bytes(), &resp) + require.Nil(t, err) + + require.Equal(t, 0, resp.Code) + require.Equal(t, "", resp.Msg) + require.NotNil(t, resp.Data) +} + +func TestTagHandler_UpdateTag(t *testing.T) { + username := "testuser" + data := types.UpdateTag{ + Name: "testtag", + Scope: "testscope", + Category: "testcategory", + } + + reqBody, _ := json.Marshal(data) + + req := httptest.NewRequest("put", "/api/v1/tags/1", bytes.NewBuffer(reqBody)) + + hr := httptest.NewRecorder() + ginContext, _ := gin.CreateTestContext(hr) + ginContext.Set("currentUser", username) + ginContext.AddParam("id", "1") + ginContext.Request = req + + tagComp := mockcom.NewMockTagComponent(t) + tagComp.EXPECT().UpdateTag(ginContext, username, int64(1), mock.Anything).Return(&database.Tag{ID: 1, Name: "testtag"}, nil) + + tagHandler, err := NewTestTagHandler(tagComp) + require.Nil(t, err) + + tagHandler.UpdateTag(ginContext) + + require.Equal(t, http.StatusOK, hr.Code) + + var resp httpbase.R + + err = json.Unmarshal(hr.Body.Bytes(), &resp) + require.Nil(t, err) + + require.Equal(t, 0, resp.Code) + require.Equal(t, "", resp.Msg) + require.NotNil(t, resp.Data) +} + +func TestTagHandler_DeleteTag(t *testing.T) { + username := "testuser" + + req := httptest.NewRequest("delete", "/api/v1/tags/1", nil) + + hr := httptest.NewRecorder() + ginContext, _ := gin.CreateTestContext(hr) + ginContext.Set("currentUser", username) + ginContext.AddParam("id", "1") + ginContext.Request = req + + tagComp := mockcom.NewMockTagComponent(t) + tagComp.EXPECT().DeleteTag(ginContext, username, int64(1)).Return(nil) + + tagHandler, err := NewTestTagHandler(tagComp) + require.Nil(t, err) + + tagHandler.DeleteTag(ginContext) + + require.Equal(t, http.StatusOK, hr.Code) + + var resp httpbase.R + + err = json.Unmarshal(hr.Body.Bytes(), &resp) + require.Nil(t, err) + + require.Equal(t, 0, resp.Code) + require.Equal(t, "", resp.Msg) + require.Nil(t, resp.Data) +} diff --git a/api/router/api.go b/api/router/api.go index 9b01c3e8..3ba007c5 100644 --- a/api/router/api.go +++ b/api/router/api.go @@ -222,15 +222,13 @@ func NewRouter(config *config.Config, enableSwagger bool) (*gin.Engine, error) { apiGroup.PUT("/organization/:namespace/members/:username", userProxyHandler.ProxyToApi("/api/v1/organization/%s/members/%s", "namespace", "username")) apiGroup.DELETE("/organization/:namespace/members/:username", userProxyHandler.ProxyToApi("/api/v1/organization/%s/members/%s", "namespace", "username")) } + // Tag tagCtrl, err := handler.NewTagHandler(config) if err != nil { return nil, fmt.Errorf("error creating tag controller:%w", err) } - apiGroup.GET("/tags", tagCtrl.AllTags) - // apiGroup.POST("/tag", tagCtrl.NewTag) - // apiGroup.PUT("/tag", tagCtrl.UpdateTag) - // apiGroup.DELETE("/tag", tagCtrl.DeleteTag) + createTagsRoutes(apiGroup, tagCtrl) // JWT token apiGroup.POST("/jwt/token", needAPIKey, userProxyHandler.Proxy) @@ -783,3 +781,14 @@ func createPromptRoutes(apiGroup *gin.RouterGroup, promptHandler *handler.Prompt promptGrp.POST("/:namespace/:name/update_downloads", promptHandler.UpdateDownloads) } } + +func createTagsRoutes(apiGroup *gin.RouterGroup, tagHandler *handler.TagsHandler) { + tagsGrp := apiGroup.Group("/tags") + { + tagsGrp.GET("", tagHandler.AllTags) + tagsGrp.POST("", tagHandler.CreateTag) + tagsGrp.GET("/:id", tagHandler.GetTagByID) + tagsGrp.PUT("/:id", tagHandler.UpdateTag) + tagsGrp.DELETE("/:id", tagHandler.DeleteTag) + } +} diff --git a/builder/store/database/tag.go b/builder/store/database/tag.go index e0ced7b6..2194afc4 100644 --- a/builder/store/database/tag.go +++ b/builder/store/database/tag.go @@ -40,6 +40,9 @@ type TagStore interface { RemoveRepoTags(ctx context.Context, repoID int64, tagIDs []int64) (err error) FindOrCreate(ctx context.Context, tag Tag) (*Tag, error) FindTag(ctx context.Context, name, scope, category string) (*Tag, error) + FindTagByID(ctx context.Context, id int64) (*Tag, error) + UpdateTagByID(ctx context.Context, tag *Tag) (*Tag, error) + DeleteTagByID(ctx context.Context, id int64) error } func NewTagStore() TagStore { @@ -393,3 +396,47 @@ func (ts *tagStoreImpl) FindTag(ctx context.Context, name, scope, category strin } return &tag, nil } + +// find tag by id +func (ts *tagStoreImpl) FindTagByID(ctx context.Context, id int64) (*Tag, error) { + var tag Tag + err := ts.db.Operator.Core.NewSelect(). + Model(&tag). + Where("id = ?", id). + Scan(ctx) + if err != nil { + return nil, fmt.Errorf("select tag by id %d error: %w", id, err) + } + return &tag, nil +} + +func (ts *tagStoreImpl) UpdateTagByID(ctx context.Context, tag *Tag) (*Tag, error) { + _, err := ts.db.Operator.Core.NewUpdate(). + Model(tag).WherePK().Exec(ctx) + if err != nil { + return nil, fmt.Errorf("update tag by id %d error: %w", tag.ID, err) + } + return tag, nil +} + +func (ts *tagStoreImpl) DeleteTagByID(ctx context.Context, id int64) error { + err := ts.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewDelete(). + Model(&Tag{}). + Where("id = ?", id). + Exec(ctx) + if err != nil { + return fmt.Errorf("delete tag by id %d error: %w", id, err) + } + _, err = tx.NewDelete(). + Model(&RepositoryTag{}). + Where("tag_id = ?", id). + Exec(ctx) + if err != nil { + return fmt.Errorf("delete repository_tag by tag_id %d error: %w", id, err) + } + return nil + }) + return err + +} diff --git a/builder/store/database/tag_test.go b/builder/store/database/tag_test.go index 777b4f4c..27d3bc0f 100644 --- a/builder/store/database/tag_test.go +++ b/builder/store/database/tag_test.go @@ -542,3 +542,86 @@ func TestTagStore_RemoveRepoTags(t *testing.T) { require.EqualValues(t, repoTags[0], tag1) } + +func TestTagStore_FindTagByID(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + var err error + ts := database.NewTagStoreWithDB(db) + t1, err := ts.CreateTag(ctx, "task", "tag_"+uuid.NewString(), "", database.ModelTagScope) + require.Empty(t, err) + require.NotEmpty(t, t1.ID) + + tag, err := ts.FindTagByID(ctx, t1.ID) + require.Empty(t, err) + require.Equal(t, tag.ID, t1.ID) + require.Equal(t, tag.Name, t1.Name) +} + +func TestTagStore_UpdateTagByID(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + ts := database.NewTagStoreWithDB(db) + + t1, err := ts.CreateTag(ctx, "task", "tag_"+uuid.NewString(), "", database.ModelTagScope) + require.Empty(t, err) + require.NotEmpty(t, t1.ID) + + newName := "new_tag_" + uuid.NewString() + + t1.Name = newName + + tag, err := ts.UpdateTagByID(ctx, &t1) + require.Empty(t, err) + require.Equal(t, tag.Name, newName) + +} + +func TestTagStore_DeleteTagByID(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + ts := database.NewTagStoreWithDB(db) + tag1, err := ts.CreateTag(ctx, "task", "tag_"+uuid.NewString(), "", database.ModelTagScope) + require.Empty(t, err) + require.NotEmpty(t, tag1.ID) + + // insert a new repo with tags + rs := database.NewRepoStoreWithDB(db) + userName := "user_name_" + uuid.NewString() + repoName := "repo_name_" + uuid.NewString() + repo, err := rs.CreateRepo(ctx, database.Repository{ + UserID: 1, + Path: fmt.Sprintf("%s/%s", userName, repoName), + GitPath: fmt.Sprintf("models_%s/%s", userName, repoName), + Name: repoName, + Nickname: "", + Description: "", + Private: false, + RepositoryType: types.ModelRepo, + }) + require.Empty(t, err) + require.NotNil(t, repo) + + // set repo tags + err = ts.UpsertRepoTags(ctx, repo.ID, []int64{}, []int64{tag1.ID}) + require.Empty(t, err) + + err = ts.DeleteTagByID(ctx, tag1.ID) + require.Empty(t, err) + + _, err = ts.FindTagByID(ctx, tag1.ID) + require.NotEmpty(t, err) + +} diff --git a/common/types/tag.go b/common/types/tag.go index 7e20074f..d85fa91c 100644 --- a/common/types/tag.go +++ b/common/types/tag.go @@ -1,6 +1,8 @@ package types -import "time" +import ( + "time" +) type RepoTag struct { Name string `json:"name"` @@ -22,3 +24,14 @@ const ( LanguageCategory TagCategory = "language" EvaluationCategory TagCategory = "evaluation" ) + +type CreateTag struct { + Name string `json:"name" binding:"required"` + Category string `json:"category" binding:"required"` + Group string `json:"group"` + Scope string `json:"scope" binding:"required"` + BuiltIn bool `json:"built_in"` + ShowName string `json:"show_name"` +} + +type UpdateTag CreateTag diff --git a/component/tag.go b/component/tag.go index 2c95e4f3..44c41506 100644 --- a/component/tag.go +++ b/component/tag.go @@ -20,6 +20,10 @@ type TagComponent interface { UpdateMetaTags(ctx context.Context, tagScope database.TagScope, namespace, name, content string) ([]*database.RepositoryTag, error) UpdateLibraryTags(ctx context.Context, tagScope database.TagScope, namespace, name, oldFilePath, newFilePath string) error UpdateRepoTagsByCategory(ctx context.Context, tagScope database.TagScope, repoID int64, category string, tagNames []string) error + CreateTag(ctx context.Context, username string, req types.CreateTag) (*database.Tag, error) + GetTagByID(ctx context.Context, username string, id int64) (*database.Tag, error) + UpdateTag(ctx context.Context, username string, id int64, req types.UpdateTag) (*database.Tag, error) + DeleteTag(ctx context.Context, username string, id int64) error } func NewTagComponent(config *config.Config) (TagComponent, error) { @@ -29,6 +33,7 @@ func NewTagComponent(config *config.Config) (TagComponent, error) { if config.SensitiveCheck.Enable { tc.sensitiveChecker = rpc.NewModerationSvcHttpClient(fmt.Sprintf("%s:%d", config.Moderation.Host, config.Moderation.Port)) } + tc.userStore = database.NewUserStore() return tc, nil } @@ -36,6 +41,7 @@ type tagComponentImpl struct { tagStore database.TagStore repoStore database.RepoStore sensitiveChecker rpc.ModerationSvcClient + userStore database.UserStore } func (tc *tagComponentImpl) AllTagsByScopeAndCategory(ctx context.Context, scope string, category string) ([]*database.Tag, error) { @@ -187,3 +193,101 @@ func (c *tagComponentImpl) UpdateRepoTagsByCategory(ctx context.Context, tagScop } return c.tagStore.UpsertRepoTags(ctx, repoID, oldTagIDs, tagIDs) } + +func (c *tagComponentImpl) CreateTag(ctx context.Context, username string, req types.CreateTag) (*database.Tag, error) { + user, err := c.userStore.FindByUsername(ctx, username) + if err != nil { + return nil, fmt.Errorf("failed to get user, error: %w", err) + } + if !user.CanAdmin() { + return nil, fmt.Errorf("user %s do not allowed create tag", username) + } + + if c.sensitiveChecker != nil { + result, err := c.sensitiveChecker.PassTextCheck(ctx, string(sensitive.ScenarioNicknameDetection), req.Name) + if err != nil { + return nil, fmt.Errorf("failed to check tag name sensitivity, error: %w", err) + } + if result.IsSensitive { + return nil, fmt.Errorf("tag name contains sensitive words") + } + } + + newTag := database.Tag{ + Name: req.Name, + Category: req.Category, + Group: req.Group, + Scope: database.TagScope(req.Scope), + BuiltIn: req.BuiltIn, + } + + tag, err := c.tagStore.FindOrCreate(ctx, newTag) + if err != nil { + return nil, fmt.Errorf("failed to create tag, error: %w", err) + } + return tag, nil +} + +func (c *tagComponentImpl) GetTagByID(ctx context.Context, username string, id int64) (*database.Tag, error) { + user, err := c.userStore.FindByUsername(ctx, username) + if err != nil { + return nil, fmt.Errorf("failed to get user, error: %w", err) + } + if !user.CanAdmin() { + return nil, fmt.Errorf("user %s do not allowed create tag", username) + } + tag, err := c.tagStore.FindTagByID(ctx, id) + if err != nil { + return nil, fmt.Errorf("failed to get tag id %d, error: %w", id, err) + } + return tag, nil +} + +func (c *tagComponentImpl) UpdateTag(ctx context.Context, username string, id int64, req types.UpdateTag) (*database.Tag, error) { + user, err := c.userStore.FindByUsername(ctx, username) + if err != nil { + return nil, fmt.Errorf("failed to get user, error: %w", err) + } + if !user.CanAdmin() { + return nil, fmt.Errorf("user %s do not allowed create tag", username) + } + + if c.sensitiveChecker != nil { + result, err := c.sensitiveChecker.PassTextCheck(ctx, string(sensitive.ScenarioNicknameDetection), req.Name) + if err != nil { + return nil, fmt.Errorf("failed to check tag name sensitivity, error: %w", err) + } + if result.IsSensitive { + return nil, fmt.Errorf("tag name contains sensitive words") + } + } + + tag := &database.Tag{ + ID: id, + Category: req.Category, + Name: req.Name, + Group: req.Group, + Scope: database.TagScope(req.Scope), + BuiltIn: req.BuiltIn, + } + newTag, err := c.tagStore.UpdateTagByID(ctx, tag) + if err != nil { + return nil, fmt.Errorf("failed to update tag id %d, error: %w", id, err) + } + return newTag, nil +} + +func (c *tagComponentImpl) DeleteTag(ctx context.Context, username string, id int64) error { + user, err := c.userStore.FindByUsername(ctx, username) + if err != nil { + return fmt.Errorf("failed to get user, error: %w", err) + } + if !user.CanAdmin() { + return fmt.Errorf("user %s do not allowed create tag", username) + } + err = c.tagStore.DeleteTagByID(ctx, id) + if err != nil { + return fmt.Errorf("failed to delete tag id %d, error: %w", id, err) + } + return nil +} diff --git a/component/tag_test.go b/component/tag_test.go index 2f53aa30..acfa168d 100644 --- a/component/tag_test.go +++ b/component/tag_test.go @@ -4,11 +4,158 @@ import ( "context" "testing" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/types" ) +func TestTagComponent_CreateTag(t *testing.T) { + ctx := context.TODO() + + username := "testUser" + + req := types.CreateTag{ + Name: "my first tag", + Category: "testCategory", + Group: "testGroup", + Scope: "testScope", + BuiltIn: true, + } + + newTag := database.Tag{ + Name: req.Name, + Category: req.Category, + Group: req.Group, + Scope: database.TagScope(req.Scope), + BuiltIn: req.BuiltIn, + } + + t.Run("admin", func(t *testing.T) { + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, username).Return(database.User{UUID: "testUUID", RoleMask: "admin"}, nil) + tc.mocks.stores.TagMock().EXPECT().FindOrCreate(ctx, newTag).Return(&newTag, nil) + tc.mocks.moderationClient.EXPECT().PassTextCheck(ctx, mock.Anything, req.Name).Return(&rpc.CheckResult{ + IsSensitive: false, + }, nil) + + tag, err := tc.CreateTag(ctx, username, req) + require.Nil(t, err) + require.Equal(t, req.Name, tag.Name) + require.Equal(t, true, tag.BuiltIn) + }) + + t.Run("non-admin", func(t *testing.T) { + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, username).Return(database.User{UUID: "testUUID", RoleMask: "persion"}, nil) + + tag, err := tc.CreateTag(ctx, username, req) + require.NotNil(t, err) + require.Nil(t, tag) + }) +} + +func TestTagComponent_GetTagByID(t *testing.T) { + ctx := context.TODO() + username := "testUser" + t.Run("admin", func(t *testing.T) { + tc := initializeTestTagComponent(ctx, t) + tc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, username).Return(database.User{UUID: "testUUID", RoleMask: "admin"}, nil) + tc.mocks.stores.TagMock().EXPECT().FindTagByID(ctx, int64(1)).Return(&database.Tag{ID: int64(1), Name: "test-tag"}, nil) + + tag, err := tc.GetTagByID(ctx, username, int64(1)) + require.Nil(t, err) + require.Equal(t, int64(1), tag.ID) + require.Equal(t, "test-tag", tag.Name) + }) + + t.Run("non-admin", func(t *testing.T) { + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, username).Return(database.User{UUID: "testUUID", RoleMask: "person"}, nil) + + tag, err := tc.GetTagByID(ctx, username, int64(1)) + require.NotNil(t, err) + require.Nil(t, tag) + }) +} + +func TestTagComponent_UpdateTag(t *testing.T) { + ctx := context.TODO() + + username := "testUser" + + req := types.UpdateTag{ + Name: "testTag", + Category: "testCategory", + Group: "testGroup", + Scope: "testScope", + BuiltIn: true, + } + + newTag := database.Tag{ + ID: int64(1), + Name: req.Name, + Category: req.Category, + Group: req.Group, + Scope: database.TagScope(req.Scope), + BuiltIn: req.BuiltIn, + } + + t.Run("admin", func(t *testing.T) { + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, username).Return(database.User{UUID: "testUUID", RoleMask: "admin"}, nil) + tc.mocks.stores.TagMock().EXPECT().UpdateTagByID(ctx, &newTag).Return(&newTag, nil) + tc.mocks.moderationClient.EXPECT().PassTextCheck(ctx, mock.Anything, req.Name).Return(&rpc.CheckResult{ + IsSensitive: false, + }, nil) + + tag, err := tc.UpdateTag(ctx, username, int64(1), req) + require.Nil(t, err) + require.Equal(t, req.Name, tag.Name) + require.Equal(t, true, tag.BuiltIn) + }) + + t.Run("non-admin", func(t *testing.T) { + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, username).Return(database.User{UUID: "testUUID", RoleMask: "persion"}, nil) + + tag, err := tc.UpdateTag(ctx, username, int64(1), req) + require.NotNil(t, err) + require.Nil(t, tag) + }) +} + +func TestTagComponent_DeleteTag(t *testing.T) { + ctx := context.TODO() + + username := "testUser" + + t.Run("admin", func(t *testing.T) { + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, username).Return(database.User{UUID: "testUUID", RoleMask: "admin"}, nil) + tc.mocks.stores.TagMock().EXPECT().DeleteTagByID(ctx, int64(1)).Return(nil) + + err := tc.DeleteTag(ctx, username, int64(1)) + require.Nil(t, err) + }) + + t.Run("non-admin", func(t *testing.T) { + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, username).Return(database.User{UUID: "testUUID", RoleMask: "persion"}, nil) + + err := tc.DeleteTag(ctx, username, int64(1)) + require.NotNil(t, err) + }) +} + func TestTagComponent_AllTagsByScopeAndCategory(t *testing.T) { ctx := context.TODO() tc := initializeTestTagComponent(ctx, t) diff --git a/component/wireset.go b/component/wireset.go index 6723031f..e9394a66 100644 --- a/component/wireset.go +++ b/component/wireset.go @@ -465,6 +465,7 @@ func NewTestTagComponent(config *config.Config, stores *tests.MockStores, sensit tagStore: stores.Tag, repoStore: stores.Repo, sensitiveChecker: sensitiveChecker, + userStore: stores.User, } } diff --git a/docs/docs.go b/docs/docs.go index 3b75dac4..4fe703f2 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -10022,6 +10022,158 @@ const docTemplate = `{ } } }, + "/tag/{id}": { + "get": { + "security": [ + { + "ApiKey": [] + } + ], + "description": "Get a tag by id", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Tag" + ], + "summary": "Get a tag by id", + "parameters": [ + { + "type": "string", + "description": "id of the tag", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/types.Response" + } + }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/types.APIInternalServerError" + } + } + } + }, + "put": { + "security": [ + { + "ApiKey": [] + } + ], + "description": "Update a tag by id", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Tag" + ], + "summary": "Update a tag by id", + "parameters": [ + { + "type": "string", + "description": "id of the tag", + "name": "id", + "in": "path", + "required": true + }, + { + "description": "body", + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/types.UpdateTag" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/types.Response" + } + }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/types.APIInternalServerError" + } + } + } + }, + "delete": { + "security": [ + { + "ApiKey": [] + } + ], + "description": "Delete a tag by id", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Tag" + ], + "summary": "Delete a tag by id", + "parameters": [ + { + "type": "string", + "description": "id of the tag", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/types.Response" + } + }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/types.APIInternalServerError" + } + } + } + } + }, "/tags": { "get": { "security": [ @@ -10029,7 +10181,7 @@ const docTemplate = `{ "ApiKey": [] } ], - "description": "get all tags", + "description": "Get all tags", "consumes": [ "application/json" ], @@ -10074,15 +10226,67 @@ const docTemplate = `{ "items": { "$ref": "#/definitions/database.Tag" } - }, - "total": { - "type": "integer" } } } ] } }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/types.APIInternalServerError" + } + } + } + }, + "post": { + "security": [ + { + "ApiKey": [] + } + ], + "description": "Create new tag", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Tag" + ], + "summary": "Create new tag", + "parameters": [ + { + "description": "body", + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/types.CreateTag" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/types.Response" + } + }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, "500": { "description": "Internal server error", "schema": { @@ -16919,6 +17123,34 @@ const docTemplate = `{ } } }, + "types.CreateTag": { + "type": "object", + "required": [ + "category", + "name", + "scope" + ], + "properties": { + "built_in": { + "type": "boolean" + }, + "category": { + "type": "string" + }, + "group": { + "type": "string" + }, + "name": { + "type": "string" + }, + "scope": { + "type": "string" + }, + "show_name": { + "type": "string" + } + } + }, "types.CreateUserTokenRequest": { "type": "object", "properties": { @@ -18556,6 +18788,34 @@ const docTemplate = `{ } } }, + "types.UpdateTag": { + "type": "object", + "required": [ + "category", + "name", + "scope" + ], + "properties": { + "built_in": { + "type": "boolean" + }, + "category": { + "type": "string" + }, + "group": { + "type": "string" + }, + "name": { + "type": "string" + }, + "scope": { + "type": "string" + }, + "show_name": { + "type": "string" + } + } + }, "types.UpdateUserRequest": { "type": "object", "properties": { diff --git a/docs/swagger.json b/docs/swagger.json index c5dfdef4..b0e864e6 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -10011,6 +10011,158 @@ } } }, + "/tag/{id}": { + "get": { + "security": [ + { + "ApiKey": [] + } + ], + "description": "Get a tag by id", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Tag" + ], + "summary": "Get a tag by id", + "parameters": [ + { + "type": "string", + "description": "id of the tag", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/types.Response" + } + }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/types.APIInternalServerError" + } + } + } + }, + "put": { + "security": [ + { + "ApiKey": [] + } + ], + "description": "Update a tag by id", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Tag" + ], + "summary": "Update a tag by id", + "parameters": [ + { + "type": "string", + "description": "id of the tag", + "name": "id", + "in": "path", + "required": true + }, + { + "description": "body", + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/types.UpdateTag" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/types.Response" + } + }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/types.APIInternalServerError" + } + } + } + }, + "delete": { + "security": [ + { + "ApiKey": [] + } + ], + "description": "Delete a tag by id", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Tag" + ], + "summary": "Delete a tag by id", + "parameters": [ + { + "type": "string", + "description": "id of the tag", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/types.Response" + } + }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/types.APIInternalServerError" + } + } + } + } + }, "/tags": { "get": { "security": [ @@ -10018,7 +10170,7 @@ "ApiKey": [] } ], - "description": "get all tags", + "description": "Get all tags", "consumes": [ "application/json" ], @@ -10063,15 +10215,67 @@ "items": { "$ref": "#/definitions/database.Tag" } - }, - "total": { - "type": "integer" } } } ] } }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/types.APIInternalServerError" + } + } + } + }, + "post": { + "security": [ + { + "ApiKey": [] + } + ], + "description": "Create new tag", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Tag" + ], + "summary": "Create new tag", + "parameters": [ + { + "description": "body", + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/types.CreateTag" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/types.Response" + } + }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, "500": { "description": "Internal server error", "schema": { @@ -16908,6 +17112,34 @@ } } }, + "types.CreateTag": { + "type": "object", + "required": [ + "category", + "name", + "scope" + ], + "properties": { + "built_in": { + "type": "boolean" + }, + "category": { + "type": "string" + }, + "group": { + "type": "string" + }, + "name": { + "type": "string" + }, + "scope": { + "type": "string" + }, + "show_name": { + "type": "string" + } + } + }, "types.CreateUserTokenRequest": { "type": "object", "properties": { @@ -18545,6 +18777,34 @@ } } }, + "types.UpdateTag": { + "type": "object", + "required": [ + "category", + "name", + "scope" + ], + "properties": { + "built_in": { + "type": "boolean" + }, + "category": { + "type": "string" + }, + "group": { + "type": "string" + }, + "name": { + "type": "string" + }, + "scope": { + "type": "string" + }, + "show_name": { + "type": "string" + } + } + }, "types.UpdateUserRequest": { "type": "object", "properties": { diff --git a/docs/swagger.yaml b/docs/swagger.yaml index c8f611fc..4ba5c8dc 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -1236,6 +1236,25 @@ definitions: required: - token type: object + types.CreateTag: + properties: + built_in: + type: boolean + category: + type: string + group: + type: string + name: + type: string + scope: + type: string + show_name: + type: string + required: + - category + - name + - scope + type: object types.CreateUserTokenRequest: properties: application: @@ -2343,6 +2362,25 @@ definitions: version: type: string type: object + types.UpdateTag: + properties: + built_in: + type: boolean + category: + type: string + group: + type: string + name: + type: string + scope: + type: string + show_name: + type: string + required: + - category + - name + - scope + type: object types.UpdateUserRequest: properties: avatar: @@ -10477,11 +10515,108 @@ paths: summary: Get latest version tags: - Sync + /tag/{id}: + delete: + consumes: + - application/json + description: Delete a tag by id + parameters: + - description: id of the tag + in: path + name: id + required: true + type: string + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/types.Response' + "400": + description: Bad request + schema: + $ref: '#/definitions/types.APIBadRequest' + "500": + description: Internal server error + schema: + $ref: '#/definitions/types.APIInternalServerError' + security: + - ApiKey: [] + summary: Delete a tag by id + tags: + - Tag + get: + consumes: + - application/json + description: Get a tag by id + parameters: + - description: id of the tag + in: path + name: id + required: true + type: string + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/types.Response' + "400": + description: Bad request + schema: + $ref: '#/definitions/types.APIBadRequest' + "500": + description: Internal server error + schema: + $ref: '#/definitions/types.APIInternalServerError' + security: + - ApiKey: [] + summary: Get a tag by id + tags: + - Tag + put: + consumes: + - application/json + description: Update a tag by id + parameters: + - description: id of the tag + in: path + name: id + required: true + type: string + - description: body + in: body + name: body + required: true + schema: + $ref: '#/definitions/types.UpdateTag' + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/types.Response' + "400": + description: Bad request + schema: + $ref: '#/definitions/types.APIBadRequest' + "500": + description: Internal server error + schema: + $ref: '#/definitions/types.APIInternalServerError' + security: + - ApiKey: [] + summary: Update a tag by id + tags: + - Tag /tags: get: consumes: - application/json - description: get all tags + description: Get all tags parameters: - description: category name in: query @@ -10507,9 +10642,11 @@ paths: items: $ref: '#/definitions/database.Tag' type: array - total: - type: integer type: object + "400": + description: Bad request + schema: + $ref: '#/definitions/types.APIBadRequest' "500": description: Internal server error schema: @@ -10519,6 +10656,37 @@ paths: summary: Get all tags tags: - Tag + post: + consumes: + - application/json + description: Create new tag + parameters: + - description: body + in: body + name: body + required: true + schema: + $ref: '#/definitions/types.CreateTag' + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/types.Response' + "400": + description: Bad request + schema: + $ref: '#/definitions/types.APIBadRequest' + "500": + description: Internal server error + schema: + $ref: '#/definitions/types.APIInternalServerError' + security: + - ApiKey: [] + summary: Create new tag + tags: + - Tag /telemetry/usage: post: consumes: From bf0d012e492785f59d1463a4a3d14c8524b60901 Mon Sep 17 00:00:00 2001 From: SeanHH86 <154984842+SeanHH86@users.noreply.github.com> Date: Thu, 19 Dec 2024 11:00:16 +0800 Subject: [PATCH 12/34] [Tag] Update swagger doc (#216) Co-authored-by: Haihui.Wang --- api/handler/tag.go | 6 +- docs/docs.go | 136 ++++++++++++++++++++++----------------------- docs/swagger.json | 136 ++++++++++++++++++++++----------------------- docs/swagger.yaml | 104 +++++++++++++++++----------------- 4 files changed, 191 insertions(+), 191 deletions(-) diff --git a/api/handler/tag.go b/api/handler/tag.go index 53a10ebc..16689e50 100644 --- a/api/handler/tag.go +++ b/api/handler/tag.go @@ -99,7 +99,7 @@ func (t *TagsHandler) CreateTag(ctx *gin.Context) { // @Success 200 {object} types.Response{database.Tag} "OK" // @Failure 400 {object} types.APIBadRequest "Bad request" // @Failure 500 {object} types.APIInternalServerError "Internal server error" -// @Router /tag/{id} [get] +// @Router /tags/{id} [get] func (t *TagsHandler) GetTagByID(ctx *gin.Context) { userName := httpbase.GetCurrentUser(ctx) if userName == "" { @@ -133,7 +133,7 @@ func (t *TagsHandler) GetTagByID(ctx *gin.Context) { // @Success 200 {object} types.Response{database.Tag} "OK" // @Failure 400 {object} types.APIBadRequest "Bad request" // @Failure 500 {object} types.APIInternalServerError "Internal server error" -// @Router /tag/{id} [put] +// @Router /tags/{id} [put] func (t *TagsHandler) UpdateTag(ctx *gin.Context) { userName := httpbase.GetCurrentUser(ctx) if userName == "" { @@ -172,7 +172,7 @@ func (t *TagsHandler) UpdateTag(ctx *gin.Context) { // @Success 200 {object} types.Response{} "OK" // @Failure 400 {object} types.APIBadRequest "Bad request" // @Failure 500 {object} types.APIInternalServerError "Internal server error" -// @Router /tag/{id} [delete] +// @Router /tags/{id} [delete] func (t *TagsHandler) DeleteTag(ctx *gin.Context) { userName := httpbase.GetCurrentUser(ctx) if userName == "" { diff --git a/docs/docs.go b/docs/docs.go index 4fe703f2..b2611626 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -10022,14 +10022,14 @@ const docTemplate = `{ } } }, - "/tag/{id}": { + "/tags": { "get": { "security": [ { "ApiKey": [] } ], - "description": "Get a tag by id", + "description": "Get all tags", "consumes": [ "application/json" ], @@ -10039,21 +10039,45 @@ const docTemplate = `{ "tags": [ "Tag" ], - "summary": "Get a tag by id", + "summary": "Get all tags", "parameters": [ { "type": "string", - "description": "id of the tag", - "name": "id", - "in": "path", - "required": true + "description": "category name", + "name": "category", + "in": "query" + }, + { + "enum": [ + "model", + "dataset" + ], + "type": "string", + "description": "scope name", + "name": "scope", + "in": "query" } ], "responses": { "200": { - "description": "OK", + "description": "tags", "schema": { - "$ref": "#/definitions/types.Response" + "allOf": [ + { + "$ref": "#/definitions/types.ResponseWithTotal" + }, + { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/definitions/database.Tag" + } + } + } + } + ] } }, "400": { @@ -10070,13 +10094,13 @@ const docTemplate = `{ } } }, - "put": { + "post": { "security": [ { "ApiKey": [] } ], - "description": "Update a tag by id", + "description": "Create new tag", "consumes": [ "application/json" ], @@ -10086,22 +10110,15 @@ const docTemplate = `{ "tags": [ "Tag" ], - "summary": "Update a tag by id", + "summary": "Create new tag", "parameters": [ - { - "type": "string", - "description": "id of the tag", - "name": "id", - "in": "path", - "required": true - }, { "description": "body", "name": "body", "in": "body", "required": true, "schema": { - "$ref": "#/definitions/types.UpdateTag" + "$ref": "#/definitions/types.CreateTag" } } ], @@ -10125,14 +10142,16 @@ const docTemplate = `{ } } } - }, - "delete": { + } + }, + "/tags/{id}": { + "get": { "security": [ { "ApiKey": [] } ], - "description": "Delete a tag by id", + "description": "Get a tag by id", "consumes": [ "application/json" ], @@ -10142,7 +10161,7 @@ const docTemplate = `{ "tags": [ "Tag" ], - "summary": "Delete a tag by id", + "summary": "Get a tag by id", "parameters": [ { "type": "string", @@ -10172,16 +10191,14 @@ const docTemplate = `{ } } } - } - }, - "/tags": { - "get": { + }, + "put": { "security": [ { "ApiKey": [] } ], - "description": "Get all tags", + "description": "Update a tag by id", "consumes": [ "application/json" ], @@ -10191,45 +10208,30 @@ const docTemplate = `{ "tags": [ "Tag" ], - "summary": "Get all tags", + "summary": "Update a tag by id", "parameters": [ { "type": "string", - "description": "category name", - "name": "category", - "in": "query" + "description": "id of the tag", + "name": "id", + "in": "path", + "required": true }, { - "enum": [ - "model", - "dataset" - ], - "type": "string", - "description": "scope name", - "name": "scope", - "in": "query" + "description": "body", + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/types.UpdateTag" + } } ], "responses": { "200": { - "description": "tags", + "description": "OK", "schema": { - "allOf": [ - { - "$ref": "#/definitions/types.ResponseWithTotal" - }, - { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "$ref": "#/definitions/database.Tag" - } - } - } - } - ] + "$ref": "#/definitions/types.Response" } }, "400": { @@ -10246,13 +10248,13 @@ const docTemplate = `{ } } }, - "post": { + "delete": { "security": [ { "ApiKey": [] } ], - "description": "Create new tag", + "description": "Delete a tag by id", "consumes": [ "application/json" ], @@ -10262,16 +10264,14 @@ const docTemplate = `{ "tags": [ "Tag" ], - "summary": "Create new tag", + "summary": "Delete a tag by id", "parameters": [ { - "description": "body", - "name": "body", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/types.CreateTag" - } + "type": "string", + "description": "id of the tag", + "name": "id", + "in": "path", + "required": true } ], "responses": { diff --git a/docs/swagger.json b/docs/swagger.json index b0e864e6..196b34e1 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -10011,14 +10011,14 @@ } } }, - "/tag/{id}": { + "/tags": { "get": { "security": [ { "ApiKey": [] } ], - "description": "Get a tag by id", + "description": "Get all tags", "consumes": [ "application/json" ], @@ -10028,21 +10028,45 @@ "tags": [ "Tag" ], - "summary": "Get a tag by id", + "summary": "Get all tags", "parameters": [ { "type": "string", - "description": "id of the tag", - "name": "id", - "in": "path", - "required": true + "description": "category name", + "name": "category", + "in": "query" + }, + { + "enum": [ + "model", + "dataset" + ], + "type": "string", + "description": "scope name", + "name": "scope", + "in": "query" } ], "responses": { "200": { - "description": "OK", + "description": "tags", "schema": { - "$ref": "#/definitions/types.Response" + "allOf": [ + { + "$ref": "#/definitions/types.ResponseWithTotal" + }, + { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/definitions/database.Tag" + } + } + } + } + ] } }, "400": { @@ -10059,13 +10083,13 @@ } } }, - "put": { + "post": { "security": [ { "ApiKey": [] } ], - "description": "Update a tag by id", + "description": "Create new tag", "consumes": [ "application/json" ], @@ -10075,22 +10099,15 @@ "tags": [ "Tag" ], - "summary": "Update a tag by id", + "summary": "Create new tag", "parameters": [ - { - "type": "string", - "description": "id of the tag", - "name": "id", - "in": "path", - "required": true - }, { "description": "body", "name": "body", "in": "body", "required": true, "schema": { - "$ref": "#/definitions/types.UpdateTag" + "$ref": "#/definitions/types.CreateTag" } } ], @@ -10114,14 +10131,16 @@ } } } - }, - "delete": { + } + }, + "/tags/{id}": { + "get": { "security": [ { "ApiKey": [] } ], - "description": "Delete a tag by id", + "description": "Get a tag by id", "consumes": [ "application/json" ], @@ -10131,7 +10150,7 @@ "tags": [ "Tag" ], - "summary": "Delete a tag by id", + "summary": "Get a tag by id", "parameters": [ { "type": "string", @@ -10161,16 +10180,14 @@ } } } - } - }, - "/tags": { - "get": { + }, + "put": { "security": [ { "ApiKey": [] } ], - "description": "Get all tags", + "description": "Update a tag by id", "consumes": [ "application/json" ], @@ -10180,45 +10197,30 @@ "tags": [ "Tag" ], - "summary": "Get all tags", + "summary": "Update a tag by id", "parameters": [ { "type": "string", - "description": "category name", - "name": "category", - "in": "query" + "description": "id of the tag", + "name": "id", + "in": "path", + "required": true }, { - "enum": [ - "model", - "dataset" - ], - "type": "string", - "description": "scope name", - "name": "scope", - "in": "query" + "description": "body", + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/types.UpdateTag" + } } ], "responses": { "200": { - "description": "tags", + "description": "OK", "schema": { - "allOf": [ - { - "$ref": "#/definitions/types.ResponseWithTotal" - }, - { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "$ref": "#/definitions/database.Tag" - } - } - } - } - ] + "$ref": "#/definitions/types.Response" } }, "400": { @@ -10235,13 +10237,13 @@ } } }, - "post": { + "delete": { "security": [ { "ApiKey": [] } ], - "description": "Create new tag", + "description": "Delete a tag by id", "consumes": [ "application/json" ], @@ -10251,16 +10253,14 @@ "tags": [ "Tag" ], - "summary": "Create new tag", + "summary": "Delete a tag by id", "parameters": [ { - "description": "body", - "name": "body", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/types.CreateTag" - } + "type": "string", + "description": "id of the tag", + "name": "id", + "in": "path", + "required": true } ], "responses": { diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 4ba5c8dc..1209dabb 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -10515,24 +10515,37 @@ paths: summary: Get latest version tags: - Sync - /tag/{id}: - delete: + /tags: + get: consumes: - application/json - description: Delete a tag by id + description: Get all tags parameters: - - description: id of the tag - in: path - name: id - required: true + - description: category name + in: query + name: category + type: string + - description: scope name + enum: + - model + - dataset + in: query + name: scope type: string produces: - application/json responses: "200": - description: OK + description: tags schema: - $ref: '#/definitions/types.Response' + allOf: + - $ref: '#/definitions/types.ResponseWithTotal' + - properties: + data: + items: + $ref: '#/definitions/database.Tag' + type: array + type: object "400": description: Bad request schema: @@ -10543,19 +10556,20 @@ paths: $ref: '#/definitions/types.APIInternalServerError' security: - ApiKey: [] - summary: Delete a tag by id + summary: Get all tags tags: - Tag - get: + post: consumes: - application/json - description: Get a tag by id + description: Create new tag parameters: - - description: id of the tag - in: path - name: id + - description: body + in: body + name: body required: true - type: string + schema: + $ref: '#/definitions/types.CreateTag' produces: - application/json responses: @@ -10573,25 +10587,20 @@ paths: $ref: '#/definitions/types.APIInternalServerError' security: - ApiKey: [] - summary: Get a tag by id + summary: Create new tag tags: - Tag - put: + /tags/{id}: + delete: consumes: - application/json - description: Update a tag by id + description: Delete a tag by id parameters: - description: id of the tag in: path name: id required: true type: string - - description: body - in: body - name: body - required: true - schema: - $ref: '#/definitions/types.UpdateTag' produces: - application/json responses: @@ -10609,40 +10618,26 @@ paths: $ref: '#/definitions/types.APIInternalServerError' security: - ApiKey: [] - summary: Update a tag by id + summary: Delete a tag by id tags: - Tag - /tags: get: consumes: - application/json - description: Get all tags + description: Get a tag by id parameters: - - description: category name - in: query - name: category - type: string - - description: scope name - enum: - - model - - dataset - in: query - name: scope + - description: id of the tag + in: path + name: id + required: true type: string produces: - application/json responses: "200": - description: tags + description: OK schema: - allOf: - - $ref: '#/definitions/types.ResponseWithTotal' - - properties: - data: - items: - $ref: '#/definitions/database.Tag' - type: array - type: object + $ref: '#/definitions/types.Response' "400": description: Bad request schema: @@ -10653,20 +10648,25 @@ paths: $ref: '#/definitions/types.APIInternalServerError' security: - ApiKey: [] - summary: Get all tags + summary: Get a tag by id tags: - Tag - post: + put: consumes: - application/json - description: Create new tag + description: Update a tag by id parameters: + - description: id of the tag + in: path + name: id + required: true + type: string - description: body in: body name: body required: true schema: - $ref: '#/definitions/types.CreateTag' + $ref: '#/definitions/types.UpdateTag' produces: - application/json responses: @@ -10684,7 +10684,7 @@ paths: $ref: '#/definitions/types.APIInternalServerError' security: - ApiKey: [] - summary: Create new tag + summary: Update a tag by id tags: - Tag /telemetry/usage: From c37d7651e8614514f9fef7b91506e47e6efc4a75 Mon Sep 17 00:00:00 2001 From: Lei Da Date: Wed, 11 Dec 2024 14:11:34 +0800 Subject: [PATCH 13/34] remove unused llm infer client --- .mockery.yaml | 4 - .../builder/inference/mock_Client.go | 150 -------------- api/handler/model.go | 43 ---- api/router/api.go | 3 - builder/inference/client.go | 28 --- builder/inference/init.go | 6 - builder/inference/llm_infer.go | 190 ------------------ builder/inference/rllm_types.go | 18 -- common/config/config.go | 4 - common/config/config.toml.example | 3 - component/model.go | 23 --- 11 files changed, 472 deletions(-) delete mode 100644 _mocks/opencsg.com/csghub-server/builder/inference/mock_Client.go delete mode 100644 builder/inference/client.go delete mode 100644 builder/inference/init.go delete mode 100644 builder/inference/llm_infer.go delete mode 100644 builder/inference/rllm_types.go diff --git a/.mockery.yaml b/.mockery.yaml index 630d57a4..312ef32b 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -75,10 +75,6 @@ packages: config: interfaces: Deployer: - opencsg.com/csghub-server/builder/inference: - config: - interfaces: - Client: opencsg.com/csghub-server/accounting/component: config: interfaces: diff --git a/_mocks/opencsg.com/csghub-server/builder/inference/mock_Client.go b/_mocks/opencsg.com/csghub-server/builder/inference/mock_Client.go deleted file mode 100644 index 14db257f..00000000 --- a/_mocks/opencsg.com/csghub-server/builder/inference/mock_Client.go +++ /dev/null @@ -1,150 +0,0 @@ -// Code generated by mockery v2.49.1. DO NOT EDIT. - -package inference - -import ( - mock "github.com/stretchr/testify/mock" - inference "opencsg.com/csghub-server/builder/inference" -) - -// MockClient is an autogenerated mock type for the Client type -type MockClient struct { - mock.Mock -} - -type MockClient_Expecter struct { - mock *mock.Mock -} - -func (_m *MockClient) EXPECT() *MockClient_Expecter { - return &MockClient_Expecter{mock: &_m.Mock} -} - -// GetModelInfo provides a mock function with given fields: id -func (_m *MockClient) GetModelInfo(id inference.ModelID) (inference.ModelInfo, error) { - ret := _m.Called(id) - - if len(ret) == 0 { - panic("no return value specified for GetModelInfo") - } - - var r0 inference.ModelInfo - var r1 error - if rf, ok := ret.Get(0).(func(inference.ModelID) (inference.ModelInfo, error)); ok { - return rf(id) - } - if rf, ok := ret.Get(0).(func(inference.ModelID) inference.ModelInfo); ok { - r0 = rf(id) - } else { - r0 = ret.Get(0).(inference.ModelInfo) - } - - if rf, ok := ret.Get(1).(func(inference.ModelID) error); ok { - r1 = rf(id) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockClient_GetModelInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetModelInfo' -type MockClient_GetModelInfo_Call struct { - *mock.Call -} - -// GetModelInfo is a helper method to define mock.On call -// - id inference.ModelID -func (_e *MockClient_Expecter) GetModelInfo(id interface{}) *MockClient_GetModelInfo_Call { - return &MockClient_GetModelInfo_Call{Call: _e.mock.On("GetModelInfo", id)} -} - -func (_c *MockClient_GetModelInfo_Call) Run(run func(id inference.ModelID)) *MockClient_GetModelInfo_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(inference.ModelID)) - }) - return _c -} - -func (_c *MockClient_GetModelInfo_Call) Return(_a0 inference.ModelInfo, _a1 error) *MockClient_GetModelInfo_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockClient_GetModelInfo_Call) RunAndReturn(run func(inference.ModelID) (inference.ModelInfo, error)) *MockClient_GetModelInfo_Call { - _c.Call.Return(run) - return _c -} - -// Predict provides a mock function with given fields: id, req -func (_m *MockClient) Predict(id inference.ModelID, req *inference.PredictRequest) (*inference.PredictResponse, error) { - ret := _m.Called(id, req) - - if len(ret) == 0 { - panic("no return value specified for Predict") - } - - var r0 *inference.PredictResponse - var r1 error - if rf, ok := ret.Get(0).(func(inference.ModelID, *inference.PredictRequest) (*inference.PredictResponse, error)); ok { - return rf(id, req) - } - if rf, ok := ret.Get(0).(func(inference.ModelID, *inference.PredictRequest) *inference.PredictResponse); ok { - r0 = rf(id, req) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*inference.PredictResponse) - } - } - - if rf, ok := ret.Get(1).(func(inference.ModelID, *inference.PredictRequest) error); ok { - r1 = rf(id, req) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockClient_Predict_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Predict' -type MockClient_Predict_Call struct { - *mock.Call -} - -// Predict is a helper method to define mock.On call -// - id inference.ModelID -// - req *inference.PredictRequest -func (_e *MockClient_Expecter) Predict(id interface{}, req interface{}) *MockClient_Predict_Call { - return &MockClient_Predict_Call{Call: _e.mock.On("Predict", id, req)} -} - -func (_c *MockClient_Predict_Call) Run(run func(id inference.ModelID, req *inference.PredictRequest)) *MockClient_Predict_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(inference.ModelID), args[1].(*inference.PredictRequest)) - }) - return _c -} - -func (_c *MockClient_Predict_Call) Return(_a0 *inference.PredictResponse, _a1 error) *MockClient_Predict_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockClient_Predict_Call) RunAndReturn(run func(inference.ModelID, *inference.PredictRequest) (*inference.PredictResponse, error)) *MockClient_Predict_Call { - _c.Call.Return(run) - return _c -} - -// NewMockClient creates a new instance of MockClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockClient(t interface { - mock.TestingT - Cleanup(func()) -}) *MockClient { - mock := &MockClient{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/api/handler/model.go b/api/handler/model.go index 9d489f7a..8f0df98d 100644 --- a/api/handler/model.go +++ b/api/handler/model.go @@ -482,49 +482,6 @@ func (h *ModelHandler) DelDatasetRelation(ctx *gin.Context) { httpbase.OK(ctx, nil) } -// Predict godoc -// @Security ApiKey -// @Summary Invoke model prediction -// @Description invoke model prediction -// @Tags Model -// @Accept json -// @Produce json -// @Param namespace path string true "namespace" -// @Param name path string true "name" -// @Param current_user query string false "current user" -// @Param body body types.ModelPredictReq true "input for model prediction" -// @Success 200 {object} string "OK" -// @Failure 400 {object} types.APIBadRequest "Bad request" -// @Failure 500 {object} types.APIInternalServerError "Internal server error" -// @Router /models/{namespace}/{name}/predict [post] -func (h *ModelHandler) Predict(ctx *gin.Context) { - var req types.ModelPredictReq - namespace, name, err := common.GetNamespaceAndNameFromContext(ctx) - if err != nil { - slog.Error("Bad request format", "error", err) - httpbase.BadRequest(ctx, err.Error()) - return - } - - if err := ctx.ShouldBindJSON(&req); err != nil { - slog.Error("Bad request format", "error", err) - httpbase.BadRequest(ctx, err.Error()) - return - } - - req.Name = name - req.Namespace = namespace - - resp, err := h.c.Predict(ctx, &req) - if err != nil { - slog.Error("fail to call predict", slog.String("error", err.Error())) - httpbase.ServerError(ctx, err) - return - } - - httpbase.OK(ctx, resp) -} - func parseTagReqs(ctx *gin.Context) (tags []types.TagReq) { tagCategories := ctx.QueryArray("tag_category") tagNames := ctx.QueryArray("tag_name") diff --git a/api/router/api.go b/api/router/api.go index 3ba007c5..6fc3ae6a 100644 --- a/api/router/api.go +++ b/api/router/api.go @@ -435,9 +435,6 @@ func createModelRoutes(config *config.Config, apiGroup *gin.RouterGroup, needAPI modelsGroup.POST("/:namespace/:name/update_downloads", middleware.RepoType(types.ModelRepo), repoCommonHandler.UpdateDownloads) modelsGroup.PUT("/:namespace/:name/incr_downloads", middleware.RepoType(types.ModelRepo), repoCommonHandler.IncrDownloads) modelsGroup.POST("/:namespace/:name/upload_file", middleware.RepoType(types.ModelRepo), repoCommonHandler.UploadFile) - // invoke model endpoint to do pediction - modelsGroup.POST("/:namespace/:name/predict", modelHandler.Predict) - modelsGroup.POST("/:namespace/:name/mirror", middleware.RepoType(types.ModelRepo), repoCommonHandler.CreateMirror) modelsGroup.GET("/:namespace/:name/mirror", middleware.RepoType(types.ModelRepo), repoCommonHandler.GetMirror) modelsGroup.PUT("/:namespace/:name/mirror", middleware.RepoType(types.ModelRepo), repoCommonHandler.UpdateMirror) diff --git a/builder/inference/client.go b/builder/inference/client.go deleted file mode 100644 index f80a5054..00000000 --- a/builder/inference/client.go +++ /dev/null @@ -1,28 +0,0 @@ -package inference - -type Client interface { - Predict(id ModelID, req *PredictRequest) (*PredictResponse, error) - GetModelInfo(id ModelID) (ModelInfo, error) -} - -type PredictRequest struct { - Prompt string `json:"prompt"` -} - -type PredictResponse struct { - GeneratedText string `json:"generated_text"` - NumInputTokens int `json:"num_input_tokens"` - NumInputTokensBatch int `json:"num_input_tokens_batch"` - NumGeneratedTokens int `json:"num_generated_tokens"` - NumGeneratedTokensBatch int `json:"num_generated_tokens_batch"` - PreprocessingTime float64 `json:"preprocessing_time"` - GenerationTime float64 `json:"generation_time"` - PostprocessingTime float64 `json:"postprocessing_time"` - GenerationTimePerToken float64 `json:"generation_time_per_token"` - GenerationTimePerTokenBatch float64 `json:"generation_time_per_token_batch"` - NumTotalTokens int `json:"num_total_tokens"` - NumTotalTokensBatch int `json:"num_total_tokens_batch"` - TotalTime float64 `json:"total_time"` - TotalTimePerToken float64 `json:"total_time_per_token"` - TotalTimePerTokenBatch float64 `json:"total_time_per_token_batch"` -} diff --git a/builder/inference/init.go b/builder/inference/init.go deleted file mode 100644 index 411a1393..00000000 --- a/builder/inference/init.go +++ /dev/null @@ -1,6 +0,0 @@ -package inference - -import "opencsg.com/csghub-server/common/config" - -func Init(config *config.Config) { -} diff --git a/builder/inference/llm_infer.go b/builder/inference/llm_infer.go deleted file mode 100644 index ba01774d..00000000 --- a/builder/inference/llm_infer.go +++ /dev/null @@ -1,190 +0,0 @@ -package inference - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "hash/fnv" - "io" - "log/slog" - "net/http" - "net/url" - "strings" - "time" -) - -type ModelID struct { - Owner, Name string - // reserved, keep empty string "" - Version string -} - -func (m ModelID) Hash() uint64 { - f := fnv.New64() - f.Write([]byte(m.Owner)) - f.Write([]byte(":")) - f.Write([]byte(m.Name)) - f.Write([]byte(":")) - f.Write([]byte(m.Version)) - return f.Sum64() -} - -var _ Client = (*llmInferClient)(nil) - -type ModelInfo struct { - Endpoint string - // deploy,running,failed etc - Status string - // ModelID.Hash() - HashID uint64 -} - -type llmInferClient struct { - lastUpdate time.Time - hc *http.Client - modelInfos map[uint64]ModelInfo - serverAddr string -} - -func NewInferClient(addr string) Client { - hc := http.DefaultClient - hc.Timeout = time.Minute - return &llmInferClient{ - hc: hc, - modelInfos: make(map[uint64]ModelInfo), - serverAddr: addr, - } -} - -func (c *llmInferClient) Predict(id ModelID, req *PredictRequest) (*PredictResponse, error) { - s, err := c.GetModelInfo(id) - if err != nil { - return nil, fmt.Errorf("failed to get model info,error:%w", err) - } - - { - // for test only, as inference service is not ready - if id.Owner == "test_user_name" && id.Name == "test_model_name" { - return &PredictResponse{GeneratedText: "this is a test predict result."}, nil - } - } - return c.CallPredict(s.Endpoint, req) -} - -// ListServing call inference service to ge all serving models -func (c *llmInferClient) ListServing() (map[uint64]ModelInfo, error) { - defer func() { - // for test only - testModelID := ModelID{ - Owner: "test_user_name", - Name: "test_model_name", - Version: "", - } - c.modelInfos[testModelID.Hash()] = ModelInfo{ - HashID: testModelID.Hash(), - Endpoint: "http://localhost:8080/test_user_name/test_model_name", - Status: "running", - } - }() - - // use local cache first - if expire := time.Since(c.lastUpdate).Seconds(); expire < 30 { - slog.Info("use cached model infos", slog.Float64("expire", expire)) - return c.modelInfos, nil - } - - api, _ := url.JoinPath(c.serverAddr, "/api/list_serving") - req, _ := http.NewRequest(http.MethodGet, api, nil) - req.Header.Set("user-name", "default") - resp, err := c.hc.Do(req) - if err != nil { - slog.Error("fail to call list serving api", slog.Any("err", err)) - return c.modelInfos, fmt.Errorf("fail to call list serving api,%w", err) - } - defer resp.Body.Close() - llmInfos := make(map[string]LlmModelInfo) - err = json.NewDecoder(resp.Body).Decode(&llmInfos) - if err != nil { - slog.Error("fail to decode list serving response", slog.Any("err", err)) - return c.modelInfos, fmt.Errorf("fail to decode list serving response,%w", err) - } - - slog.Debug("llmResp", slog.Any("map", llmInfos)) - if len(llmInfos) > 0 { - c.updateModelInfos(llmInfos) - } - return c.modelInfos, nil -} - -func (c *llmInferClient) updateModelInfos(llmInfos map[string]LlmModelInfo) { - tmp := make(map[uint64]ModelInfo) - for _, v := range llmInfos { - for modelName, endpoint := range v.URL { - // example: THUDM/chatglm3-6b - owner, name, _ := strings.Cut(modelName, "/") - mid := ModelID{ - Owner: owner, - Name: name, - Version: "", - } - slog.Debug("get model info", slog.Any("mid", mid), slog.String("endpoint", endpoint)) - // endpoint = strings.Replace(endpoint, "http://0.0.0.0:8000", c.serverAddr, 1) - parsedUrl, _ := url.Parse(endpoint) - endpoint, _ = url.JoinPath(c.serverAddr, parsedUrl.RequestURI()) - slog.Debug("replace llm endpoint with new domain", slog.String("new_endpoint", endpoint)) - var status string - if len(v.Status) > 0 { - for _, vs := range v.Status { - status = vs.ApplicationStatus - break - } - } - mi := ModelInfo{ - Endpoint: endpoint, - Status: status, - HashID: mid.Hash(), - } - tmp[mi.HashID] = mi - // only one url - break - } - } - c.modelInfos = tmp - c.lastUpdate = time.Now() -} - -func (c *llmInferClient) GetModelInfo(id ModelID) (ModelInfo, error) { - list, err := c.ListServing() - if err != nil { - return ModelInfo{}, err - } - - if s, ok := list[id.Hash()]; ok { - return s, nil - } - - return ModelInfo{}, errors.New("model info not found by id") -} - -func (c *llmInferClient) CallPredict(url string, req *PredictRequest) (*PredictResponse, error) { - var body bytes.Buffer - err := json.NewEncoder(&body).Encode(req) - if err != nil { - return nil, err - } - resp, err := c.hc.Post(url, "application/json", &body) - if err != nil { - return nil, fmt.Errorf("failed to send http request,error: %w", err) - } - defer resp.Body.Close() - - data, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body,error: %w", err) - } - - var r PredictResponse - err = json.Unmarshal(data, &r) - return &r, err -} diff --git a/builder/inference/rllm_types.go b/builder/inference/rllm_types.go deleted file mode 100644 index 52297248..00000000 --- a/builder/inference/rllm_types.go +++ /dev/null @@ -1,18 +0,0 @@ -package inference - -type LlmModelInfo struct { - URL map[string]string `json:"url"` - Status map[string]LlmModelInfo_Status `json:"status"` -} - -type LlmModelInfo_Status struct { - /*example: - * { - * "OpenCSG--opencsg-CodeLlama-7b-v0.1": "HEALTHY", - * "RouterDeployment": "HEALTHY" - * } - */ - DeploymentsStatus map[string]string `json:"deployments_status"` - // example: RUNNING - ApplicationStatus string `json:"application_status"` -} diff --git a/common/config/config.go b/common/config/config.go index c69ccc46..bd87c130 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -109,10 +109,6 @@ type Config struct { ValidHour int `env:"STARHUB_JWT_VALIDATE_HOUR, default=24"` } - Inference struct { - ServerAddr string `env:"STARHUB_SERVER_INFERENCE_SERVER_ADDR, default=http://localhost:8000"` - } - Space struct { BuilderEndpoint string `env:"STARHUB_SERVER_SPACE_BUILDER_ENDPOINT, default=http://localhost:8081"` // base url for space api running in k8s cluster diff --git a/common/config/config.toml.example b/common/config/config.toml.example index e1073ea0..f9d95a5b 100644 --- a/common/config/config.toml.example +++ b/common/config/config.toml.example @@ -81,9 +81,6 @@ enable_ssl = true signing_key = "signing-key" valid_hour = 24 -[inference] -server_addr = "http://localhost:8000" - [space] builder_endpoint = "http://localhost:8081" runner_endpoint = "http://localhost:8082" diff --git a/component/model.go b/component/model.go index c06c6e80..cf64ad60 100644 --- a/component/model.go +++ b/component/model.go @@ -13,7 +13,6 @@ import ( "opencsg.com/csghub-server/builder/git" "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/git/membership" - "opencsg.com/csghub-server/builder/inference" "opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/config" @@ -75,7 +74,6 @@ type ModelComponent interface { SetRelationDatasets(ctx context.Context, req types.RelationDatasets) error AddRelationDataset(ctx context.Context, req types.RelationDataset) error DelRelationDataset(ctx context.Context, req types.RelationDataset) error - Predict(ctx context.Context, req *types.ModelPredictReq) (*types.ModelPredictResp, error) // create model deploy as inference/serverless Deploy(ctx context.Context, deployReq types.DeployActReq, req types.ModelRunReq) (int64, error) ListModelsByRuntimeFrameworkID(ctx context.Context, currentUser string, per, page int, id int64, deployType int) ([]types.Model, int, error) @@ -97,7 +95,6 @@ func NewModelComponent(config *config.Config) (ModelComponent, error) { c.modelStore = database.NewModelStore() c.repoStore = database.NewRepoStore() c.spaceResourceStore = database.NewSpaceResourceStore() - c.inferClient = inference.NewInferClient(config.Inference.ServerAddr) c.userStore = database.NewUserStore() c.userLikesStore = database.NewUserLikesStore() c.deployer = deploy.NewDeployer() @@ -131,7 +128,6 @@ type modelComponentImpl struct { modelStore database.ModelStore repoStore database.RepoStore spaceResourceStore database.SpaceResourceStore - inferClient inference.Client userStore database.UserStore deployer deploy.Deployer accountingComponent AccountingComponent @@ -871,25 +867,6 @@ func getFilePaths(namespace, repoName, folder string, repoType types.RepositoryT return filePaths, nil } -func (c *modelComponentImpl) Predict(ctx context.Context, req *types.ModelPredictReq) (*types.ModelPredictResp, error) { - mid := inference.ModelID{ - Owner: req.Namespace, - Name: req.Name, - } - inferReq := &inference.PredictRequest{ - Prompt: req.Input, - } - inferResp, err := c.inferClient.Predict(mid, inferReq) - if err != nil { - slog.Error("failed to predict", slog.Any("req", *inferReq), slog.Any("model", mid), slog.String("error", err.Error())) - return nil, err - } - resp := &types.ModelPredictResp{ - Content: inferResp.GeneratedText, - } - return resp, nil -} - // create model deploy as inference/serverless func (c *modelComponentImpl) Deploy(ctx context.Context, deployReq types.DeployActReq, req types.ModelRunReq) (int64, error) { m, err := c.modelStore.FindByPath(ctx, deployReq.Namespace, deployReq.Name) From 2b1d8eb5192515d42ba09b9e1c9d14293411500f Mon Sep 17 00:00:00 2001 From: yiling Date: Thu, 19 Dec 2024 18:23:16 +0800 Subject: [PATCH 14/34] refactor model component --- .../builder/deploy/mock_Deployer.go | 29 +-- api/handler/model.go | 4 +- builder/deploy/deployer.go | 4 +- common/types/model.go | 1 + common/types/repo.go | 1 + component/model.go | 42 ++-- component/model_ce.go | 33 ++++ component/model_ce_test.go | 72 +++++++ component/model_test.go | 185 ++++++------------ component/space.go | 2 +- 10 files changed, 208 insertions(+), 165 deletions(-) create mode 100644 component/model_ce.go create mode 100644 component/model_ce_test.go diff --git a/_mocks/opencsg.com/csghub-server/builder/deploy/mock_Deployer.go b/_mocks/opencsg.com/csghub-server/builder/deploy/mock_Deployer.go index 954cb181..d1520e81 100644 --- a/_mocks/opencsg.com/csghub-server/builder/deploy/mock_Deployer.go +++ b/_mocks/opencsg.com/csghub-server/builder/deploy/mock_Deployer.go @@ -26,9 +26,9 @@ func (_m *MockDeployer) EXPECT() *MockDeployer_Expecter { return &MockDeployer_Expecter{mock: &_m.Mock} } -// CheckResourceAvailable provides a mock function with given fields: ctx, clusterId, hardWare -func (_m *MockDeployer) CheckResourceAvailable(ctx context.Context, clusterId string, hardWare *types.HardWare) (bool, error) { - ret := _m.Called(ctx, clusterId, hardWare) +// CheckResourceAvailable provides a mock function with given fields: ctx, clusterId, orderDetailID, hardWare +func (_m *MockDeployer) CheckResourceAvailable(ctx context.Context, clusterId string, orderDetailID int64, hardWare *types.HardWare) (bool, error) { + ret := _m.Called(ctx, clusterId, orderDetailID, hardWare) if len(ret) == 0 { panic("no return value specified for CheckResourceAvailable") @@ -36,17 +36,17 @@ func (_m *MockDeployer) CheckResourceAvailable(ctx context.Context, clusterId st var r0 bool var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, *types.HardWare) (bool, error)); ok { - return rf(ctx, clusterId, hardWare) + if rf, ok := ret.Get(0).(func(context.Context, string, int64, *types.HardWare) (bool, error)); ok { + return rf(ctx, clusterId, orderDetailID, hardWare) } - if rf, ok := ret.Get(0).(func(context.Context, string, *types.HardWare) bool); ok { - r0 = rf(ctx, clusterId, hardWare) + if rf, ok := ret.Get(0).(func(context.Context, string, int64, *types.HardWare) bool); ok { + r0 = rf(ctx, clusterId, orderDetailID, hardWare) } else { r0 = ret.Get(0).(bool) } - if rf, ok := ret.Get(1).(func(context.Context, string, *types.HardWare) error); ok { - r1 = rf(ctx, clusterId, hardWare) + if rf, ok := ret.Get(1).(func(context.Context, string, int64, *types.HardWare) error); ok { + r1 = rf(ctx, clusterId, orderDetailID, hardWare) } else { r1 = ret.Error(1) } @@ -62,14 +62,15 @@ type MockDeployer_CheckResourceAvailable_Call struct { // CheckResourceAvailable is a helper method to define mock.On call // - ctx context.Context // - clusterId string +// - orderDetailID int64 // - hardWare *types.HardWare -func (_e *MockDeployer_Expecter) CheckResourceAvailable(ctx interface{}, clusterId interface{}, hardWare interface{}) *MockDeployer_CheckResourceAvailable_Call { - return &MockDeployer_CheckResourceAvailable_Call{Call: _e.mock.On("CheckResourceAvailable", ctx, clusterId, hardWare)} +func (_e *MockDeployer_Expecter) CheckResourceAvailable(ctx interface{}, clusterId interface{}, orderDetailID interface{}, hardWare interface{}) *MockDeployer_CheckResourceAvailable_Call { + return &MockDeployer_CheckResourceAvailable_Call{Call: _e.mock.On("CheckResourceAvailable", ctx, clusterId, orderDetailID, hardWare)} } -func (_c *MockDeployer_CheckResourceAvailable_Call) Run(run func(ctx context.Context, clusterId string, hardWare *types.HardWare)) *MockDeployer_CheckResourceAvailable_Call { +func (_c *MockDeployer_CheckResourceAvailable_Call) Run(run func(ctx context.Context, clusterId string, orderDetailID int64, hardWare *types.HardWare)) *MockDeployer_CheckResourceAvailable_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(*types.HardWare)) + run(args[0].(context.Context), args[1].(string), args[2].(int64), args[3].(*types.HardWare)) }) return _c } @@ -79,7 +80,7 @@ func (_c *MockDeployer_CheckResourceAvailable_Call) Return(_a0 bool, _a1 error) return _c } -func (_c *MockDeployer_CheckResourceAvailable_Call) RunAndReturn(run func(context.Context, string, *types.HardWare) (bool, error)) *MockDeployer_CheckResourceAvailable_Call { +func (_c *MockDeployer_CheckResourceAvailable_Call) RunAndReturn(run func(context.Context, string, int64, *types.HardWare) (bool, error)) *MockDeployer_CheckResourceAvailable_Call { _c.Call.Return(run) return _c } diff --git a/api/handler/model.go b/api/handler/model.go index 8f0df98d..cd0eb829 100644 --- a/api/handler/model.go +++ b/api/handler/model.go @@ -92,7 +92,7 @@ func (h *ModelHandler) Index(ctx *gin.Context) { return } - models, total, err := h.c.Index(ctx, filter, per, page) + models, total, err := h.c.Index(ctx, filter, per, page, false) if err != nil { slog.Error("Failed to get models", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -264,7 +264,7 @@ func (h *ModelHandler) Show(ctx *gin.Context) { return } currentUser := httpbase.GetCurrentUser(ctx) - detail, err := h.c.Show(ctx, namespace, name, currentUser) + detail, err := h.c.Show(ctx, namespace, name, currentUser, false) if err != nil { if errors.Is(err, component.ErrUnauthorized) { httpbase.UnauthorizedError(ctx, err) diff --git a/builder/deploy/deployer.go b/builder/deploy/deployer.go index 8e05d81a..98a81886 100644 --- a/builder/deploy/deployer.go +++ b/builder/deploy/deployer.go @@ -38,7 +38,7 @@ type Deployer interface { UpdateCluster(ctx context.Context, data types.ClusterRequest) (*types.UpdateClusterResponse, error) UpdateDeploy(ctx context.Context, dur *types.DeployUpdateReq, deploy *database.Deploy) error StartDeploy(ctx context.Context, deploy *database.Deploy) error - CheckResourceAvailable(ctx context.Context, clusterId string, hardWare *types.HardWare) (bool, error) + CheckResourceAvailable(ctx context.Context, clusterId string, orderDetailID int64, hardWare *types.HardWare) (bool, error) SubmitEvaluation(ctx context.Context, req types.EvaluationReq) (*types.ArgoWorkFlowRes, error) ListEvaluations(context.Context, string, int, int) (*types.ArgoWorkFlowListRes, error) DeleteEvaluation(ctx context.Context, req types.ArgoWorkFlowDeleteReq) error @@ -725,7 +725,7 @@ func getValidSceneType(deployType int) types.SceneType { } } -func (d *deployer) CheckResourceAvailable(ctx context.Context, clusterId string, hardWare *types.HardWare) (bool, error) { +func (d *deployer) CheckResourceAvailable(ctx context.Context, clusterId string, orderDetailID int64, hardWare *types.HardWare) (bool, error) { // backward compatibility for old api if clusterId == "" { clusters, err := d.ListCluster(ctx) diff --git a/common/types/model.go b/common/types/model.go index 028aa4a1..994218dd 100644 --- a/common/types/model.go +++ b/common/types/model.go @@ -214,6 +214,7 @@ type ModelRunReq struct { MaxReplica int `json:"max_replica"` Revision string `json:"revision"` SecureLevel int `json:"secure_level"` + OrderDetailID int64 `json:"order_detail_id"` } var _ SensitiveRequestV2 = (*ModelRunReq)(nil) diff --git a/common/types/repo.go b/common/types/repo.go index ba387aac..03ced317 100644 --- a/common/types/repo.go +++ b/common/types/repo.go @@ -145,6 +145,7 @@ type DeployRepo struct { SKU string `json:"sku,omitempty"` ResourceType string `json:"resource_type,omitempty"` RepoTag string `json:"repo_tag,omitempty"` + OrderDetailID int64 `json:"order_detail_id,omitempty"` } type RuntimeFrameworkReq struct { diff --git a/component/model.go b/component/model.go index cf64ad60..07f518f3 100644 --- a/component/model.go +++ b/component/model.go @@ -63,11 +63,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text const LFSPrefix = "version https://git-lfs.github.com/spec/v1" type ModelComponent interface { - Index(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.Model, int, error) + Index(ctx context.Context, filter *types.RepoFilter, per, page int, needOpWeight bool) ([]*types.Model, int, error) Create(ctx context.Context, req *types.CreateModelReq) (*types.Model, error) Update(ctx context.Context, req *types.UpdateModelReq) (*types.Model, error) Delete(ctx context.Context, namespace, name, currentUser string) error - Show(ctx context.Context, namespace, name, currentUser string) (*types.Model, error) + Show(ctx context.Context, namespace, name, currentUser string, needOpWeight bool) (*types.Model, error) GetServerless(ctx context.Context, namespace, name, currentUser string) (*types.DeployRepo, error) SDKModelInfo(ctx context.Context, namespace, name, ref, currentUser string) (*types.SDKModelInfo, error) Relations(ctx context.Context, namespace, name, currentUser string) (*types.Relations, error) @@ -143,10 +143,10 @@ type modelComponentImpl struct { userSvcClient rpc.UserSvcClient } -func (c *modelComponentImpl) Index(ctx context.Context, filter *types.RepoFilter, per, page int) ([]types.Model, int, error) { +func (c *modelComponentImpl) Index(ctx context.Context, filter *types.RepoFilter, per, page int, needOpWeight bool) ([]*types.Model, int, error) { var ( err error - resModels []types.Model + resModels []*types.Model ) repos, total, err := c.repoComponent.PublicToUser(ctx, types.ModelRepo, filter.Username, filter, per, page) if err != nil { @@ -187,7 +187,7 @@ func (c *modelComponentImpl) Index(ctx context.Context, filter *types.RepoFilter UpdatedAt: tag.UpdatedAt, }) } - resModels = append(resModels, types.Model{ + resModels = append(resModels, &types.Model{ ID: model.ID, Name: repo.Name, Nickname: repo.Nickname, @@ -206,6 +206,9 @@ func (c *modelComponentImpl) Index(ctx context.Context, filter *types.RepoFilter Repository: common.BuildCloneInfo(c.config, model.Repository), }) } + if needOpWeight { + c.addOpWeightToModel(ctx, repoIDs, resModels) + } return resModels, total, nil } @@ -393,7 +396,7 @@ func (c *modelComponentImpl) Delete(ctx context.Context, namespace, name, curren return nil } -func (c *modelComponentImpl) Show(ctx context.Context, namespace, name, currentUser string) (*types.Model, error) { +func (c *modelComponentImpl) Show(ctx context.Context, namespace, name, currentUser string, needOpWeight bool) (*types.Model, error) { var tags []types.RepoTag model, err := c.modelStore.FindByPath(ctx, namespace, name) if err != nil { @@ -459,10 +462,16 @@ func (c *modelComponentImpl) Show(ctx context.Context, namespace, name, currentU BaseModel: model.BaseModel, License: model.Repository.License, MirrorLastUpdatedAt: model.Repository.Mirror.LastUpdatedAt, - - CanWrite: permission.CanWrite, - CanManage: permission.CanAdmin, - Namespace: ns, + CanWrite: permission.CanWrite, + CanManage: permission.CanAdmin, + Namespace: ns, + } + // admin user or owner can see the sensitive check status + if permission.CanAdmin { + resModel.SensitiveCheckStatus = model.Repository.SensitiveCheckStatus.String() + } + if needOpWeight { + c.addOpWeightToModel(ctx, []int64{model.RepositoryID}, []*types.Model{resModel}) } inferences, _ := c.repoRuntimeFrameworkStore.GetByRepoIDsAndType(ctx, model.Repository.ID, types.InferenceType) if len(inferences) > 0 { @@ -922,17 +931,15 @@ func (c *modelComponentImpl) Deploy(ctx context.Context, deployReq types.DeployA return -1, fmt.Errorf("invalid hardware setting, %w", err) } - _, err = c.deployer.CheckResourceAvailable(ctx, req.ClusterID, &hardware) + // resource available only if err is nil, err message should contain + // the reason why resource is unavailable + err = c.resourceAvailable(ctx, resource, req, deployReq, hardware) if err != nil { - return -1, fmt.Errorf("fail to check resource, %w", err) + return -1, err } // choose image - containerImg := frame.FrameCpuImage - if hardware.Gpu.Num != "" { - // use gpu image - containerImg = frame.FrameImage - } + containerImg := c.containerImg(frame, hardware) // create deploy for model return c.deployer.Deploy(ctx, types.DeployRepo{ @@ -957,6 +964,7 @@ func (c *modelComponentImpl) Deploy(ctx context.Context, deployReq types.DeployA Type: deployReq.DeployType, UserUUID: user.UUID, SKU: strconv.FormatInt(resource.ID, 10), + OrderDetailID: req.OrderDetailID, }) } diff --git a/component/model_ce.go b/component/model_ce.go new file mode 100644 index 00000000..b0206b60 --- /dev/null +++ b/component/model_ce.go @@ -0,0 +1,33 @@ +//go:build !ee && !saas + +package component + +import ( + "context" + "fmt" + + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func (c *modelComponentImpl) addOpWeightToModel(ctx context.Context, repoIDs []int64, resModels []*types.Model) { + +} + +func (c *modelComponentImpl) resourceAvailable(ctx context.Context, resource *database.SpaceResource, req types.ModelRunReq, deployReq types.DeployActReq, hardware types.HardWare) error { + + _, err := c.deployer.CheckResourceAvailable(ctx, req.ClusterID, 0, &hardware) + if err != nil { + return fmt.Errorf("fail to check resource, %w", err) + } + return nil +} + +func (c *modelComponentImpl) containerImg(frame *database.RuntimeFramework, hardware types.HardWare) string { + containerImg := frame.FrameCpuImage + if hardware.Gpu.Num != "" { + containerImg = frame.FrameImage + } + return containerImg + +} diff --git a/component/model_ce_test.go b/component/model_ce_test.go new file mode 100644 index 00000000..6925c5db --- /dev/null +++ b/component/model_ce_test.go @@ -0,0 +1,72 @@ +//go:build !ee && !saas + +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestModelComponent_Deploy(t *testing.T) { + ctx := context.TODO() + mc := initializeTestModelComponent(ctx, t) + + mc.mocks.stores.ModelMock().EXPECT().FindByPath(ctx, "ns", "n").Return(&database.Model{ + RepositoryID: int64(123), + Repository: &database.Repository{ + ID: 1, + Path: "foo", + }, + }, nil) + mc.mocks.stores.DeployTaskMock().EXPECT().GetServerlessDeployByRepID(ctx, int64(1)).Return( + nil, nil, + ) + mc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + RoleMask: "admin", + }, nil) + mc.mocks.stores.RuntimeFrameworkMock().EXPECT().FindEnabledByID(ctx, int64(11)).Return( + &database.RuntimeFramework{}, nil, + ) + mc.mocks.components.repo.EXPECT().IsAdminRole(database.User{ + RoleMask: "admin", + }).Return(true) + mc.mocks.stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(123)).Return( + &database.SpaceResource{ + ID: 123, + Resources: `{"memory": "foo"}`, + }, nil, + ) + + mc.mocks.deployer.EXPECT().CheckResourceAvailable(ctx, "cluster", int64(0), &types.HardWare{ + Memory: "foo", + }).Return(true, nil) + mc.mocks.deployer.EXPECT().Deploy(ctx, types.DeployRepo{ + DeployName: "dp", + Path: "foo", + Hardware: "{\"memory\": \"foo\"}", + Annotation: "{\"hub-res-name\":\"ns/n\",\"hub-res-type\":\"model\"}", + ClusterID: "cluster", + RepoID: 1, + SKU: "123", + Type: types.ServerlessType, + }).Return(111, nil) + + id, err := mc.Deploy(ctx, types.DeployActReq{ + Namespace: "ns", + Name: "n", + CurrentUser: "user", + DeployType: types.ServerlessType, + }, types.ModelRunReq{ + RuntimeFrameworkID: 11, + ResourceID: 123, + ClusterID: "cluster", + DeployName: "dp", + }) + require.Nil(t, err) + require.Equal(t, int64(111), id) + +} diff --git a/component/model_test.go b/component/model_test.go index 3f49f915..58bf152e 100644 --- a/component/model_test.go +++ b/component/model_test.go @@ -13,45 +13,45 @@ import ( "opencsg.com/csghub-server/common/types" ) -// func TestModelComponent_Index(t *testing.T) { -// ctx := context.TODO() -// mc := initializeTestModelComponent(ctx, t) - -// filter := &types.RepoFilter{Username: "user"} -// mc.mocks.components.repo.EXPECT().PublicToUser(ctx, types.ModelRepo, "user", filter, 10, 1).Return( -// []*database.Repository{ -// {ID: 1, Name: "r1", Tags: []database.Tag{{Name: "t1"}}}, -// {ID: 2, Name: "r2", Tags: []database.Tag{{Name: "t2"}}}, -// }, 100, nil, -// ) - -// mc.mocks.stores.ModelMock().EXPECT().ByRepoIDs(ctx, []int64{1, 2}).Return([]database.Model{ -// {RepositoryID: 1, ID: 11, Repository: &database.Repository{}}, -// {RepositoryID: 2, ID: 12, Repository: &database.Repository{}}, -// }, nil) - -// data, total, err := mc.Index(ctx, filter, 10, 1) -// require.Nil(t, err) -// require.Equal(t, 100, total) - -// require.Equal(t, []*types.Model{ -// { -// ID: 11, Name: "r1", Tags: []types.RepoTag{{Name: "t1"}}, RepositoryID: 1, -// Repository: types.Repository{ -// HTTPCloneURL: "https://foo.com/s/.git", -// SSHCloneURL: "test@127.0.0.1:s/.git", -// }, -// }, -// { -// ID: 12, Name: "r2", Tags: []types.RepoTag{{Name: "t2"}}, RepositoryID: 2, -// Repository: types.Repository{ -// HTTPCloneURL: "https://foo.com/s/.git", -// SSHCloneURL: "test@127.0.0.1:s/.git", -// }, -// }, -// }, data) - -// } +func TestModelComponent_Index(t *testing.T) { + ctx := context.TODO() + mc := initializeTestModelComponent(ctx, t) + + filter := &types.RepoFilter{Username: "user"} + mc.mocks.components.repo.EXPECT().PublicToUser(ctx, types.ModelRepo, "user", filter, 10, 1).Return( + []*database.Repository{ + {ID: 1, Name: "r1", Tags: []database.Tag{{Name: "t1"}}}, + {ID: 2, Name: "r2", Tags: []database.Tag{{Name: "t2"}}}, + }, 100, nil, + ) + + mc.mocks.stores.ModelMock().EXPECT().ByRepoIDs(ctx, []int64{1, 2}).Return([]database.Model{ + {RepositoryID: 1, ID: 11, Repository: &database.Repository{}}, + {RepositoryID: 2, ID: 12, Repository: &database.Repository{}}, + }, nil) + + data, total, err := mc.Index(ctx, filter, 10, 1, false) + require.Nil(t, err) + require.Equal(t, 100, total) + + require.Equal(t, []*types.Model{ + { + ID: 11, Name: "r1", Tags: []types.RepoTag{{Name: "t1"}}, RepositoryID: 1, + Repository: types.Repository{ + HTTPCloneURL: "https://foo.com/s/.git", + SSHCloneURL: "test@127.0.0.1:s/.git", + }, + }, + { + ID: 12, Name: "r2", Tags: []types.RepoTag{{Name: "t2"}}, RepositoryID: 2, + Repository: types.Repository{ + HTTPCloneURL: "https://foo.com/s/.git", + SSHCloneURL: "test@127.0.0.1:s/.git", + }, + }, + }, data) + +} func TestModelComponent_Create(t *testing.T) { ctx := context.TODO() @@ -216,17 +216,18 @@ func TestModelComponent_Show(t *testing.T) { ctx, int64(123), mock.Anything, ).Return([]database.RepositoriesRuntimeFramework{{}}, nil) - model, err := mc.Show(ctx, "ns", "n", "user") + model, err := mc.Show(ctx, "ns", "n", "user", false) require.Nil(t, err) require.Equal(t, &types.Model{ - ID: 1, - Name: "n", - Namespace: &types.Namespace{Path: "ns"}, - UserLikes: true, - RepositoryID: 123, - CanManage: true, - User: &types.User{}, - Path: "foo/bar", + ID: 1, + Name: "n", + Namespace: &types.Namespace{Path: "ns"}, + UserLikes: true, + RepositoryID: 123, + CanManage: true, + User: &types.User{}, + Path: "foo/bar", + SensitiveCheckStatus: "Pending", Repository: types.Repository{ HTTPCloneURL: "https://foo.com/s/foo/bar.git", SSHCloneURL: "test@127.0.0.1:s/foo/bar.git", @@ -457,83 +458,6 @@ func TestModelComponent_DeleteRelationDataset(t *testing.T) { require.Nil(t, err) } -// func TestModelComponent_Predict(t *testing.T) { -// ctx := context.TODO() -// mc := initializeTestModelComponent(ctx, t) - -// resp, err := mc.Predict(ctx, &types.ModelPredictReq{ -// Namespace: "ns", -// Name: "n", -// Input: "foo", -// CurrentUser: "user", -// }) -// require.Nil(t, err) -// require.Equal(t, &types.ModelPredictResp{ -// Content: "abcd", -// }, resp) - -// } - -// func TestModelComponent_Deploy(t *testing.T) { -// ctx := context.TODO() -// mc := initializeTestModelComponent(ctx, t) - -// mc.mocks.stores.ModelMock().EXPECT().FindByPath(ctx, "ns", "n").Return(&database.Model{ -// RepositoryID: int64(123), -// Repository: &database.Repository{ -// ID: 1, -// Path: "foo", -// }, -// }, nil) -// mc.mocks.stores.DeployTaskMock().EXPECT().GetServerlessDeployByRepID(ctx, int64(1)).Return( -// nil, nil, -// ) -// mc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ -// RoleMask: "admin", -// }, nil) -// mc.mocks.stores.RuntimeFrameworkMock().EXPECT().FindEnabledByID(ctx, int64(11)).Return( -// &database.RuntimeFramework{}, nil, -// ) -// mc.mocks.components.repo.EXPECT().IsAdminRole(database.User{ -// RoleMask: "admin", -// }).Return(true) -// mc.mocks.stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(123)).Return( -// &database.SpaceResource{ -// ID: 123, -// Resources: `{"memory": "foo"}`, -// }, nil, -// ) - -// mc.mocks.deployer.EXPECT().CheckResourceAvailable(ctx, int64(0), &types.HardWare{ -// Memory: "foo", -// }).Return(true, nil) -// mc.mocks.deployer.EXPECT().Deploy(ctx, types.DeployRepo{ -// DeployName: "dp", -// Path: "foo", -// Hardware: "{\"memory\": \"foo\"}", -// Annotation: "{\"hub-res-name\":\"ns/n\",\"hub-res-type\":\"model\"}", -// ClusterID: "cluster", -// RepoID: 1, -// SKU: "123", -// Type: types.ServerlessType, -// }).Return(111, nil) - -// id, err := mc.Deploy(ctx, types.DeployActReq{ -// Namespace: "ns", -// Name: "n", -// CurrentUser: "user", -// DeployType: types.ServerlessType, -// }, types.ModelRunReq{ -// RuntimeFrameworkID: 11, -// ResourceID: 123, -// ClusterID: "cluster", -// DeployName: "dp", -// }) -// require.Nil(t, err) -// require.Equal(t, int64(111), id) - -// } - func TestModelComponent_ListModelsByRuntimeFrameworkID(t *testing.T) { ctx := context.TODO() mc := initializeTestModelComponent(ctx, t) @@ -562,22 +486,25 @@ func TestModelComponent_SetRuntimeFrameworkModes(t *testing.T) { mc := initializeTestModelComponent(ctx, t) mc.mocks.stores.RuntimeFrameworkMock().EXPECT().FindByID(ctx, int64(1)).Return( - &database.RuntimeFramework{ - ID: 1, - }, nil, + &database.RuntimeFramework{}, nil, ) mc.mocks.stores.ModelMock().EXPECT().ListByPath(ctx, []string{"a", "b"}).Return( []database.Model{ {RepositoryID: 1, Repository: &database.Repository{ID: 1, Path: "m1/foo"}}, + {RepositoryID: 2, Repository: &database.Repository{ID: 2, Path: "m2/foo"}}, }, nil, ) rftags := []*database.Tag{{Name: "t1"}, {Name: "t2"}} mc.mocks.stores.TagMock().EXPECT().GetTagsByScopeAndCategories( ctx, database.TagScope("model"), []string{"runtime_framework", "resource"}, ).Return(rftags, nil) + mc.mocks.stores.RepoRuntimeFrameworkMock().EXPECT().GetByIDsAndType( ctx, int64(1), int64(1), 1, - ).Return(nil, nil) + ).Return([]database.RepositoriesRuntimeFramework{}, nil) + mc.mocks.stores.RepoRuntimeFrameworkMock().EXPECT().GetByIDsAndType( + ctx, int64(1), int64(2), 1, + ).Return([]database.RepositoriesRuntimeFramework{{}}, nil) mc.mocks.stores.RepoRuntimeFrameworkMock().EXPECT().Add(ctx, int64(1), int64(1), 1).Return(nil) mc.mocks.components.runtimeArchitecture.EXPECT().AddRuntimeFrameworkTag( diff --git a/component/space.go b/component/space.go index d91912a5..422d41b8 100644 --- a/component/space.go +++ b/component/space.go @@ -125,7 +125,7 @@ func (c *spaceComponentImpl) Create(ctx context.Context, req types.CreateSpaceRe if err != nil { return nil, fmt.Errorf("invalid hardware setting, %w", err) } - _, err = c.deployer.CheckResourceAvailable(ctx, req.ClusterID, &hardware) + _, err = c.deployer.CheckResourceAvailable(ctx, req.ClusterID, 0, &hardware) if err != nil { return nil, fmt.Errorf("fail to check resource, %w", err) } From 4714f8850a761b0a2e4e8480ad306b34fea6f806 Mon Sep 17 00:00:00 2001 From: yiling Date: Mon, 23 Dec 2024 16:13:17 +0800 Subject: [PATCH 15/34] update mocks --- .mockery.yaml | 13 + .../deploy/imagebuilder/mock_Builder.go | 214 ++++ .../builder/deploy/imagerunner/mock_Runner.go | 984 ++++++++++++++++++ .../deploy/scheduler/mock_Scheduler.go | 123 +++ 4 files changed, 1334 insertions(+) create mode 100644 _mocks/opencsg.com/csghub-server/builder/deploy/imagebuilder/mock_Builder.go create mode 100644 _mocks/opencsg.com/csghub-server/builder/deploy/imagerunner/mock_Runner.go create mode 100644 _mocks/opencsg.com/csghub-server/builder/deploy/scheduler/mock_Scheduler.go diff --git a/.mockery.yaml b/.mockery.yaml index 312ef32b..cbda4ba7 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -75,6 +75,19 @@ packages: config: interfaces: Deployer: + opencsg.com/csghub-server/builder/deploy/scheduler: + config: + interfaces: + Scheduler: + opencsg.com/csghub-server/builder/deploy/imagerunner: + config: + interfaces: + Runner: + opencsg.com/csghub-server/builder/deploy/imagebuilder: + config: + interfaces: + Builder: + opencsg.com/csghub-server/accounting/component: config: interfaces: diff --git a/_mocks/opencsg.com/csghub-server/builder/deploy/imagebuilder/mock_Builder.go b/_mocks/opencsg.com/csghub-server/builder/deploy/imagebuilder/mock_Builder.go new file mode 100644 index 00000000..ca4a0a85 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/builder/deploy/imagebuilder/mock_Builder.go @@ -0,0 +1,214 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package imagebuilder + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + imagebuilder "opencsg.com/csghub-server/builder/deploy/imagebuilder" +) + +// MockBuilder is an autogenerated mock type for the Builder type +type MockBuilder struct { + mock.Mock +} + +type MockBuilder_Expecter struct { + mock *mock.Mock +} + +func (_m *MockBuilder) EXPECT() *MockBuilder_Expecter { + return &MockBuilder_Expecter{mock: &_m.Mock} +} + +// Build provides a mock function with given fields: _a0, _a1 +func (_m *MockBuilder) Build(_a0 context.Context, _a1 *imagebuilder.BuildRequest) (*imagebuilder.BuildResponse, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Build") + } + + var r0 *imagebuilder.BuildResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *imagebuilder.BuildRequest) (*imagebuilder.BuildResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *imagebuilder.BuildRequest) *imagebuilder.BuildResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*imagebuilder.BuildResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *imagebuilder.BuildRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBuilder_Build_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Build' +type MockBuilder_Build_Call struct { + *mock.Call +} + +// Build is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *imagebuilder.BuildRequest +func (_e *MockBuilder_Expecter) Build(_a0 interface{}, _a1 interface{}) *MockBuilder_Build_Call { + return &MockBuilder_Build_Call{Call: _e.mock.On("Build", _a0, _a1)} +} + +func (_c *MockBuilder_Build_Call) Run(run func(_a0 context.Context, _a1 *imagebuilder.BuildRequest)) *MockBuilder_Build_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*imagebuilder.BuildRequest)) + }) + return _c +} + +func (_c *MockBuilder_Build_Call) Return(_a0 *imagebuilder.BuildResponse, _a1 error) *MockBuilder_Build_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBuilder_Build_Call) RunAndReturn(run func(context.Context, *imagebuilder.BuildRequest) (*imagebuilder.BuildResponse, error)) *MockBuilder_Build_Call { + _c.Call.Return(run) + return _c +} + +// Logs provides a mock function with given fields: _a0, _a1 +func (_m *MockBuilder) Logs(_a0 context.Context, _a1 *imagebuilder.LogsRequest) (<-chan string, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Logs") + } + + var r0 <-chan string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *imagebuilder.LogsRequest) (<-chan string, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *imagebuilder.LogsRequest) <-chan string); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *imagebuilder.LogsRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBuilder_Logs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Logs' +type MockBuilder_Logs_Call struct { + *mock.Call +} + +// Logs is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *imagebuilder.LogsRequest +func (_e *MockBuilder_Expecter) Logs(_a0 interface{}, _a1 interface{}) *MockBuilder_Logs_Call { + return &MockBuilder_Logs_Call{Call: _e.mock.On("Logs", _a0, _a1)} +} + +func (_c *MockBuilder_Logs_Call) Run(run func(_a0 context.Context, _a1 *imagebuilder.LogsRequest)) *MockBuilder_Logs_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*imagebuilder.LogsRequest)) + }) + return _c +} + +func (_c *MockBuilder_Logs_Call) Return(_a0 <-chan string, _a1 error) *MockBuilder_Logs_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBuilder_Logs_Call) RunAndReturn(run func(context.Context, *imagebuilder.LogsRequest) (<-chan string, error)) *MockBuilder_Logs_Call { + _c.Call.Return(run) + return _c +} + +// Status provides a mock function with given fields: _a0, _a1 +func (_m *MockBuilder) Status(_a0 context.Context, _a1 *imagebuilder.StatusRequest) (*imagebuilder.StatusResponse, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Status") + } + + var r0 *imagebuilder.StatusResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *imagebuilder.StatusRequest) (*imagebuilder.StatusResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *imagebuilder.StatusRequest) *imagebuilder.StatusResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*imagebuilder.StatusResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *imagebuilder.StatusRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBuilder_Status_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Status' +type MockBuilder_Status_Call struct { + *mock.Call +} + +// Status is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *imagebuilder.StatusRequest +func (_e *MockBuilder_Expecter) Status(_a0 interface{}, _a1 interface{}) *MockBuilder_Status_Call { + return &MockBuilder_Status_Call{Call: _e.mock.On("Status", _a0, _a1)} +} + +func (_c *MockBuilder_Status_Call) Run(run func(_a0 context.Context, _a1 *imagebuilder.StatusRequest)) *MockBuilder_Status_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*imagebuilder.StatusRequest)) + }) + return _c +} + +func (_c *MockBuilder_Status_Call) Return(_a0 *imagebuilder.StatusResponse, _a1 error) *MockBuilder_Status_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBuilder_Status_Call) RunAndReturn(run func(context.Context, *imagebuilder.StatusRequest) (*imagebuilder.StatusResponse, error)) *MockBuilder_Status_Call { + _c.Call.Return(run) + return _c +} + +// NewMockBuilder creates a new instance of MockBuilder. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockBuilder(t interface { + mock.TestingT + Cleanup(func()) +}) *MockBuilder { + mock := &MockBuilder{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/builder/deploy/imagerunner/mock_Runner.go b/_mocks/opencsg.com/csghub-server/builder/deploy/imagerunner/mock_Runner.go new file mode 100644 index 00000000..2bb02537 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/builder/deploy/imagerunner/mock_Runner.go @@ -0,0 +1,984 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package imagerunner + +import ( + context "context" + + httpbase "opencsg.com/csghub-server/api/httpbase" + + mock "github.com/stretchr/testify/mock" + + types "opencsg.com/csghub-server/common/types" +) + +// MockRunner is an autogenerated mock type for the Runner type +type MockRunner struct { + mock.Mock +} + +type MockRunner_Expecter struct { + mock *mock.Mock +} + +func (_m *MockRunner) EXPECT() *MockRunner_Expecter { + return &MockRunner_Expecter{mock: &_m.Mock} +} + +// DeleteWorkFlow provides a mock function with given fields: _a0, _a1 +func (_m *MockRunner) DeleteWorkFlow(_a0 context.Context, _a1 types.ArgoWorkFlowDeleteReq) (*httpbase.R, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for DeleteWorkFlow") + } + + var r0 *httpbase.R + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.ArgoWorkFlowDeleteReq) (*httpbase.R, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, types.ArgoWorkFlowDeleteReq) *httpbase.R); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*httpbase.R) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.ArgoWorkFlowDeleteReq) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRunner_DeleteWorkFlow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteWorkFlow' +type MockRunner_DeleteWorkFlow_Call struct { + *mock.Call +} + +// DeleteWorkFlow is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 types.ArgoWorkFlowDeleteReq +func (_e *MockRunner_Expecter) DeleteWorkFlow(_a0 interface{}, _a1 interface{}) *MockRunner_DeleteWorkFlow_Call { + return &MockRunner_DeleteWorkFlow_Call{Call: _e.mock.On("DeleteWorkFlow", _a0, _a1)} +} + +func (_c *MockRunner_DeleteWorkFlow_Call) Run(run func(_a0 context.Context, _a1 types.ArgoWorkFlowDeleteReq)) *MockRunner_DeleteWorkFlow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.ArgoWorkFlowDeleteReq)) + }) + return _c +} + +func (_c *MockRunner_DeleteWorkFlow_Call) Return(_a0 *httpbase.R, _a1 error) *MockRunner_DeleteWorkFlow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRunner_DeleteWorkFlow_Call) RunAndReturn(run func(context.Context, types.ArgoWorkFlowDeleteReq) (*httpbase.R, error)) *MockRunner_DeleteWorkFlow_Call { + _c.Call.Return(run) + return _c +} + +// Exist provides a mock function with given fields: _a0, _a1 +func (_m *MockRunner) Exist(_a0 context.Context, _a1 *types.CheckRequest) (*types.StatusResponse, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Exist") + } + + var r0 *types.StatusResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.CheckRequest) (*types.StatusResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.CheckRequest) *types.StatusResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.StatusResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.CheckRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRunner_Exist_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Exist' +type MockRunner_Exist_Call struct { + *mock.Call +} + +// Exist is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *types.CheckRequest +func (_e *MockRunner_Expecter) Exist(_a0 interface{}, _a1 interface{}) *MockRunner_Exist_Call { + return &MockRunner_Exist_Call{Call: _e.mock.On("Exist", _a0, _a1)} +} + +func (_c *MockRunner_Exist_Call) Run(run func(_a0 context.Context, _a1 *types.CheckRequest)) *MockRunner_Exist_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.CheckRequest)) + }) + return _c +} + +func (_c *MockRunner_Exist_Call) Return(_a0 *types.StatusResponse, _a1 error) *MockRunner_Exist_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRunner_Exist_Call) RunAndReturn(run func(context.Context, *types.CheckRequest) (*types.StatusResponse, error)) *MockRunner_Exist_Call { + _c.Call.Return(run) + return _c +} + +// GetClusterById provides a mock function with given fields: ctx, clusterId +func (_m *MockRunner) GetClusterById(ctx context.Context, clusterId string) (*types.ClusterResponse, error) { + ret := _m.Called(ctx, clusterId) + + if len(ret) == 0 { + panic("no return value specified for GetClusterById") + } + + var r0 *types.ClusterResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*types.ClusterResponse, error)); ok { + return rf(ctx, clusterId) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *types.ClusterResponse); ok { + r0 = rf(ctx, clusterId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.ClusterResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, clusterId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRunner_GetClusterById_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetClusterById' +type MockRunner_GetClusterById_Call struct { + *mock.Call +} + +// GetClusterById is a helper method to define mock.On call +// - ctx context.Context +// - clusterId string +func (_e *MockRunner_Expecter) GetClusterById(ctx interface{}, clusterId interface{}) *MockRunner_GetClusterById_Call { + return &MockRunner_GetClusterById_Call{Call: _e.mock.On("GetClusterById", ctx, clusterId)} +} + +func (_c *MockRunner_GetClusterById_Call) Run(run func(ctx context.Context, clusterId string)) *MockRunner_GetClusterById_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockRunner_GetClusterById_Call) Return(_a0 *types.ClusterResponse, _a1 error) *MockRunner_GetClusterById_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRunner_GetClusterById_Call) RunAndReturn(run func(context.Context, string) (*types.ClusterResponse, error)) *MockRunner_GetClusterById_Call { + _c.Call.Return(run) + return _c +} + +// GetReplica provides a mock function with given fields: _a0, _a1 +func (_m *MockRunner) GetReplica(_a0 context.Context, _a1 *types.StatusRequest) (*types.ReplicaResponse, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for GetReplica") + } + + var r0 *types.ReplicaResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.StatusRequest) (*types.ReplicaResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.StatusRequest) *types.ReplicaResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.ReplicaResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.StatusRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRunner_GetReplica_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetReplica' +type MockRunner_GetReplica_Call struct { + *mock.Call +} + +// GetReplica is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *types.StatusRequest +func (_e *MockRunner_Expecter) GetReplica(_a0 interface{}, _a1 interface{}) *MockRunner_GetReplica_Call { + return &MockRunner_GetReplica_Call{Call: _e.mock.On("GetReplica", _a0, _a1)} +} + +func (_c *MockRunner_GetReplica_Call) Run(run func(_a0 context.Context, _a1 *types.StatusRequest)) *MockRunner_GetReplica_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.StatusRequest)) + }) + return _c +} + +func (_c *MockRunner_GetReplica_Call) Return(_a0 *types.ReplicaResponse, _a1 error) *MockRunner_GetReplica_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRunner_GetReplica_Call) RunAndReturn(run func(context.Context, *types.StatusRequest) (*types.ReplicaResponse, error)) *MockRunner_GetReplica_Call { + _c.Call.Return(run) + return _c +} + +// GetWorkFlow provides a mock function with given fields: _a0, _a1 +func (_m *MockRunner) GetWorkFlow(_a0 context.Context, _a1 types.ArgoWorkFlowDeleteReq) (*types.ArgoWorkFlowRes, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for GetWorkFlow") + } + + var r0 *types.ArgoWorkFlowRes + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.ArgoWorkFlowDeleteReq) (*types.ArgoWorkFlowRes, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, types.ArgoWorkFlowDeleteReq) *types.ArgoWorkFlowRes); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.ArgoWorkFlowRes) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.ArgoWorkFlowDeleteReq) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRunner_GetWorkFlow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetWorkFlow' +type MockRunner_GetWorkFlow_Call struct { + *mock.Call +} + +// GetWorkFlow is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 types.ArgoWorkFlowDeleteReq +func (_e *MockRunner_Expecter) GetWorkFlow(_a0 interface{}, _a1 interface{}) *MockRunner_GetWorkFlow_Call { + return &MockRunner_GetWorkFlow_Call{Call: _e.mock.On("GetWorkFlow", _a0, _a1)} +} + +func (_c *MockRunner_GetWorkFlow_Call) Run(run func(_a0 context.Context, _a1 types.ArgoWorkFlowDeleteReq)) *MockRunner_GetWorkFlow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.ArgoWorkFlowDeleteReq)) + }) + return _c +} + +func (_c *MockRunner_GetWorkFlow_Call) Return(_a0 *types.ArgoWorkFlowRes, _a1 error) *MockRunner_GetWorkFlow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRunner_GetWorkFlow_Call) RunAndReturn(run func(context.Context, types.ArgoWorkFlowDeleteReq) (*types.ArgoWorkFlowRes, error)) *MockRunner_GetWorkFlow_Call { + _c.Call.Return(run) + return _c +} + +// InstanceLogs provides a mock function with given fields: _a0, _a1 +func (_m *MockRunner) InstanceLogs(_a0 context.Context, _a1 *types.InstanceLogsRequest) (<-chan string, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for InstanceLogs") + } + + var r0 <-chan string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.InstanceLogsRequest) (<-chan string, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.InstanceLogsRequest) <-chan string); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.InstanceLogsRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRunner_InstanceLogs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InstanceLogs' +type MockRunner_InstanceLogs_Call struct { + *mock.Call +} + +// InstanceLogs is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *types.InstanceLogsRequest +func (_e *MockRunner_Expecter) InstanceLogs(_a0 interface{}, _a1 interface{}) *MockRunner_InstanceLogs_Call { + return &MockRunner_InstanceLogs_Call{Call: _e.mock.On("InstanceLogs", _a0, _a1)} +} + +func (_c *MockRunner_InstanceLogs_Call) Run(run func(_a0 context.Context, _a1 *types.InstanceLogsRequest)) *MockRunner_InstanceLogs_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.InstanceLogsRequest)) + }) + return _c +} + +func (_c *MockRunner_InstanceLogs_Call) Return(_a0 <-chan string, _a1 error) *MockRunner_InstanceLogs_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRunner_InstanceLogs_Call) RunAndReturn(run func(context.Context, *types.InstanceLogsRequest) (<-chan string, error)) *MockRunner_InstanceLogs_Call { + _c.Call.Return(run) + return _c +} + +// ListCluster provides a mock function with given fields: ctx +func (_m *MockRunner) ListCluster(ctx context.Context) ([]types.ClusterResponse, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for ListCluster") + } + + var r0 []types.ClusterResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]types.ClusterResponse, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []types.ClusterResponse); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.ClusterResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRunner_ListCluster_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListCluster' +type MockRunner_ListCluster_Call struct { + *mock.Call +} + +// ListCluster is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockRunner_Expecter) ListCluster(ctx interface{}) *MockRunner_ListCluster_Call { + return &MockRunner_ListCluster_Call{Call: _e.mock.On("ListCluster", ctx)} +} + +func (_c *MockRunner_ListCluster_Call) Run(run func(ctx context.Context)) *MockRunner_ListCluster_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockRunner_ListCluster_Call) Return(_a0 []types.ClusterResponse, _a1 error) *MockRunner_ListCluster_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRunner_ListCluster_Call) RunAndReturn(run func(context.Context) ([]types.ClusterResponse, error)) *MockRunner_ListCluster_Call { + _c.Call.Return(run) + return _c +} + +// ListWorkFlows provides a mock function with given fields: _a0, _a1, _a2, _a3 +func (_m *MockRunner) ListWorkFlows(_a0 context.Context, _a1 string, _a2 int, _a3 int) (*types.ArgoWorkFlowListRes, error) { + ret := _m.Called(_a0, _a1, _a2, _a3) + + if len(ret) == 0 { + panic("no return value specified for ListWorkFlows") + } + + var r0 *types.ArgoWorkFlowListRes + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, int, int) (*types.ArgoWorkFlowListRes, error)); ok { + return rf(_a0, _a1, _a2, _a3) + } + if rf, ok := ret.Get(0).(func(context.Context, string, int, int) *types.ArgoWorkFlowListRes); ok { + r0 = rf(_a0, _a1, _a2, _a3) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.ArgoWorkFlowListRes) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, int, int) error); ok { + r1 = rf(_a0, _a1, _a2, _a3) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRunner_ListWorkFlows_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListWorkFlows' +type MockRunner_ListWorkFlows_Call struct { + *mock.Call +} + +// ListWorkFlows is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 string +// - _a2 int +// - _a3 int +func (_e *MockRunner_Expecter) ListWorkFlows(_a0 interface{}, _a1 interface{}, _a2 interface{}, _a3 interface{}) *MockRunner_ListWorkFlows_Call { + return &MockRunner_ListWorkFlows_Call{Call: _e.mock.On("ListWorkFlows", _a0, _a1, _a2, _a3)} +} + +func (_c *MockRunner_ListWorkFlows_Call) Run(run func(_a0 context.Context, _a1 string, _a2 int, _a3 int)) *MockRunner_ListWorkFlows_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(int), args[3].(int)) + }) + return _c +} + +func (_c *MockRunner_ListWorkFlows_Call) Return(_a0 *types.ArgoWorkFlowListRes, _a1 error) *MockRunner_ListWorkFlows_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRunner_ListWorkFlows_Call) RunAndReturn(run func(context.Context, string, int, int) (*types.ArgoWorkFlowListRes, error)) *MockRunner_ListWorkFlows_Call { + _c.Call.Return(run) + return _c +} + +// Logs provides a mock function with given fields: _a0, _a1 +func (_m *MockRunner) Logs(_a0 context.Context, _a1 *types.LogsRequest) (<-chan string, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Logs") + } + + var r0 <-chan string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.LogsRequest) (<-chan string, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.LogsRequest) <-chan string); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.LogsRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRunner_Logs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Logs' +type MockRunner_Logs_Call struct { + *mock.Call +} + +// Logs is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *types.LogsRequest +func (_e *MockRunner_Expecter) Logs(_a0 interface{}, _a1 interface{}) *MockRunner_Logs_Call { + return &MockRunner_Logs_Call{Call: _e.mock.On("Logs", _a0, _a1)} +} + +func (_c *MockRunner_Logs_Call) Run(run func(_a0 context.Context, _a1 *types.LogsRequest)) *MockRunner_Logs_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.LogsRequest)) + }) + return _c +} + +func (_c *MockRunner_Logs_Call) Return(_a0 <-chan string, _a1 error) *MockRunner_Logs_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRunner_Logs_Call) RunAndReturn(run func(context.Context, *types.LogsRequest) (<-chan string, error)) *MockRunner_Logs_Call { + _c.Call.Return(run) + return _c +} + +// Purge provides a mock function with given fields: _a0, _a1 +func (_m *MockRunner) Purge(_a0 context.Context, _a1 *types.PurgeRequest) (*types.PurgeResponse, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Purge") + } + + var r0 *types.PurgeResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.PurgeRequest) (*types.PurgeResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.PurgeRequest) *types.PurgeResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.PurgeResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.PurgeRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRunner_Purge_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Purge' +type MockRunner_Purge_Call struct { + *mock.Call +} + +// Purge is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *types.PurgeRequest +func (_e *MockRunner_Expecter) Purge(_a0 interface{}, _a1 interface{}) *MockRunner_Purge_Call { + return &MockRunner_Purge_Call{Call: _e.mock.On("Purge", _a0, _a1)} +} + +func (_c *MockRunner_Purge_Call) Run(run func(_a0 context.Context, _a1 *types.PurgeRequest)) *MockRunner_Purge_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.PurgeRequest)) + }) + return _c +} + +func (_c *MockRunner_Purge_Call) Return(_a0 *types.PurgeResponse, _a1 error) *MockRunner_Purge_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRunner_Purge_Call) RunAndReturn(run func(context.Context, *types.PurgeRequest) (*types.PurgeResponse, error)) *MockRunner_Purge_Call { + _c.Call.Return(run) + return _c +} + +// Run provides a mock function with given fields: _a0, _a1 +func (_m *MockRunner) Run(_a0 context.Context, _a1 *types.RunRequest) (*types.RunResponse, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Run") + } + + var r0 *types.RunResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.RunRequest) (*types.RunResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.RunRequest) *types.RunResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.RunResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.RunRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRunner_Run_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Run' +type MockRunner_Run_Call struct { + *mock.Call +} + +// Run is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *types.RunRequest +func (_e *MockRunner_Expecter) Run(_a0 interface{}, _a1 interface{}) *MockRunner_Run_Call { + return &MockRunner_Run_Call{Call: _e.mock.On("Run", _a0, _a1)} +} + +func (_c *MockRunner_Run_Call) Run(run func(_a0 context.Context, _a1 *types.RunRequest)) *MockRunner_Run_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.RunRequest)) + }) + return _c +} + +func (_c *MockRunner_Run_Call) Return(_a0 *types.RunResponse, _a1 error) *MockRunner_Run_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRunner_Run_Call) RunAndReturn(run func(context.Context, *types.RunRequest) (*types.RunResponse, error)) *MockRunner_Run_Call { + _c.Call.Return(run) + return _c +} + +// Status provides a mock function with given fields: _a0, _a1 +func (_m *MockRunner) Status(_a0 context.Context, _a1 *types.StatusRequest) (*types.StatusResponse, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Status") + } + + var r0 *types.StatusResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.StatusRequest) (*types.StatusResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.StatusRequest) *types.StatusResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.StatusResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.StatusRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRunner_Status_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Status' +type MockRunner_Status_Call struct { + *mock.Call +} + +// Status is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *types.StatusRequest +func (_e *MockRunner_Expecter) Status(_a0 interface{}, _a1 interface{}) *MockRunner_Status_Call { + return &MockRunner_Status_Call{Call: _e.mock.On("Status", _a0, _a1)} +} + +func (_c *MockRunner_Status_Call) Run(run func(_a0 context.Context, _a1 *types.StatusRequest)) *MockRunner_Status_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.StatusRequest)) + }) + return _c +} + +func (_c *MockRunner_Status_Call) Return(_a0 *types.StatusResponse, _a1 error) *MockRunner_Status_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRunner_Status_Call) RunAndReturn(run func(context.Context, *types.StatusRequest) (*types.StatusResponse, error)) *MockRunner_Status_Call { + _c.Call.Return(run) + return _c +} + +// StatusAll provides a mock function with given fields: _a0 +func (_m *MockRunner) StatusAll(_a0 context.Context) (map[string]types.StatusResponse, error) { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for StatusAll") + } + + var r0 map[string]types.StatusResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (map[string]types.StatusResponse, error)); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(context.Context) map[string]types.StatusResponse); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]types.StatusResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRunner_StatusAll_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StatusAll' +type MockRunner_StatusAll_Call struct { + *mock.Call +} + +// StatusAll is a helper method to define mock.On call +// - _a0 context.Context +func (_e *MockRunner_Expecter) StatusAll(_a0 interface{}) *MockRunner_StatusAll_Call { + return &MockRunner_StatusAll_Call{Call: _e.mock.On("StatusAll", _a0)} +} + +func (_c *MockRunner_StatusAll_Call) Run(run func(_a0 context.Context)) *MockRunner_StatusAll_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockRunner_StatusAll_Call) Return(_a0 map[string]types.StatusResponse, _a1 error) *MockRunner_StatusAll_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRunner_StatusAll_Call) RunAndReturn(run func(context.Context) (map[string]types.StatusResponse, error)) *MockRunner_StatusAll_Call { + _c.Call.Return(run) + return _c +} + +// Stop provides a mock function with given fields: _a0, _a1 +func (_m *MockRunner) Stop(_a0 context.Context, _a1 *types.StopRequest) (*types.StopResponse, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Stop") + } + + var r0 *types.StopResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.StopRequest) (*types.StopResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.StopRequest) *types.StopResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.StopResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.StopRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRunner_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop' +type MockRunner_Stop_Call struct { + *mock.Call +} + +// Stop is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *types.StopRequest +func (_e *MockRunner_Expecter) Stop(_a0 interface{}, _a1 interface{}) *MockRunner_Stop_Call { + return &MockRunner_Stop_Call{Call: _e.mock.On("Stop", _a0, _a1)} +} + +func (_c *MockRunner_Stop_Call) Run(run func(_a0 context.Context, _a1 *types.StopRequest)) *MockRunner_Stop_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.StopRequest)) + }) + return _c +} + +func (_c *MockRunner_Stop_Call) Return(_a0 *types.StopResponse, _a1 error) *MockRunner_Stop_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRunner_Stop_Call) RunAndReturn(run func(context.Context, *types.StopRequest) (*types.StopResponse, error)) *MockRunner_Stop_Call { + _c.Call.Return(run) + return _c +} + +// SubmitWorkFlow provides a mock function with given fields: _a0, _a1 +func (_m *MockRunner) SubmitWorkFlow(_a0 context.Context, _a1 *types.ArgoWorkFlowReq) (*types.ArgoWorkFlowRes, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for SubmitWorkFlow") + } + + var r0 *types.ArgoWorkFlowRes + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.ArgoWorkFlowReq) (*types.ArgoWorkFlowRes, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.ArgoWorkFlowReq) *types.ArgoWorkFlowRes); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.ArgoWorkFlowRes) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.ArgoWorkFlowReq) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRunner_SubmitWorkFlow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SubmitWorkFlow' +type MockRunner_SubmitWorkFlow_Call struct { + *mock.Call +} + +// SubmitWorkFlow is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *types.ArgoWorkFlowReq +func (_e *MockRunner_Expecter) SubmitWorkFlow(_a0 interface{}, _a1 interface{}) *MockRunner_SubmitWorkFlow_Call { + return &MockRunner_SubmitWorkFlow_Call{Call: _e.mock.On("SubmitWorkFlow", _a0, _a1)} +} + +func (_c *MockRunner_SubmitWorkFlow_Call) Run(run func(_a0 context.Context, _a1 *types.ArgoWorkFlowReq)) *MockRunner_SubmitWorkFlow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.ArgoWorkFlowReq)) + }) + return _c +} + +func (_c *MockRunner_SubmitWorkFlow_Call) Return(_a0 *types.ArgoWorkFlowRes, _a1 error) *MockRunner_SubmitWorkFlow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRunner_SubmitWorkFlow_Call) RunAndReturn(run func(context.Context, *types.ArgoWorkFlowReq) (*types.ArgoWorkFlowRes, error)) *MockRunner_SubmitWorkFlow_Call { + _c.Call.Return(run) + return _c +} + +// UpdateCluster provides a mock function with given fields: ctx, data +func (_m *MockRunner) UpdateCluster(ctx context.Context, data *types.ClusterRequest) (*types.UpdateClusterResponse, error) { + ret := _m.Called(ctx, data) + + if len(ret) == 0 { + panic("no return value specified for UpdateCluster") + } + + var r0 *types.UpdateClusterResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.ClusterRequest) (*types.UpdateClusterResponse, error)); ok { + return rf(ctx, data) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.ClusterRequest) *types.UpdateClusterResponse); ok { + r0 = rf(ctx, data) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.UpdateClusterResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.ClusterRequest) error); ok { + r1 = rf(ctx, data) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRunner_UpdateCluster_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCluster' +type MockRunner_UpdateCluster_Call struct { + *mock.Call +} + +// UpdateCluster is a helper method to define mock.On call +// - ctx context.Context +// - data *types.ClusterRequest +func (_e *MockRunner_Expecter) UpdateCluster(ctx interface{}, data interface{}) *MockRunner_UpdateCluster_Call { + return &MockRunner_UpdateCluster_Call{Call: _e.mock.On("UpdateCluster", ctx, data)} +} + +func (_c *MockRunner_UpdateCluster_Call) Run(run func(ctx context.Context, data *types.ClusterRequest)) *MockRunner_UpdateCluster_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.ClusterRequest)) + }) + return _c +} + +func (_c *MockRunner_UpdateCluster_Call) Return(_a0 *types.UpdateClusterResponse, _a1 error) *MockRunner_UpdateCluster_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRunner_UpdateCluster_Call) RunAndReturn(run func(context.Context, *types.ClusterRequest) (*types.UpdateClusterResponse, error)) *MockRunner_UpdateCluster_Call { + _c.Call.Return(run) + return _c +} + +// NewMockRunner creates a new instance of MockRunner. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockRunner(t interface { + mock.TestingT + Cleanup(func()) +}) *MockRunner { + mock := &MockRunner{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/builder/deploy/scheduler/mock_Scheduler.go b/_mocks/opencsg.com/csghub-server/builder/deploy/scheduler/mock_Scheduler.go new file mode 100644 index 00000000..f68ed3ca --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/builder/deploy/scheduler/mock_Scheduler.go @@ -0,0 +1,123 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package scheduler + +import mock "github.com/stretchr/testify/mock" + +// MockScheduler is an autogenerated mock type for the Scheduler type +type MockScheduler struct { + mock.Mock +} + +type MockScheduler_Expecter struct { + mock *mock.Mock +} + +func (_m *MockScheduler) EXPECT() *MockScheduler_Expecter { + return &MockScheduler_Expecter{mock: &_m.Mock} +} + +// Queue provides a mock function with given fields: deployTaskID +func (_m *MockScheduler) Queue(deployTaskID int64) error { + ret := _m.Called(deployTaskID) + + if len(ret) == 0 { + panic("no return value specified for Queue") + } + + var r0 error + if rf, ok := ret.Get(0).(func(int64) error); ok { + r0 = rf(deployTaskID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockScheduler_Queue_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Queue' +type MockScheduler_Queue_Call struct { + *mock.Call +} + +// Queue is a helper method to define mock.On call +// - deployTaskID int64 +func (_e *MockScheduler_Expecter) Queue(deployTaskID interface{}) *MockScheduler_Queue_Call { + return &MockScheduler_Queue_Call{Call: _e.mock.On("Queue", deployTaskID)} +} + +func (_c *MockScheduler_Queue_Call) Run(run func(deployTaskID int64)) *MockScheduler_Queue_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockScheduler_Queue_Call) Return(_a0 error) *MockScheduler_Queue_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockScheduler_Queue_Call) RunAndReturn(run func(int64) error) *MockScheduler_Queue_Call { + _c.Call.Return(run) + return _c +} + +// Run provides a mock function with given fields: +func (_m *MockScheduler) Run() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Run") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockScheduler_Run_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Run' +type MockScheduler_Run_Call struct { + *mock.Call +} + +// Run is a helper method to define mock.On call +func (_e *MockScheduler_Expecter) Run() *MockScheduler_Run_Call { + return &MockScheduler_Run_Call{Call: _e.mock.On("Run")} +} + +func (_c *MockScheduler_Run_Call) Run(run func()) *MockScheduler_Run_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockScheduler_Run_Call) Return(_a0 error) *MockScheduler_Run_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockScheduler_Run_Call) RunAndReturn(run func() error) *MockScheduler_Run_Call { + _c.Call.Return(run) + return _c +} + +// NewMockScheduler creates a new instance of MockScheduler. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockScheduler(t interface { + mock.TestingT + Cleanup(func()) +}) *MockScheduler { + mock := &MockScheduler{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} From ec188e36bebb23244525ce091bbd00b1b9a2f9fe Mon Sep 17 00:00:00 2001 From: yiling Date: Mon, 23 Dec 2024 16:35:48 +0800 Subject: [PATCH 16/34] sync deployer with enterprise --- builder/deploy/deploy_ce_test.go | 138 ++++++ builder/deploy/deployer.go | 224 ++++----- builder/deploy/deployer_ce.go | 143 ++++++ builder/deploy/deployer_test.go | 827 +++++++++++++++++++++++++++++++ builder/deploy/init.go | 4 +- 5 files changed, 1200 insertions(+), 136 deletions(-) create mode 100644 builder/deploy/deploy_ce_test.go create mode 100644 builder/deploy/deployer_ce.go create mode 100644 builder/deploy/deployer_test.go diff --git a/builder/deploy/deploy_ce_test.go b/builder/deploy/deploy_ce_test.go new file mode 100644 index 00000000..57ef4087 --- /dev/null +++ b/builder/deploy/deploy_ce_test.go @@ -0,0 +1,138 @@ +//go:build !ee && !saas + +package deploy + +import ( + "context" + "testing" + "time" + + "github.com/bwmarrin/snowflake" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + mockbuilder "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/deploy/imagebuilder" + mockrunner "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/deploy/imagerunner" + mockScheduler "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/deploy/scheduler" + mockdb "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/builder/deploy/common" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" + "opencsg.com/csghub-server/common/types" +) + +func newTestDeployer(t *testing.T) *testDepolyerWithMocks { + mockStores := tests.NewMockStores(t) + node, err := snowflake.NewNode(1) + require.NoError(t, err) + s := &testDepolyerWithMocks{ + deployer: &deployer{ + deployTaskStore: mockStores.DeployTask, + spaceStore: mockStores.Space, + spaceResourceStore: mockStores.SpaceResource, + runtimeFrameworkStore: mockStores.RuntimeFramework, + userStore: mockStores.User, + snowflakeNode: node, + }, + } + s.mocks.stores = mockStores + s.mocks.scheduler = mockScheduler.NewMockScheduler(t) + s.scheduler = s.mocks.scheduler + s.mocks.builder = mockbuilder.NewMockBuilder(t) + s.imageBuilder = s.mocks.builder + s.mocks.runner = mockrunner.NewMockRunner(t) + s.imageRunner = s.mocks.runner + return s +} + +func TestDeployer_Stop(t *testing.T) { + dr := types.DeployRepo{ + SpaceID: 0, + DeployID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, + } + + mockRunner := mockrunner.NewMockRunner(t) + mockRunner.EXPECT().Stop(mock.Anything, mock.Anything).Return(&types.StopResponse{}, nil) + + d := &deployer{ + imageRunner: mockRunner, + } + err := d.Stop(context.TODO(), dr) + require.Nil(t, err) +} + +func TestDeployer_StartDeploy(t *testing.T) { + dbdeploy := database.Deploy{ + ID: 1, + UserUUID: "1", + } + + mockTaskStore := mockdb.NewMockDeployTaskStore(t) + //make a copy to compare the status + dbdeployUpdate := dbdeploy + dbdeployUpdate.Status = common.Pending + mockTaskStore.EXPECT().UpdateDeploy(mock.Anything, &dbdeployUpdate).Return(nil) + + buildTask := database.DeployTask{ + DeployID: dbdeploy.ID, + TaskType: 1, + } + mockTaskStore.EXPECT().CreateDeployTask(mock.Anything, &buildTask).Return(nil) + + mockSch := mockScheduler.NewMockScheduler(t) + mockSch.EXPECT().Queue(mock.Anything).Return(nil) + + node, _ := snowflake.NewNode(1) + + d := &deployer{ + snowflakeNode: node, + deployTaskStore: mockTaskStore, + scheduler: mockSch, + } + err := d.StartDeploy(context.TODO(), &dbdeploy) + + //wait for scheduler to queue task + time.Sleep(time.Second) + + require.Nil(t, err) +} + +func TestDeployer_CheckResourceAvailable(t *testing.T) { + tester := newTestDeployer(t) + ctx := context.TODO() + + tester.mocks.runner.EXPECT().ListCluster(ctx).Return([]types.ClusterResponse{ + {ClusterID: "c1"}, + }, nil) + tester.mocks.runner.EXPECT().GetClusterById(ctx, "c1").Return(&types.ClusterResponse{ + Nodes: map[string]types.NodeResourceInfo{ + "n1": {AvailableMem: 100}, + }, + }, nil) + + v, err := tester.CheckResourceAvailable(ctx, "", 0, &types.HardWare{Memory: "10Gi"}) + require.NoError(t, err) + require.True(t, v) +} + +func TestDeployer_updateEvaluationEnvHardware(t *testing.T) { + + cases := []struct { + hardware types.HardWare + key string + value string + }{ + {types.HardWare{ + Gpu: types.GPU{Num: "1"}, + }, "GPU_NUM", "1"}, + } + + for _, c := range cases { + m := map[string]string{} + updateEvaluationEnvHardware(m, types.EvaluationReq{Hardware: c.hardware}) + require.Equal(t, c.value, m[c.key]) + } + +} diff --git a/builder/deploy/deployer.go b/builder/deploy/deployer.go index 98a81886..0838f217 100644 --- a/builder/deploy/deployer.go +++ b/builder/deploy/deployer.go @@ -12,13 +12,10 @@ import ( "strings" "time" - "github.com/bwmarrin/snowflake" "github.com/google/uuid" "opencsg.com/csghub-server/builder/deploy/common" "opencsg.com/csghub-server/builder/deploy/imagebuilder" - "opencsg.com/csghub-server/builder/deploy/imagerunner" "opencsg.com/csghub-server/builder/deploy/scheduler" - "opencsg.com/csghub-server/builder/event" "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/types" ) @@ -45,54 +42,6 @@ type Deployer interface { GetEvaluation(ctx context.Context, req types.EvaluationGetReq) (*types.ArgoWorkFlowRes, error) } -var _ Deployer = (*deployer)(nil) - -type deployer struct { - s scheduler.Scheduler - ib imagebuilder.Builder - ir imagerunner.Runner - - store database.DeployTaskStore - spaceStore database.SpaceStore - spaceResourceStore database.SpaceResourceStore - runnerStatuscache map[string]types.StatusResponse - internalRootDomain string - sfNode *snowflake.Node - eventPub *event.EventPublisher - rtfm database.RuntimeFrameworksStore -} - -func newDeployer(s scheduler.Scheduler, ib imagebuilder.Builder, ir imagerunner.Runner) (*deployer, error) { - store := database.NewDeployTaskStore() - node, err := snowflake.NewNode(1) - if err != nil || node == nil { - slog.Error("fail to generate uuid for inference service name", slog.Any("error", err)) - return nil, err - } - d := &deployer{ - s: s, - ib: ib, - ir: ir, - store: store, - spaceStore: database.NewSpaceStore(), - spaceResourceStore: database.NewSpaceResourceStore(), - runnerStatuscache: make(map[string]types.StatusResponse), - sfNode: node, - eventPub: &event.DefaultEventPublisher, - rtfm: database.NewRuntimeFrameworksStore(), - } - - go d.refreshStatus() - go func() { - err = d.s.Run() - if err != nil { - slog.Error("run scheduler failed", slog.Any("error", err)) - } - }() - go d.startAccounting() - return d, nil -} - func (d *deployer) GenerateUniqueSvcName(dr types.DeployRepo) string { uniqueSvcName := "" if dr.Type == types.SpaceType { @@ -106,7 +55,7 @@ func (d *deployer) GenerateUniqueSvcName(dr types.DeployRepo) string { } else { // model inference // generate unique service name from uuid when create new deploy by snowflake - uniqueSvcName = d.sfNode.Generate().Base36() + uniqueSvcName = d.snowflakeNode.Generate().Base36() } return uniqueSvcName } @@ -118,9 +67,9 @@ func (d *deployer) serverlessDeploy(ctx context.Context, dr types.DeployRepo) (* ) slog.Info("do deployer.serverlessDeploy check type", slog.Any("dr.Type", dr.Type)) if dr.Type == types.SpaceType { - deploy, err = d.store.GetLatestDeployBySpaceID(ctx, dr.SpaceID) + deploy, err = d.deployTaskStore.GetLatestDeployBySpaceID(ctx, dr.SpaceID) } else { - deploy, err = d.store.GetServerlessDeployByRepID(ctx, dr.RepoID) + deploy, err = d.deployTaskStore.GetServerlessDeployByRepID(ctx, dr.RepoID) } if err == sql.ErrNoRows { return nil, nil @@ -135,8 +84,18 @@ func (d *deployer) serverlessDeploy(ctx context.Context, dr types.DeployRepo) (* deploy.SKU = dr.SKU // dr.ImageID is not null for nginx space, otherwise it's "" deploy.ImageID = dr.ImageID + deploy.Annotation = dr.Annotation + deploy.Env = dr.Env + deploy.Hardware = dr.Hardware + deploy.RuntimeFramework = dr.RuntimeFramework + deploy.Secret = dr.Secret + deploy.SecureLevel = dr.SecureLevel + deploy.ContainerPort = dr.ContainerPort + deploy.Template = dr.Template + deploy.MinReplica = dr.MinReplica + deploy.MaxReplica = dr.MaxReplica slog.Debug("do deployer.serverlessDeploy", slog.Any("dr", dr), slog.Any("deploy", deploy)) - err = d.store.UpdateDeploy(ctx, deploy) + err = d.deployTaskStore.UpdateDeploy(ctx, deploy) if err != nil { return nil, fmt.Errorf("fail reset deploy image, %w", err) } @@ -174,7 +133,8 @@ func (d *deployer) dedicatedDeploy(ctx context.Context, dr types.DeployRepo) (*d UserUUID: dr.UserUUID, SKU: dr.SKU, } - err := d.store.CreateDeploy(ctx, deploy) + updateDatabaseDeploy(deploy, dr) + err := d.deployTaskStore.CreateDeploy(ctx, deploy) return deploy, err } @@ -203,6 +163,13 @@ func (d *deployer) buildDeploy(ctx context.Context, dr types.DeployRepo) (*datab } func (d *deployer) Deploy(ctx context.Context, dr types.DeployRepo) (int64, error) { + + //check reserved resource + err := d.checkOrderDetail(ctx, dr) + if err != nil { + return -1, err + } + deploy, err := d.buildDeploy(ctx, dr) slog.Info("do deployer.Deploy", slog.Any("dr", dr), slog.Any("deploy", deploy)) if err != nil || deploy == nil { @@ -225,24 +192,20 @@ func (d *deployer) Deploy(ctx context.Context, dr types.DeployRepo) (int64, erro Status: bldTaskStatus, Message: bldTaskMsg, } - err = d.store.CreateDeployTask(ctx, buildTask) + err = d.deployTaskStore.CreateDeployTask(ctx, buildTask) if err != nil { - return -1, fmt.Errorf("failed to create deploy task: %w", err) + return -1, fmt.Errorf("create deploy task failed: %w", err) } runTask := &database.DeployTask{ DeployID: deploy.ID, TaskType: 1, } - err = d.store.CreateDeployTask(ctx, runTask) + err = d.deployTaskStore.CreateDeployTask(ctx, runTask) if err != nil { - return -1, fmt.Errorf("failed to create deploy task: %w", err) + return -1, fmt.Errorf("create deploy task failed: %w", err) } - go func() { - if err := d.s.Queue(buildTask.ID); err != nil { - slog.Error("failed to queue task", slog.Any("error", err)) - } - }() + go func() { _ = d.scheduler.Queue(buildTask.ID) }() return deploy.ID, nil } @@ -250,13 +213,13 @@ func (d *deployer) Deploy(ctx context.Context, dr types.DeployRepo) (int64, erro func (d *deployer) refreshStatus() { for { ctxTimeout, cancel := context.WithTimeout(context.Background(), 3*time.Second) - status, err := d.ir.StatusAll(ctxTimeout) + status, err := d.imageRunner.StatusAll(ctxTimeout) cancel() if err != nil { slog.Error("refresh status all failed", slog.Any("error", err)) } else { - slog.Debug("status all cached", slog.Any("status", d.runnerStatuscache)) - d.runnerStatuscache = status + slog.Debug("status all cached", slog.Any("status", d.runnerStatusCache)) + d.runnerStatusCache = status } time.Sleep(5 * time.Second) @@ -264,14 +227,14 @@ func (d *deployer) refreshStatus() { } func (d *deployer) Status(ctx context.Context, dr types.DeployRepo, needDetails bool) (string, int, []types.Instance, error) { - deploy, err := d.store.GetDeployByID(ctx, dr.DeployID) + deploy, err := d.deployTaskStore.GetDeployByID(ctx, dr.DeployID) if err != nil || deploy == nil { - slog.Error("fail to get deploy by deploy id", slog.Any("DeployID", deploy.ID), slog.Any("error", err)) + slog.Error("fail to get deploy by deploy id", slog.Any("DeployID", dr.DeployID), slog.Any("error", err)) return "", common.Stopped, nil, fmt.Errorf("can't get deploy, %w", err) } svcName := deploy.SvcName // srvName := common.UniqueSpaceAppName(dr.Namespace, dr.Name, dr.SpaceID) - rstatus, found := d.runnerStatuscache[svcName] + rstatus, found := d.runnerStatusCache[svcName] if !found { slog.Debug("status cache miss", slog.String("svc_name", svcName)) if deploy.Status == common.Running { @@ -283,7 +246,7 @@ func (d *deployer) Status(ctx context.Context, dr types.DeployRepo, needDetails deployStatus := rstatus.Code if dr.ModelID > 0 { targetID := dr.DeployID // support model deploy with multi-instance - status, err := d.ir.Status(ctx, &types.StatusRequest{ + status, err := d.imageRunner.Status(ctx, &types.StatusRequest{ ClusterID: dr.ClusterID, OrgName: dr.Namespace, RepoName: dr.Name, @@ -307,13 +270,13 @@ func (d *deployer) Status(ctx context.Context, dr types.DeployRepo, needDetails func (d *deployer) Logs(ctx context.Context, dr types.DeployRepo) (*MultiLogReader, error) { // get latest Deploy - deploy, err := d.store.GetLatestDeployBySpaceID(ctx, dr.SpaceID) + deploy, err := d.deployTaskStore.GetLatestDeployBySpaceID(ctx, dr.SpaceID) if err != nil { return nil, fmt.Errorf("can't get space delopyment,%w", err) } slog.Debug("get logs for space", slog.Any("deploy", deploy), slog.Int64("space_id", dr.SpaceID)) - buildLog, err := d.ib.Logs(ctx, &imagebuilder.LogsRequest{ + buildLog, err := d.imageBuilder.Logs(ctx, &imagebuilder.LogsRequest{ OrgName: dr.Namespace, SpaceName: dr.Name, BuildID: strconv.FormatInt(deploy.ID, 10), @@ -327,7 +290,7 @@ func (d *deployer) Logs(ctx context.Context, dr types.DeployRepo) (*MultiLogRead if dr.SpaceID == 0 { targetID = dr.DeployID // support model deploy with multi-instance } - runLog, err := d.ir.Logs(ctx, &types.LogsRequest{ + runLog, err := d.imageRunner.Logs(ctx, &types.LogsRequest{ ID: targetID, OrgName: dr.Namespace, RepoName: dr.Name, @@ -347,7 +310,7 @@ func (d *deployer) Stop(ctx context.Context, dr types.DeployRepo) error { if dr.SpaceID == 0 { targetID = dr.DeployID // support model deploy with multi-instance } - resp, err := d.ir.Stop(ctx, &types.StopRequest{ + resp, err := d.imageRunner.Stop(ctx, &types.StopRequest{ ID: targetID, OrgName: dr.Namespace, RepoName: dr.Name, @@ -357,6 +320,11 @@ func (d *deployer) Stop(ctx context.Context, dr types.DeployRepo) error { if err != nil { slog.Error("deployer stop deploy", slog.Any("runner_resp", resp), slog.Int64("space_id", dr.SpaceID), slog.Any("deploy_id", dr.DeployID), slog.Any("error", err)) } + // release resource if it's a order case + err = d.releaseUserResourceByOrder(ctx, dr) + if err != nil { + return err + } return err } @@ -365,7 +333,7 @@ func (d *deployer) Purge(ctx context.Context, dr types.DeployRepo) error { if dr.SpaceID == 0 { targetID = dr.DeployID // support model deploy with multi-instance } - resp, err := d.ir.Purge(ctx, &types.PurgeRequest{ + resp, err := d.imageRunner.Purge(ctx, &types.PurgeRequest{ ID: targetID, OrgName: dr.Namespace, RepoName: dr.Name, @@ -417,7 +385,7 @@ func (d *deployer) Exist(ctx context.Context, dr types.DeployRepo) (bool, error) SvcName: dr.SvcName, ClusterID: dr.ClusterID, } - resp, err := d.ir.Exist(ctx, req) + resp, err := d.imageRunner.Exist(ctx, req) if err != nil { slog.Error("fail to check deploy", slog.Any("req", req), slog.Any("error", err)) return true, err @@ -447,7 +415,7 @@ func (d *deployer) GetReplica(ctx context.Context, dr types.DeployRepo) (int, in ClusterID: dr.ClusterID, SvcName: dr.SvcName, } - resp, err := d.ir.GetReplica(ctx, req) + resp, err := d.imageRunner.GetReplica(ctx, req) if err != nil { slog.Warn("fail to get deploy replica with error", slog.Any("req", req), slog.Any("error", err)) return 0, 0, []types.Instance{}, err @@ -462,7 +430,7 @@ func (d *deployer) InstanceLogs(ctx context.Context, dr types.DeployRepo) (*Mult if dr.SpaceID == 0 { targetID = dr.DeployID // support model deploy with multi-instance } - runLog, err := d.ir.InstanceLogs(ctx, &types.InstanceLogsRequest{ + runLog, err := d.imageRunner.InstanceLogs(ctx, &types.InstanceLogsRequest{ ID: targetID, OrgName: dr.Namespace, RepoName: dr.Name, @@ -479,7 +447,7 @@ func (d *deployer) InstanceLogs(ctx context.Context, dr types.DeployRepo) (*Mult } func (d *deployer) ListCluster(ctx context.Context) ([]types.ClusterRes, error) { - resp, err := d.ir.ListCluster(ctx) + resp, err := d.imageRunner.ListCluster(ctx) if err != nil { return nil, err } @@ -501,13 +469,15 @@ func (d *deployer) ListCluster(ctx context.Context) ([]types.ClusterRes, error) } func (d *deployer) GetClusterById(ctx context.Context, clusterId string) (*types.ClusterRes, error) { - resp, err := d.ir.GetClusterById(ctx, clusterId) + resp, err := d.imageRunner.GetClusterById(ctx, clusterId) if err != nil { return nil, err } - resources := make([]types.NodeResourceInfo, 0) - for _, node := range resp.Nodes { - resources = append(resources, node) + + // get reserved resources + resources, err := d.getResources(ctx, clusterId, resp) + if err != nil { + return nil, err } result := types.ClusterRes{ ClusterID: resp.ClusterID, @@ -520,8 +490,7 @@ func (d *deployer) GetClusterById(ctx context.Context, clusterId string) (*types } func (d *deployer) UpdateCluster(ctx context.Context, data types.ClusterRequest) (*types.UpdateClusterResponse, error) { - resp, err := d.ir.UpdateCluster(ctx, &data) - return (*types.UpdateClusterResponse)(resp), err + return d.imageRunner.UpdateCluster(ctx, &data) } // UpdateDeploy implements Deployer. @@ -534,7 +503,7 @@ func (d *deployer) UpdateDeploy(ctx context.Context, dur *types.DeployUpdateReq, ) if dur.RuntimeFrameworkID != nil { - frame, err = d.rtfm.FindEnabledByID(ctx, *dur.RuntimeFrameworkID) + frame, err = d.runtimeFrameworkStore.FindEnabledByID(ctx, *dur.RuntimeFrameworkID) if err != nil || frame == nil { return fmt.Errorf("can't find available runtime framework %v, %w", *dur.RuntimeFrameworkID, err) } @@ -567,11 +536,7 @@ func (d *deployer) UpdateDeploy(ctx context.Context, dur *types.DeployUpdateReq, if frame != nil { // choose image - containerImg := frame.FrameCpuImage - if hardware != nil && hardware.Gpu.Num != "" { - // use gpu image - containerImg = frame.FrameImage - } + containerImg := containerImage(hardware, frame) deploy.ImageID = containerImg deploy.RuntimeFramework = frame.FrameName deploy.ContainerPort = frame.ContainerPort @@ -601,7 +566,7 @@ func (d *deployer) UpdateDeploy(ctx context.Context, dur *types.DeployUpdateReq, } // update deploy table - err = d.store.UpdateDeploy(ctx, deploy) + err = d.deployTaskStore.UpdateDeploy(ctx, deploy) if err != nil { return fmt.Errorf("failed to update deploy, %w", err) } @@ -612,7 +577,7 @@ func (d *deployer) UpdateDeploy(ctx context.Context, dur *types.DeployUpdateReq, func (d *deployer) StartDeploy(ctx context.Context, deploy *database.Deploy) error { deploy.Status = common.Pending // update deploy table - err := d.store.UpdateDeploy(ctx, deploy) + err := d.deployTaskStore.UpdateDeploy(ctx, deploy) if err != nil { return fmt.Errorf("failed to update deploy, %w", err) } @@ -622,19 +587,20 @@ func (d *deployer) StartDeploy(ctx context.Context, deploy *database.Deploy) err DeployID: deploy.ID, TaskType: 1, } - err = d.store.CreateDeployTask(ctx, runTask) + err = d.deployTaskStore.CreateDeployTask(ctx, runTask) if err != nil { - return fmt.Errorf("failed to create deploy task: %w", err) + return fmt.Errorf("create deploy task failed: %w", err) } - go func() { _ = d.s.Queue(runTask.ID) }() + go func() { _ = d.scheduler.Queue(runTask.ID) }() - return nil -} + // update resource if it's a order case + err = d.updateUserResourceByOrder(ctx, deploy) + if err != nil { + return err + } -// accounting timer -func (d *deployer) startAccounting() { - d.startAccountingMetering() + return nil } func (d *deployer) getResourceMap() map[string]string { @@ -652,19 +618,19 @@ func (d *deployer) getResourceMap() map[string]string { return resources } -func (d *deployer) startAccountingMetering() { +func (d *deployer) startAcctFeeing() { for { resMap := d.getResourceMap() - slog.Debug("get resources map and runnerStatuscache", slog.Any("resMap", resMap), slog.Any("runnerStatuscache", d.runnerStatuscache)) - for _, svc := range d.runnerStatuscache { - d.startAccountingRequestMeter(resMap, svc) + slog.Debug("get resources map", slog.Any("resMap", resMap)) + for _, svc := range d.runnerStatusCache { + d.startAcctRequestFee(resMap, svc) } // accounting interval in min, get from env config time.Sleep(time.Duration(d.eventPub.SyncInterval) * time.Minute) } } -func (d *deployer) startAccountingRequestMeter(resMap map[string]string, svcRes types.StatusResponse) { +func (d *deployer) startAcctRequestFee(resMap map[string]string, svcRes types.StatusResponse) { // ignore not ready svc if svcRes.Code != common.Running { return @@ -684,6 +650,8 @@ func (d *deployer) startAccountingRequestMeter(resMap map[string]string, svcRes slog.Error("invalid deploy type of service for metering", slog.Any("svcRes", svcRes)) return } + + extra := startAcctRequestFeeExtra(svcRes) event := types.METERING_EVENT{ Uuid: uuid.New(), UserUUID: svcRes.UserID, @@ -695,7 +663,7 @@ func (d *deployer) startAccountingRequestMeter(resMap map[string]string, svcRes ResourceName: resName, CustomerID: svcRes.ServiceName, CreatedAt: time.Now(), - Extra: "", + Extra: extra, } str, err := json.Marshal(event) if err != nil { @@ -741,6 +709,10 @@ func (d *deployer) CheckResourceAvailable(ctx context.Context, clusterId string, if err != nil { return false, err } + err = d.checkOrderDetailByID(ctx, orderDetailID) + if err != nil { + return false, err + } if !CheckResource(clusterResources, hardWare) { return false, fmt.Errorf("required resource is not enough") } @@ -756,24 +728,7 @@ func CheckResource(clusterResources *types.ClusterRes, hardware *types.HardWare) } for _, node := range clusterResources.Resources { if float32(mem) <= node.AvailableMem { - if hardware.Gpu.Num != "" { - gpu, err := strconv.Atoi(hardware.Gpu.Num) - if err != nil { - slog.Error("failed to parse hardware gpu ", slog.Any("error", err)) - return false - } - cpu, err := strconv.Atoi(hardware.Cpu.Num) - if err != nil { - slog.Error("failed to parse hardware cpu ", slog.Any("error", err)) - return false - - } - if gpu <= int(node.AvailableXPU) && hardware.Gpu.Type == node.XPUModel && cpu <= int(node.AvailableCPU) { - return true - } - } else { - return true - } + return checkNodeResource(node, hardware) } } return false @@ -789,9 +744,8 @@ func (d *deployer) SubmitEvaluation(ctx context.Context, req types.EvaluationReq env["ACCESS_TOKEN"] = req.Token env["HF_ENDPOINT"] = req.DownloadEndpoint - if req.Hardware.Gpu.Num != "" { - env["GPU_NUM"] = req.Hardware.Gpu.Num - } + updateEvaluationEnvHardware(env, req) + templates := []types.ArgoFlowTemplate{} templates = append(templates, types.ArgoFlowTemplate{ Name: "evaluation", @@ -800,7 +754,7 @@ func (d *deployer) SubmitEvaluation(ctx context.Context, req types.EvaluationReq Image: req.Image, }, ) - uniqueFlowName := d.sfNode.Generate().Base36() + uniqueFlowName := d.snowflakeNode.Generate().Base36() flowReq := &types.ArgoWorkFlowReq{ TaskName: req.TaskName, TaskId: uniqueFlowName, @@ -821,14 +775,14 @@ func (d *deployer) SubmitEvaluation(ctx context.Context, req types.EvaluationReq if req.ResourceId == 0 { flowReq.ShareMode = true } - return d.ir.SubmitWorkFlow(ctx, flowReq) + return d.imageRunner.SubmitWorkFlow(ctx, flowReq) } func (d *deployer) ListEvaluations(ctx context.Context, username string, per int, page int) (*types.ArgoWorkFlowListRes, error) { - return d.ir.ListWorkFlows(ctx, username, per, page) + return d.imageRunner.ListWorkFlows(ctx, username, per, page) } func (d *deployer) DeleteEvaluation(ctx context.Context, req types.ArgoWorkFlowDeleteReq) error { - _, err := d.ir.DeleteWorkFlow(ctx, req) + _, err := d.imageRunner.DeleteWorkFlow(ctx, req) if err != nil { return err } @@ -836,7 +790,7 @@ func (d *deployer) DeleteEvaluation(ctx context.Context, req types.ArgoWorkFlowD } func (d *deployer) GetEvaluation(ctx context.Context, req types.EvaluationGetReq) (*types.ArgoWorkFlowRes, error) { - wf, err := d.ir.GetWorkFlow(ctx, req) + wf, err := d.imageRunner.GetWorkFlow(ctx, req) if err != nil { return nil, err } diff --git a/builder/deploy/deployer_ce.go b/builder/deploy/deployer_ce.go new file mode 100644 index 00000000..ddf6f5b8 --- /dev/null +++ b/builder/deploy/deployer_ce.go @@ -0,0 +1,143 @@ +//go:build !ee && !saas + +package deploy + +import ( + "context" + "log/slog" + "strconv" + + "github.com/bwmarrin/snowflake" + "opencsg.com/csghub-server/builder/deploy/imagebuilder" + "opencsg.com/csghub-server/builder/deploy/imagerunner" + "opencsg.com/csghub-server/builder/deploy/scheduler" + "opencsg.com/csghub-server/builder/event" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +type deployer struct { + scheduler scheduler.Scheduler + imageBuilder imagebuilder.Builder + imageRunner imagerunner.Runner + + deployTaskStore database.DeployTaskStore + spaceStore database.SpaceStore + spaceResourceStore database.SpaceResourceStore + runnerStatusCache map[string]types.StatusResponse + internalRootDomain string + snowflakeNode *snowflake.Node + eventPub *event.EventPublisher + runtimeFrameworkStore database.RuntimeFrameworksStore + deployConfig DeployConfig + userStore database.UserStore +} + +func newDeployer(s scheduler.Scheduler, ib imagebuilder.Builder, ir imagerunner.Runner, c DeployConfig) (*deployer, error) { + store := database.NewDeployTaskStore() + node, err := snowflake.NewNode(1) + if err != nil || node == nil { + slog.Error("fail to generate uuid for inference service name", slog.Any("error", err)) + return nil, err + } + d := &deployer{ + scheduler: s, + imageBuilder: ib, + imageRunner: ir, + deployTaskStore: store, + spaceStore: database.NewSpaceStore(), + spaceResourceStore: database.NewSpaceResourceStore(), + runnerStatusCache: make(map[string]types.StatusResponse), + snowflakeNode: node, + eventPub: &event.DefaultEventPublisher, + runtimeFrameworkStore: database.NewRuntimeFrameworksStore(), + deployConfig: c, + userStore: database.NewUserStore(), + } + + go d.refreshStatus() + d.startJobs() + return d, nil +} + +func (d *deployer) checkOrderDetailByID(ctx context.Context, id int64) error { + return nil +} + +func (d *deployer) checkOrderDetail(ctx context.Context, dr types.DeployRepo) error { + return nil +} + +func (d *deployer) updateUserResourceByOrder(ctx context.Context, deploy *database.Deploy) error { + return nil +} + +func (d *deployer) releaseUserResourceByOrder(ctx context.Context, dr types.DeployRepo) error { + return nil +} + +func containerImage(hardware *types.HardWare, frame *database.RuntimeFramework) string { + // use gpu image + if hardware.Gpu.Num != "" { + return frame.FrameImage + } + return frame.FrameCpuImage +} + +func (d *deployer) startAccounting() { + d.startAcctFeeing() +} + +func checkNodeResource(node types.NodeResourceInfo, hardware *types.HardWare) bool { + if hardware.Gpu.Num != "" { + gpu, err := strconv.Atoi(hardware.Gpu.Num) + if err != nil { + slog.Error("failed to parse hardware gpu ", slog.Any("error", err)) + return false + } + cpu, err := strconv.Atoi(hardware.Cpu.Num) + if err != nil { + slog.Error("failed to parse hardware cpu ", slog.Any("error", err)) + return false + + } + if gpu <= int(node.AvailableXPU) && hardware.Gpu.Type == node.XPUModel && cpu <= int(node.AvailableCPU) { + return true + } + } else { + return true + } + return false +} + +func (d *deployer) startJobs() { + go func() { + err := d.scheduler.Run() + if err != nil { + slog.Error("run scheduler failed", slog.Any("error", err)) + } + }() + go d.startAccounting() +} + +func (d *deployer) getResources(ctx context.Context, clusterId string, clusterResponse *types.ClusterResponse) ([]types.NodeResourceInfo, error) { + resources := make([]types.NodeResourceInfo, 0) + for _, node := range clusterResponse.Nodes { + resources = append(resources, node) + } + return resources, nil + +} + +func startAcctRequestFeeExtra(res types.StatusResponse) string { + return "" +} + +func updateDatabaseDeploy(dp *database.Deploy, dt types.DeployRepo) { +} + +func updateEvaluationEnvHardware(env map[string]string, req types.EvaluationReq) { + if req.Hardware.Gpu.Num != "" { + env["GPU_NUM"] = req.Hardware.Gpu.Num + } +} diff --git a/builder/deploy/deployer_test.go b/builder/deploy/deployer_test.go new file mode 100644 index 00000000..bc8e1508 --- /dev/null +++ b/builder/deploy/deployer_test.go @@ -0,0 +1,827 @@ +package deploy + +import ( + "context" + "database/sql" + "errors" + "strings" + "testing" + "time" + + "github.com/bwmarrin/snowflake" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + mockbuilder "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/deploy/imagebuilder" + mockrunner "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/deploy/imagerunner" + mockScheduler "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/deploy/scheduler" + mockdb "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/store/database" + + "opencsg.com/csghub-server/builder/deploy/common" + "opencsg.com/csghub-server/builder/deploy/scheduler" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" + "opencsg.com/csghub-server/common/types" +) + +type testDepolyerWithMocks struct { + *deployer + mocks struct { + stores *tests.MockStores + scheduler *mockScheduler.MockScheduler + builder *mockbuilder.MockBuilder + runner *mockrunner.MockRunner + } +} + +func TestDeployer_GenerateUniqueSvcName(t *testing.T) { + dr := types.DeployRepo{ + Path: "namespace/name", + } + + node, _ := snowflake.NewNode(1) + d := &deployer{ + snowflakeNode: node, + } + + dr.Type = types.SpaceType + name := d.GenerateUniqueSvcName(dr) + require.True(t, strings.HasPrefix(name, "u")) + + dr.Type = types.ServerlessType + name = d.GenerateUniqueSvcName(dr) + require.True(t, strings.HasPrefix(name, "s")) + + dr.Type = types.InferenceType + name = d.GenerateUniqueSvcName(dr) + require.False(t, strings.Contains(name, "-")) + +} + +func TestDeployer_serverlessDeploy(t *testing.T) { + t.Run("deploy space", func(t *testing.T) { + var oldDeploy database.Deploy + oldDeploy.ID = 1 + + dr := types.DeployRepo{ + SpaceID: 1, + Type: types.SpaceType, + UserUUID: "1", + SKU: "1", + ImageID: "image:1", + Annotation: "test annotation", + Env: "test env", + RuntimeFramework: "test runtime framework", + Secret: "test secret", + SecureLevel: 1, + ContainerPort: 8000, + Template: "test template", + MinReplica: 1, + MaxReplica: 2, + } + + newDeploy := oldDeploy + newDeploy.UserUUID = dr.UserUUID + newDeploy.SKU = dr.SKU + newDeploy.ImageID = dr.ImageID + newDeploy.Annotation = dr.Annotation + newDeploy.Env = dr.Env + newDeploy.RuntimeFramework = dr.RuntimeFramework + newDeploy.Secret = dr.Secret + newDeploy.SecureLevel = dr.SecureLevel + newDeploy.ContainerPort = dr.ContainerPort + newDeploy.Template = dr.Template + newDeploy.MinReplica = dr.MinReplica + newDeploy.MaxReplica = dr.MaxReplica + + mockTaskStore := mockdb.NewMockDeployTaskStore(t) + mockTaskStore.EXPECT().GetLatestDeployBySpaceID(mock.Anything, dr.SpaceID).Return(&oldDeploy, nil) + mockTaskStore.EXPECT().UpdateDeploy(mock.Anything, &newDeploy).Return(nil) + + d := &deployer{ + deployTaskStore: mockTaskStore, + } + dbdeploy, err := d.serverlessDeploy(context.TODO(), dr) + require.Nil(t, err) + require.Equal(t, *dbdeploy, newDeploy) + }) + + t.Run("deploy model", func(t *testing.T) { + var oldDeploy database.Deploy + oldDeploy.ID = 1 + + dr := types.DeployRepo{ + RepoID: 1, + Type: types.InferenceType, + UserUUID: "1", + SKU: "1", + ImageID: "image:1", + Annotation: "test annotation", + Env: "test env", + RuntimeFramework: "test runtime framework", + Secret: "test secret", + SecureLevel: 1, + ContainerPort: 8000, + Template: "test template", + MinReplica: 1, + MaxReplica: 2, + } + + newDeploy := oldDeploy + newDeploy.UserUUID = dr.UserUUID + newDeploy.SKU = dr.SKU + newDeploy.ImageID = dr.ImageID + newDeploy.Annotation = dr.Annotation + newDeploy.Env = dr.Env + newDeploy.RuntimeFramework = dr.RuntimeFramework + newDeploy.Secret = dr.Secret + newDeploy.SecureLevel = dr.SecureLevel + newDeploy.ContainerPort = dr.ContainerPort + newDeploy.Template = dr.Template + newDeploy.MinReplica = dr.MinReplica + newDeploy.MaxReplica = dr.MaxReplica + + mockTaskStore := mockdb.NewMockDeployTaskStore(t) + mockTaskStore.EXPECT().GetServerlessDeployByRepID(mock.Anything, dr.RepoID).Return(&oldDeploy, nil) + mockTaskStore.EXPECT().UpdateDeploy(mock.Anything, &newDeploy).Return(nil) + + d := &deployer{ + deployTaskStore: mockTaskStore, + } + dbdeploy, err := d.serverlessDeploy(context.TODO(), dr) + require.Nil(t, err) + require.Equal(t, *dbdeploy, newDeploy) + }) +} + +func TestDeployer_dedicatedDeploy(t *testing.T) { + dr := types.DeployRepo{ + Path: "namespace/name", + Type: types.InferenceType, + } + + mockTaskStore := mockdb.NewMockDeployTaskStore(t) + mockTaskStore.EXPECT().CreateDeploy(mock.Anything, mock.Anything).Return(nil) + + node, _ := snowflake.NewNode(1) + d := &deployer{ + snowflakeNode: node, + deployTaskStore: mockTaskStore, + } + + _, err := d.dedicatedDeploy(context.TODO(), dr) + require.Nil(t, err) + +} + +func TestDeployer_Deploy(t *testing.T) { + + t.Run("use on-demand resource and skip build task", func(t *testing.T) { + dr := types.DeployRepo{ + OrderDetailID: 0, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, + ImageID: "image:1", + } + + buildTask := database.DeployTask{ + TaskType: 0, + Status: scheduler.BuildSkip, + Message: "Skip", + } + runTask := database.DeployTask{ + TaskType: 1, + } + + mockTaskStore := mockdb.NewMockDeployTaskStore(t) + mockTaskStore.EXPECT().CreateDeploy(mock.Anything, mock.Anything).Return(nil) + mockTaskStore.EXPECT().CreateDeployTask(mock.Anything, &buildTask).Return(nil) + // RunAndReturn(func(ctx context.Context, dt *database.DeployTask) error { + // return nil + // }) + mockTaskStore.EXPECT().CreateDeployTask(mock.Anything, &runTask).Return(nil) + + mockSch := mockScheduler.NewMockScheduler(t) + mockSch.EXPECT().Queue(mock.Anything).Return(nil) + + node, _ := snowflake.NewNode(1) + + d := &deployer{ + snowflakeNode: node, + deployTaskStore: mockTaskStore, + scheduler: mockSch, + } + + _, err := d.Deploy(context.TODO(), dr) + // wait for scheduler.Queue to be called + time.Sleep(time.Second) + require.Nil(t, err) + }) +} + +func TestDeployer_Status(t *testing.T) { + t.Run("no deploy", func(t *testing.T) { + dr := types.DeployRepo{ + OrderDetailID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, + } + + mockDeployTaskStore := mockdb.NewMockDeployTaskStore(t) + mockDeployTaskStore.EXPECT().GetDeployByID(mock.Anything, dr.DeployID). + Return(nil, nil) + + d := &deployer{ + deployTaskStore: mockDeployTaskStore, + } + + svcName, deployStatus, instances, err := d.Status(context.TODO(), dr, false) + require.NotNil(t, err) + require.Equal(t, "", svcName) + require.Equal(t, common.Stopped, deployStatus) + require.Nil(t, instances) + + }) + t.Run("cache miss and running", func(t *testing.T) { + dr := types.DeployRepo{ + DeployID: 1, + OrderDetailID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, + } + deploy := &database.Deploy{ + Status: common.Running, + SvcName: "svc", + } + + mockDeployTaskStore := mockdb.NewMockDeployTaskStore(t) + mockDeployTaskStore.EXPECT().GetDeployByID(mock.Anything, dr.DeployID). + Return(deploy, nil) + + d := &deployer{ + deployTaskStore: mockDeployTaskStore, + } + d.runnerStatusCache = make(map[string]types.StatusResponse) + + svcName, deployStatus, instances, err := d.Status(context.TODO(), dr, false) + require.Nil(t, err) + require.Equal(t, "svc", svcName) + require.Equal(t, common.Stopped, deployStatus) + require.Nil(t, instances) + + }) + + t.Run("cache miss and not running", func(t *testing.T) { + dr := types.DeployRepo{ + DeployID: 1, + OrderDetailID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, + } + deploy := &database.Deploy{ + Status: common.BuildSuccess, + SvcName: "svc", + } + + mockDeployTaskStore := mockdb.NewMockDeployTaskStore(t) + mockDeployTaskStore.EXPECT().GetDeployByID(mock.Anything, dr.DeployID). + Return(deploy, nil) + + d := &deployer{ + deployTaskStore: mockDeployTaskStore, + } + d.runnerStatusCache = make(map[string]types.StatusResponse) + + svcName, deployStatus, instances, err := d.Status(context.TODO(), dr, false) + require.Nil(t, err) + require.Equal(t, "svc", svcName) + require.Equal(t, common.BuildSuccess, deployStatus) + require.Nil(t, instances) + + }) + + t.Run("cache hit and running", func(t *testing.T) { + dr := types.DeployRepo{ + DeployID: 1, + OrderDetailID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, + ModelID: 1, + } + // build success status in db + deploy := &database.Deploy{ + Status: common.BuildSuccess, + SvcName: "svc", + } + + mockDeployTaskStore := mockdb.NewMockDeployTaskStore(t) + mockDeployTaskStore.EXPECT().GetDeployByID(mock.Anything, dr.DeployID). + Return(deploy, nil) + + mockRunner := mockrunner.NewMockRunner(t) + mockRunner.EXPECT().Status(mock.Anything, mock.Anything). + Return(&types.StatusResponse{ + DeployID: 1, + UserID: "", + // running status from the runner (latest) + Code: int(common.Running), + Message: "", + Endpoint: "http://localhost", + Instances: []types.Instance{{ + Name: "instance1", + Status: "ready", + }}, + Replica: 1, + DeployType: 0, + ServiceName: "svc", + DeploySku: "", + }, nil) + + d := &deployer{ + deployTaskStore: mockDeployTaskStore, + imageRunner: mockRunner, + } + d.runnerStatusCache = make(map[string]types.StatusResponse) + // deploying status in cache + d.runnerStatusCache["svc"] = types.StatusResponse{ + DeployID: 1, + UserID: "", + Code: int(common.Deploying), + Message: "", + } + + svcName, deployStatus, instances, err := d.Status(context.TODO(), dr, false) + require.Nil(t, err) + require.Equal(t, "svc", svcName) + require.Equal(t, common.Running, deployStatus) + require.Len(t, instances, 1) + + }) +} + +func TestDeployer_Logs(t *testing.T) { + t.Run("no deploy", func(t *testing.T) { + dr := types.DeployRepo{ + OrderDetailID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, + SpaceID: 1, + } + + mockDeployTaskStore := mockdb.NewMockDeployTaskStore(t) + mockDeployTaskStore.EXPECT().GetLatestDeployBySpaceID(mock.Anything, dr.SpaceID). + Return(nil, sql.ErrNoRows) + + d := &deployer{ + deployTaskStore: mockDeployTaskStore, + } + + lreader, err := d.Logs(context.TODO(), dr) + require.NotNil(t, err) + require.Nil(t, lreader) + + }) + t.Run("get log reader", func(t *testing.T) { + dr := types.DeployRepo{ + SpaceID: 1, + DeployID: 1, + OrderDetailID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, + } + deploy := &database.Deploy{ + Status: common.Running, + SvcName: "svc", + } + + mockDeployTaskStore := mockdb.NewMockDeployTaskStore(t) + mockDeployTaskStore.EXPECT().GetLatestDeployBySpaceID(mock.Anything, dr.SpaceID). + Return(deploy, nil) + + mockRunner := mockrunner.NewMockRunner(t) + mockRunner.EXPECT().Logs(mock.Anything, mock.Anything).Return(nil, nil) + mockBuilder := mockbuilder.NewMockBuilder(t) + mockBuilder.EXPECT().Logs(mock.Anything, mock.Anything).Return(nil, nil) + d := &deployer{ + deployTaskStore: mockDeployTaskStore, + imageRunner: mockRunner, + imageBuilder: mockBuilder, + } + + lreader, err := d.Logs(context.TODO(), dr) + require.Nil(t, err) + require.NotNil(t, lreader) + + }) +} + +func TestDeployer_Purge(t *testing.T) { + dr := types.DeployRepo{ + SpaceID: 0, + DeployID: 1, + OrderDetailID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, + } + + mockRunner := mockrunner.NewMockRunner(t) + mockRunner.EXPECT().Purge(mock.Anything, mock.Anything).Return(&types.PurgeResponse{}, nil) + + d := &deployer{ + imageRunner: mockRunner, + } + err := d.Purge(context.TODO(), dr) + require.Nil(t, err) +} + +// func TestDeployer_Wakeup(t *testing.T) { +// startDNSServer() +// dr := types.DeployRepo{ +// SpaceID: 0, +// DeployID: 1, +// OrderDetailID: 1, +// UserUUID: "1", +// Path: "namespace/name", +// Type: types.InferenceType, +// SvcName: "svc", +// } + +// s := httptest.NewUnstartedServer(&wakeupHandler{}) +// s.Config.Addr = "svc.internal.example.com:51000" +// s.Config.ListenAndServe() +// // s.Start() +// defer s.Close() + +// d := &deployer{ +// internalRootDomain: "internal.example.com:51000", +// } +// err := d.Wakeup(context.TODO(), dr) +// require.Nil(t, err) +// } + +// type wakeupHandler struct { +// } + +// func (h *wakeupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { +// w.WriteHeader(http.StatusOK) +// } + +// func startDNSServer() (string, error) { +// dnsServerAddr := ":53" // DNS 服务器的端口 +// go func() { +// // Set up a new DNS server +// dnsServer := &dns.Server{Addr: dnsServerAddr, Net: "udp"} + +// dns.HandleFunc("svc.internal.example.com", func(w dns.ResponseWriter, r *dns.Msg) { +// m := new(dns.Msg) +// m.SetReply(r) +// m.Authoritative = true +// m.Answer = append(m.Answer, &dns.A{ +// Hdr: dns.RR_Header{ +// Name: "svc.internal.example.com", +// Rrtype: dns.TypeA, +// Class: dns.ClassINET, +// Ttl: 60, +// }, +// A: net.ParseIP("127.0.0.1"), +// }) +// w.WriteMsg(m) +// }) + +// if err := dnsServer.ListenAndServe(); err != nil { +// panic(err) +// } +// }() +// return dnsServerAddr, nil +// } + +func TestDeployer_Exists(t *testing.T) { + dr := types.DeployRepo{ + SpaceID: 0, + DeployID: 1, + OrderDetailID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, + } + + t.Run("fail to check", func(t *testing.T) { + mockRunner := mockrunner.NewMockRunner(t) + mockRunner.EXPECT().Exist(mock.Anything, mock.Anything). + Return(&types.StatusResponse{ + Code: -1, + }, nil) + + d := &deployer{ + imageRunner: mockRunner, + } + resp, err := d.Exist(context.TODO(), dr) + require.NotNil(t, err) + require.True(t, resp) + }) + t.Run("success to check", func(t *testing.T) { + mockRunner := mockrunner.NewMockRunner(t) + mockRunner.EXPECT().Exist(mock.Anything, mock.Anything). + Return(&types.StatusResponse{ + Code: 1, + }, nil) + + d := &deployer{ + imageRunner: mockRunner, + } + resp, err := d.Exist(context.TODO(), dr) + require.Nil(t, err) + require.True(t, resp) + }) + + t.Run("service not exist", func(t *testing.T) { + mockRunner := mockrunner.NewMockRunner(t) + mockRunner.EXPECT().Exist(mock.Anything, mock.Anything). + Return(&types.StatusResponse{ + Code: 2, + }, nil) + + d := &deployer{ + imageRunner: mockRunner, + } + resp, err := d.Exist(context.TODO(), dr) + require.Nil(t, err) + require.False(t, resp) + }) +} + +func TestDeployer_GetReplica(t *testing.T) { + dr := types.DeployRepo{ + SpaceID: 0, + DeployID: 1, + OrderDetailID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, + } + + t.Run("fail to check", func(t *testing.T) { + mockRunner := mockrunner.NewMockRunner(t) + mockRunner.EXPECT().GetReplica(mock.Anything, mock.Anything). + Return(nil, errors.New("error")) + + d := &deployer{ + imageRunner: mockRunner, + } + actualReplica, desiredReplica, instances, err := d.GetReplica(context.TODO(), dr) + require.NotNil(t, err) + require.Equal(t, 0, actualReplica) + require.Equal(t, 0, desiredReplica) + require.Equal(t, 0, len(instances)) + }) + + t.Run("success", func(t *testing.T) { + mockRunner := mockrunner.NewMockRunner(t) + mockRunner.EXPECT().GetReplica(mock.Anything, mock.Anything). + Return(&types.ReplicaResponse{ + DeployID: 1, + Code: 1, + Message: "", + ActualReplica: 1, + DesiredReplica: 1, + Instances: []types.Instance{ + { + Name: "name1", + Status: "running", + }, + }, + }, nil) + + d := &deployer{ + imageRunner: mockRunner, + } + actualReplica, desiredReplica, instances, err := d.GetReplica(context.TODO(), dr) + require.Nil(t, err) + require.Equal(t, 1, actualReplica) + require.Equal(t, 1, desiredReplica) + require.Equal(t, 1, len(instances)) + }) +} + +func TestDeployer_InstanceLogs(t *testing.T) { + dr := types.DeployRepo{ + SpaceID: 0, + DeployID: 1, + OrderDetailID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, + } + + mockRunner := mockrunner.NewMockRunner(t) + runLog := make(chan string) + defer close(runLog) + mockRunner.EXPECT().InstanceLogs(mock.Anything, mock.Anything). + Return(runLog, nil) + + d := &deployer{ + imageRunner: mockRunner, + } + lreader, err := d.InstanceLogs(context.TODO(), dr) + require.Nil(t, err) + require.Nil(t, lreader.buildLogs) + require.NotNil(t, lreader.RunLog()) +} + +func TestDeployer_ListCluster(t *testing.T) { + + clusterResp := []types.ClusterResponse{ + { + ClusterID: "cluster1", + Region: "us-east-1", + Zone: "us-east-1a", + Provider: "aws", + Enable: false, + Nodes: map[string]types.NodeResourceInfo{ + "node1": { + NodeName: "node1", + XPUModel: "", + TotalCPU: 1, + AvailableCPU: 1, + TotalXPU: 0, + AvailableXPU: 0, + GPUVendor: "nvidia", + TotalMem: 0, + AvailableMem: 128, + XPUCapacityLabel: "", + }, + }, + }, + } + mockRunner := mockrunner.NewMockRunner(t) + mockRunner.EXPECT().ListCluster(mock.Anything). + Return(clusterResp, nil) + + d := &deployer{ + imageRunner: mockRunner, + } + clusterResources, err := d.ListCluster(context.TODO()) + require.Nil(t, err) + require.Len(t, clusterResources, 1) + require.Len(t, clusterResources[0].Resources, 1) +} + +func TestDeployer_UpdateDeploy(t *testing.T) { + var runtimeFrameworkID int64 = 1 + var ResourceID int64 = 1 + var deployName = "deploy1" + var env = "{}" + var MinReplica = 1 + var MaxReplica = 1 + var Revision = "1" + var SecureLevel = 1 + var ClusterID = "cluster1" + dur := &types.DeployUpdateReq{ + RuntimeFrameworkID: &runtimeFrameworkID, + ResourceID: &ResourceID, + DeployName: &deployName, + Env: &env, + MinReplica: &MinReplica, + MaxReplica: &MaxReplica, + Revision: &Revision, + SecureLevel: &SecureLevel, + ClusterID: &ClusterID, + } + + dbdeploy := &database.Deploy{} + + mockRTFM := mockdb.NewMockRuntimeFrameworksStore(t) + mockRTFM.EXPECT().FindEnabledByID(mock.Anything, runtimeFrameworkID). + Return(&database.RuntimeFramework{ + FrameImage: "gpu_image", + FrameName: "gpu", + ContainerPort: 8000, + }, nil) + mockSpaceResourceStore := mockdb.NewMockSpaceResourceStore(t) + mockSpaceResourceStore.EXPECT().FindByID(mock.Anything, ResourceID). + Return(&database.SpaceResource{ + ID: ResourceID, + Resources: `{ "gpu": { "type": "A10", "num": "1", "resource_name": "nvidia.com/gpu", "labels": { "aliyun.accelerator/nvidia_name": "NVIDIA-A10" } }, "cpu": { "type": "Intel", "num": "12" }, "memory": "46Gi" }`, + }, nil) + + mockDeployTaskStore := mockdb.NewMockDeployTaskStore(t) + mockDeployTaskStore.EXPECT().UpdateDeploy(mock.Anything, mock.Anything).Return(nil) + d := &deployer{ + runtimeFrameworkStore: mockRTFM, + deployTaskStore: mockDeployTaskStore, + spaceResourceStore: mockSpaceResourceStore, + } + err := d.UpdateDeploy(context.TODO(), dur, dbdeploy) + require.Nil(t, err) +} + +func TestDeployer_getResourceMap(t *testing.T) { + mockSpaceResourceStore := mockdb.NewMockSpaceResourceStore(t) + mockSpaceResourceStore.EXPECT().FindAll(mock.Anything). + Return([]database.SpaceResource{ + {ID: 1, Name: "Resource1"}, + {ID: 2, Name: "Resource2"}, + }, nil) + + d := &deployer{ + spaceResourceStore: mockSpaceResourceStore, + } + resources := d.getResourceMap() + require.Equal(t, map[string]string{ + "1": "Resource1", + "2": "Resource2", + }, resources) +} + +func TestDeployer_CheckResource(t *testing.T) { + + cases := []struct { + hardware *types.HardWare + available bool + }{ + {&types.HardWare{}, true}, + {&types.HardWare{ + Gpu: types.GPU{Num: "1", Type: "t1"}, + Cpu: types.CPU{Num: "2"}, + }, true}, + {&types.HardWare{ + Gpu: types.GPU{Num: "1", Type: "t2"}, + Cpu: types.CPU{Num: "2"}, + }, false}, + {&types.HardWare{ + Gpu: types.GPU{Num: "15", Type: "t1"}, + Cpu: types.CPU{Num: "2"}, + }, false}, + {&types.HardWare{ + Gpu: types.GPU{Num: "1", Type: "t1"}, + Cpu: types.CPU{Num: "20"}, + }, false}, + } + + for _, c := range cases { + c.hardware.Memory = "1Gi" + v := CheckResource(&types.ClusterRes{ + Resources: []types.NodeResourceInfo{ + {AvailableXPU: 10, XPUModel: "t1", AvailableCPU: 10, AvailableMem: 10000}, + }, + }, c.hardware) + require.Equal(t, c.available, v, c.hardware) + } + +} + +func TestDeployer_SubmitEvaluation(t *testing.T) { + tester := newTestDeployer(t) + ctx := context.TODO() + + tester.mocks.runner.EXPECT().SubmitWorkFlow(ctx, mock.Anything).RunAndReturn( + func(ctx context.Context, awfr *types.ArgoWorkFlowReq) (*types.ArgoWorkFlowRes, error) { + require.Equal(t, map[string]string{ + "REVISION": "main", + "MODEL_ID": "m1", + "DATASET_IDS": "", + "ACCESS_TOKEN": "k", + "HF_ENDPOINT": "dl", + }, awfr.Templates[0].Env) + return &types.ArgoWorkFlowRes{ID: 1}, nil + }, + ) + resp, err := tester.SubmitEvaluation(ctx, types.EvaluationReq{ + ModelId: "m1", + Token: "k", + DownloadEndpoint: "dl", + }) + require.NoError(t, err) + require.Equal(t, &types.ArgoWorkFlowRes{ID: 1}, resp) +} + +func TestDeployer_ListEvaluations(t *testing.T) { + tester := newTestDeployer(t) + ctx := context.TODO() + + tester.mocks.runner.EXPECT().ListWorkFlows(ctx, "user", 10, 1).Return( + &types.ArgoWorkFlowListRes{Total: 100}, nil, + ) + r, err := tester.ListEvaluations(ctx, "user", 10, 1) + require.NoError(t, err) + require.Equal(t, &types.ArgoWorkFlowListRes{Total: 100}, r) +} + +func TestDeployer_GetEvaluation(t *testing.T) { + tester := newTestDeployer(t) + ctx := context.TODO() + + tester.mocks.runner.EXPECT().GetWorkFlow(ctx, types.EvaluationGetReq{}).Return( + &types.ArgoWorkFlowRes{ID: 100}, nil, + ) + r, err := tester.GetEvaluation(ctx, types.EvaluationGetReq{}) + require.NoError(t, err) + require.Equal(t, &types.ArgoWorkFlowRes{ID: 100}, r) +} diff --git a/builder/deploy/init.go b/builder/deploy/init.go index 736404ae..ded6d3e5 100644 --- a/builder/deploy/init.go +++ b/builder/deploy/init.go @@ -14,6 +14,8 @@ var ( defaultDeployer Deployer ) +type DeployConfig = common.DeployConfig + func Init(c common.DeployConfig) error { // ib := imagebuilder.NewLocalBuilder() ib, err := imagebuilder.NewRemoteBuilder(c.ImageBuilderURL) @@ -26,7 +28,7 @@ func Init(c common.DeployConfig) error { } fifoScheduler = scheduler.NewFIFOScheduler(ib, ir, c) - deployer, err := newDeployer(fifoScheduler, ib, ir) + deployer, err := newDeployer(fifoScheduler, ib, ir, c) if err != nil { return fmt.Errorf("failed to create deployer:%w", err) } From 8bd0791f7910f44e69bb51354278dd1487881aba Mon Sep 17 00:00:00 2001 From: yiling Date: Mon, 23 Dec 2024 17:26:53 +0800 Subject: [PATCH 17/34] remove unused fields --- common/types/model.go | 1 - common/types/repo.go | 1 - component/model.go | 7 ++++--- component/model_ce.go | 4 ++++ component/model_ce_test.go | 22 ++++++++++++++++++++++ 5 files changed, 30 insertions(+), 5 deletions(-) diff --git a/common/types/model.go b/common/types/model.go index 994218dd..028aa4a1 100644 --- a/common/types/model.go +++ b/common/types/model.go @@ -214,7 +214,6 @@ type ModelRunReq struct { MaxReplica int `json:"max_replica"` Revision string `json:"revision"` SecureLevel int `json:"secure_level"` - OrderDetailID int64 `json:"order_detail_id"` } var _ SensitiveRequestV2 = (*ModelRunReq)(nil) diff --git a/common/types/repo.go b/common/types/repo.go index 03ced317..ba387aac 100644 --- a/common/types/repo.go +++ b/common/types/repo.go @@ -145,7 +145,6 @@ type DeployRepo struct { SKU string `json:"sku,omitempty"` ResourceType string `json:"resource_type,omitempty"` RepoTag string `json:"repo_tag,omitempty"` - OrderDetailID int64 `json:"order_detail_id,omitempty"` } type RuntimeFrameworkReq struct { diff --git a/component/model.go b/component/model.go index 07f518f3..ef7da6b7 100644 --- a/component/model.go +++ b/component/model.go @@ -942,7 +942,7 @@ func (c *modelComponentImpl) Deploy(ctx context.Context, deployReq types.DeployA containerImg := c.containerImg(frame, hardware) // create deploy for model - return c.deployer.Deploy(ctx, types.DeployRepo{ + dp := types.DeployRepo{ DeployName: req.DeployName, SpaceID: 0, Path: m.Repository.Path, @@ -964,8 +964,9 @@ func (c *modelComponentImpl) Deploy(ctx context.Context, deployReq types.DeployA Type: deployReq.DeployType, UserUUID: user.UUID, SKU: strconv.FormatInt(resource.ID, 10), - OrderDetailID: req.OrderDetailID, - }) + } + dp = modelRunUpdateDeployRepo(dp, req) + return c.deployer.Deploy(ctx, dp) } func (c *modelComponentImpl) ListModelsByRuntimeFrameworkID(ctx context.Context, currentUser string, per, page int, id int64, deployType int) ([]types.Model, int, error) { diff --git a/component/model_ce.go b/component/model_ce.go index b0206b60..b04c2cdb 100644 --- a/component/model_ce.go +++ b/component/model_ce.go @@ -31,3 +31,7 @@ func (c *modelComponentImpl) containerImg(frame *database.RuntimeFramework, hard return containerImg } + +func modelRunUpdateDeployRepo(dp types.DeployRepo, req types.ModelRunReq) types.DeployRepo { + return dp +} diff --git a/component/model_ce_test.go b/component/model_ce_test.go index 6925c5db..54b2e688 100644 --- a/component/model_ce_test.go +++ b/component/model_ce_test.go @@ -70,3 +70,25 @@ func TestModelComponent_Deploy(t *testing.T) { require.Equal(t, int64(111), id) } + +func TestModelComponent_containerImg(t *testing.T) { + ctx := context.TODO() + mc := initializeTestModelComponent(ctx, t) + + cases := []struct { + hd types.HardWare + img string + }{ + {hd: types.HardWare{}, img: "cpu"}, + {hd: types.HardWare{Gpu: types.GPU{}}, img: "cpu"}, + {hd: types.HardWare{Gpu: types.GPU{Num: "1"}}, img: "gpu"}, + } + + for _, c := range cases { + v := mc.containerImg(&database.RuntimeFramework{ + FrameImage: "gpu", + FrameCpuImage: "cpu", + }, c.hd) + require.Equal(t, c.img, v) + } +} From ed8c34cda0d444bc3dd447718037ae3bfef8b10c Mon Sep 17 00:00:00 2001 From: yiling Date: Mon, 23 Dec 2024 17:31:05 +0800 Subject: [PATCH 18/34] fix deployer tests --- builder/deploy/deployer_test.go | 170 +++++++++----------------------- 1 file changed, 49 insertions(+), 121 deletions(-) diff --git a/builder/deploy/deployer_test.go b/builder/deploy/deployer_test.go index bc8e1508..07ac662c 100644 --- a/builder/deploy/deployer_test.go +++ b/builder/deploy/deployer_test.go @@ -177,11 +177,10 @@ func TestDeployer_Deploy(t *testing.T) { t.Run("use on-demand resource and skip build task", func(t *testing.T) { dr := types.DeployRepo{ - OrderDetailID: 0, - UserUUID: "1", - Path: "namespace/name", - Type: types.InferenceType, - ImageID: "image:1", + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, + ImageID: "image:1", } buildTask := database.DeployTask{ @@ -222,10 +221,9 @@ func TestDeployer_Deploy(t *testing.T) { func TestDeployer_Status(t *testing.T) { t.Run("no deploy", func(t *testing.T) { dr := types.DeployRepo{ - OrderDetailID: 1, - UserUUID: "1", - Path: "namespace/name", - Type: types.InferenceType, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, } mockDeployTaskStore := mockdb.NewMockDeployTaskStore(t) @@ -245,11 +243,10 @@ func TestDeployer_Status(t *testing.T) { }) t.Run("cache miss and running", func(t *testing.T) { dr := types.DeployRepo{ - DeployID: 1, - OrderDetailID: 1, - UserUUID: "1", - Path: "namespace/name", - Type: types.InferenceType, + DeployID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, } deploy := &database.Deploy{ Status: common.Running, @@ -275,11 +272,10 @@ func TestDeployer_Status(t *testing.T) { t.Run("cache miss and not running", func(t *testing.T) { dr := types.DeployRepo{ - DeployID: 1, - OrderDetailID: 1, - UserUUID: "1", - Path: "namespace/name", - Type: types.InferenceType, + DeployID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, } deploy := &database.Deploy{ Status: common.BuildSuccess, @@ -305,12 +301,11 @@ func TestDeployer_Status(t *testing.T) { t.Run("cache hit and running", func(t *testing.T) { dr := types.DeployRepo{ - DeployID: 1, - OrderDetailID: 1, - UserUUID: "1", - Path: "namespace/name", - Type: types.InferenceType, - ModelID: 1, + DeployID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, + ModelID: 1, } // build success status in db deploy := &database.Deploy{ @@ -366,11 +361,10 @@ func TestDeployer_Status(t *testing.T) { func TestDeployer_Logs(t *testing.T) { t.Run("no deploy", func(t *testing.T) { dr := types.DeployRepo{ - OrderDetailID: 1, - UserUUID: "1", - Path: "namespace/name", - Type: types.InferenceType, - SpaceID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, + SpaceID: 1, } mockDeployTaskStore := mockdb.NewMockDeployTaskStore(t) @@ -388,12 +382,11 @@ func TestDeployer_Logs(t *testing.T) { }) t.Run("get log reader", func(t *testing.T) { dr := types.DeployRepo{ - SpaceID: 1, - DeployID: 1, - OrderDetailID: 1, - UserUUID: "1", - Path: "namespace/name", - Type: types.InferenceType, + SpaceID: 1, + DeployID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, } deploy := &database.Deploy{ Status: common.Running, @@ -423,12 +416,11 @@ func TestDeployer_Logs(t *testing.T) { func TestDeployer_Purge(t *testing.T) { dr := types.DeployRepo{ - SpaceID: 0, - DeployID: 1, - OrderDetailID: 1, - UserUUID: "1", - Path: "namespace/name", - Type: types.InferenceType, + SpaceID: 0, + DeployID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, } mockRunner := mockrunner.NewMockRunner(t) @@ -441,75 +433,13 @@ func TestDeployer_Purge(t *testing.T) { require.Nil(t, err) } -// func TestDeployer_Wakeup(t *testing.T) { -// startDNSServer() -// dr := types.DeployRepo{ -// SpaceID: 0, -// DeployID: 1, -// OrderDetailID: 1, -// UserUUID: "1", -// Path: "namespace/name", -// Type: types.InferenceType, -// SvcName: "svc", -// } - -// s := httptest.NewUnstartedServer(&wakeupHandler{}) -// s.Config.Addr = "svc.internal.example.com:51000" -// s.Config.ListenAndServe() -// // s.Start() -// defer s.Close() - -// d := &deployer{ -// internalRootDomain: "internal.example.com:51000", -// } -// err := d.Wakeup(context.TODO(), dr) -// require.Nil(t, err) -// } - -// type wakeupHandler struct { -// } - -// func (h *wakeupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { -// w.WriteHeader(http.StatusOK) -// } - -// func startDNSServer() (string, error) { -// dnsServerAddr := ":53" // DNS 服务器的端口 -// go func() { -// // Set up a new DNS server -// dnsServer := &dns.Server{Addr: dnsServerAddr, Net: "udp"} - -// dns.HandleFunc("svc.internal.example.com", func(w dns.ResponseWriter, r *dns.Msg) { -// m := new(dns.Msg) -// m.SetReply(r) -// m.Authoritative = true -// m.Answer = append(m.Answer, &dns.A{ -// Hdr: dns.RR_Header{ -// Name: "svc.internal.example.com", -// Rrtype: dns.TypeA, -// Class: dns.ClassINET, -// Ttl: 60, -// }, -// A: net.ParseIP("127.0.0.1"), -// }) -// w.WriteMsg(m) -// }) - -// if err := dnsServer.ListenAndServe(); err != nil { -// panic(err) -// } -// }() -// return dnsServerAddr, nil -// } - func TestDeployer_Exists(t *testing.T) { dr := types.DeployRepo{ - SpaceID: 0, - DeployID: 1, - OrderDetailID: 1, - UserUUID: "1", - Path: "namespace/name", - Type: types.InferenceType, + SpaceID: 0, + DeployID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, } t.Run("fail to check", func(t *testing.T) { @@ -559,12 +489,11 @@ func TestDeployer_Exists(t *testing.T) { func TestDeployer_GetReplica(t *testing.T) { dr := types.DeployRepo{ - SpaceID: 0, - DeployID: 1, - OrderDetailID: 1, - UserUUID: "1", - Path: "namespace/name", - Type: types.InferenceType, + SpaceID: 0, + DeployID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, } t.Run("fail to check", func(t *testing.T) { @@ -612,12 +541,11 @@ func TestDeployer_GetReplica(t *testing.T) { func TestDeployer_InstanceLogs(t *testing.T) { dr := types.DeployRepo{ - SpaceID: 0, - DeployID: 1, - OrderDetailID: 1, - UserUUID: "1", - Path: "namespace/name", - Type: types.InferenceType, + SpaceID: 0, + DeployID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, } mockRunner := mockrunner.NewMockRunner(t) From 7cb7c73fcb33496588c5d717913f5ceaffd69ca3 Mon Sep 17 00:00:00 2001 From: "yiling.ji" Date: Wed, 18 Dec 2024 03:11:14 +0000 Subject: [PATCH 19/34] Merge branch 'feature/handler_tests' into 'main' Add repo handler tests See merge request product/starhub/starhub-server!731 --- api/handler/helper_test.go | 128 ++++ api/handler/repo.go | 12 +- api/handler/repo_test.go | 1172 ++++++++++++++++++++++++++++++++++-- api/httpbase/user.go | 4 +- 4 files changed, 1244 insertions(+), 72 deletions(-) create mode 100644 api/handler/helper_test.go diff --git a/api/handler/helper_test.go b/api/handler/helper_test.go new file mode 100644 index 00000000..4c62d995 --- /dev/null +++ b/api/handler/helper_test.go @@ -0,0 +1,128 @@ +package handler + +import ( + "bytes" + "encoding/json" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gin-gonic/gin" + "github.com/spf13/cast" + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/api/httpbase" + "opencsg.com/csghub-server/component" +) + +type GinTester struct { + ginHandler gin.HandlerFunc + ctx *gin.Context + response *httptest.ResponseRecorder + OKText string // text of httpbase.OK +} + +func NewGinTester() *GinTester { + response := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(response) + ctx.Request = &http.Request{ + URL: &url.URL{}, + } + + return &GinTester{ + ginHandler: nil, + ctx: ctx, + response: response, + OKText: "OK", + } +} + +func (g *GinTester) Handler(handler gin.HandlerFunc) { + g.ginHandler = handler +} + +func (g *GinTester) Execute() { + g.ginHandler(g.ctx) +} +func (g *GinTester) WithUser() *GinTester { + g.ctx.Set(httpbase.CurrentUserCtxVar, "u") + return g +} + +func (g *GinTester) WithParam(key string, value string) *GinTester { + params := g.ctx.Params + for i, param := range params { + if param.Key == key { + params[i].Value = value + return g + } + } + g.ctx.AddParam(key, value) + return g +} + +func (g *GinTester) WithKV(key string, value any) *GinTester { + g.ctx.Set(key, value) + return g +} + +func (g *GinTester) WithBody(t *testing.T, body any) *GinTester { + b, err := json.Marshal(body) + require.Nil(t, err) + g.ctx.Request.Body = io.NopCloser(bytes.NewBuffer(b)) + return g +} + +func (g *GinTester) WithMultipartForm(mf *multipart.Form) *GinTester { + g.ctx.Request.MultipartForm = mf + return g +} + +func (g *GinTester) WithQuery(key, value string) *GinTester { + q := g.ctx.Request.URL.Query() + q.Add(key, value) + g.ctx.Request.URL.RawQuery = q.Encode() + return g +} + +func (g *GinTester) AddPagination(page int, per int) *GinTester { + g.WithQuery("page", cast.ToString(page)) + g.WithQuery("per", cast.ToString(per)) + return g +} + +func (g *GinTester) ResponseEq(t *testing.T, code int, msg string, expected any) { + var r = struct { + Msg string `json:"msg"` + Data any `json:"data,omitempty"` + }{ + Msg: msg, + Data: expected, + } + b, err := json.Marshal(r) + require.NoError(t, err) + require.Equal(t, code, g.response.Code, g.response.Body.String()) + require.JSONEq(t, string(b), g.response.Body.String()) + +} + +func (g *GinTester) ResponseEqSimple(t *testing.T, code int, expected any) { + b, err := json.Marshal(expected) + require.NoError(t, err) + require.Equal(t, code, g.response.Code) + require.JSONEq(t, string(b), g.response.Body.String()) + +} + +func (g *GinTester) RequireUser(t *testing.T) { + // use a tmp ctx to test no user case + tmp := NewGinTester() + tmp.ctx.Params = g.ctx.Params + g.ginHandler(tmp.ctx) + tmp.ResponseEq(t, http.StatusUnauthorized, component.ErrUserNotFound.Error(), nil) + // add user to original test ctx now + _ = g.WithUser() + +} diff --git a/api/handler/repo.go b/api/handler/repo.go index 075673a7..b763ad83 100644 --- a/api/handler/repo.go +++ b/api/handler/repo.go @@ -29,12 +29,14 @@ func NewRepoHandler(config *config.Config) (*RepoHandler, error) { return nil, err } return &RepoHandler{ - c: uc, + c: uc, + deployStatusCheckInterval: 5 * time.Second, }, nil } type RepoHandler struct { - c component.RepoComponent + c component.RepoComponent + deployStatusCheckInterval time.Duration } // CreateRepoFile godoc @@ -521,7 +523,7 @@ func (h *RepoHandler) Tags(ctx *gin.Context) { func (h *RepoHandler) UpdateTags(ctx *gin.Context) { currentUser := httpbase.GetCurrentUser(ctx) if currentUser == "" { - httpbase.UnauthorizedError(ctx, httpbase.ErrorNeedLogin) + httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) return } namespace, name, err := common.GetNamespaceAndNameFromContext(ctx) @@ -1695,7 +1697,7 @@ func (h *RepoHandler) DeployStatus(ctx *gin.Context) { slog.Info("deploy handler status request context done", slog.Any("error", ctx.Request.Context().Err())) return default: - time.Sleep(time.Second * 5) + time.Sleep(h.deployStatusCheckInterval) // user http request context instead of gin context, so that server knows the life cycle of the request _, status, instances, err := h.c.DeployStatus(ctx.Request.Context(), repoType, namespace, name, deployID) if err != nil { @@ -2136,7 +2138,7 @@ func (h *RepoHandler) ServerlessStatus(ctx *gin.Context) { slog.Info("deploy handler status request context done", slog.Any("error", ctx.Request.Context().Err())) return default: - time.Sleep(time.Second * 5) + time.Sleep(h.deployStatusCheckInterval) // user http request context instead of gin context, so that server knows the life cycle of the request _, status, instances, err := h.c.DeployStatus(ctx.Request.Context(), types.ModelRepo, namespace, name, deployID) if err != nil { diff --git a/api/handler/repo_test.go b/api/handler/repo_test.go index 6d252166..e935995e 100644 --- a/api/handler/repo_test.go +++ b/api/handler/repo_test.go @@ -1,132 +1,1172 @@ package handler import ( + "bytes" + "context" "encoding/json" "errors" + "io" + "mime/multipart" "net/http" "net/http/httptest" + "strings" "testing" + "time" + "github.com/alibabacloud-go/tea/tea" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/builder/deploy" + "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/types" "opencsg.com/csghub-server/component" ) +type RepoTester struct { + *GinTester + handler *RepoHandler + mocks struct { + repo *mockcomponent.MockRepoComponent + } +} + +func NewRepoTester(t *testing.T) *RepoTester { + tester := &RepoTester{GinTester: NewGinTester()} + tester.mocks.repo = mockcomponent.NewMockRepoComponent(t) + tester.handler = &RepoHandler{tester.mocks.repo, 0} + tester.WithParam("name", "r") + tester.WithParam("namespace", "u") + return tester + +} + +func (rt *RepoTester) WithHandleFunc(fn func(rp *RepoHandler) gin.HandlerFunc) *RepoTester { + rt.ginHandler = fn(rt.handler) + return rt +} + +func TestRepoHandler_CreateFile(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.CreateFile + }) + tester.RequireUser(t) + + tester.mocks.repo.EXPECT().CreateFile(tester.ctx, &types.CreateFileReq{ + Message: "foo", + Branch: "main", + Content: "bar", + Username: "u", + Namespace: "u", + Name: "r", + CurrentUser: "u", + FilePath: "foo", + }).Return(&types.CreateFileResp{}, nil) + tester.WithParam("file_path", "foo") + req := &types.CreateFileReq{ + Message: "foo", + Branch: "main", + Content: "bar", + } + tester.WithBody(t, req) + + tester.Execute() + tester.ResponseEq(t, http.StatusOK, tester.OKText, &types.CreateFileResp{}) + +} + +func TestRepoHandler_UpdateFile(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.UpdateFile + }) + tester.RequireUser(t) + + tester.mocks.repo.EXPECT().UpdateFile(tester.ctx, &types.UpdateFileReq{ + Message: "foo", + Branch: "main", + Content: "bar", + Username: "u", + Namespace: "u", + Name: "r", + CurrentUser: "u", + FilePath: "foo", + }).Return(&types.UpdateFileResp{}, nil) + tester.WithParam("file_path", "foo") + req := &types.CreateFileReq{ + Message: "foo", + Branch: "main", + Content: "bar", + } + tester.WithBody(t, req) + + tester.Execute() + tester.ResponseEq(t, http.StatusOK, tester.OKText, &types.UpdateFileResp{}) + +} + +func TestRepoHandler_Commits(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.Commits + }) + + tester.WithUser() + tester.mocks.repo.EXPECT().Commits(tester.ctx, &types.GetCommitsReq{ + Namespace: "u", + Name: "r", + Ref: "main", + Page: 1, + Per: 10, + RepoType: types.ModelRepo, + CurrentUser: "u", + }).Return([]types.Commit{{ID: "c1"}}, &types.RepoPageOpts{Total: 100, PageCount: 1}, nil) + tester.WithParam("file_path", "foo") + tester.WithQuery("ref", "main") + tester.AddPagination(1, 10) + tester.WithKV("repo_type", types.ModelRepo) + + tester.Execute() + tester.ResponseEq(t, http.StatusOK, tester.OKText, gin.H{ + "commits": []types.Commit{{ID: "c1"}}, + "total": 100, + "page_count": 1, + }) + +} + func TestRepoHandler_LastCommit(t *testing.T) { t.Run("forbidden", func(t *testing.T) { - comp := mockcomponent.NewMockRepoComponent(t) - h := &RepoHandler{comp} - - response := httptest.NewRecorder() - ginc, _ := gin.CreateTestContext(response) - ginc.AddParam("namespace", "user_name_1") - ginc.AddParam("name", "repo_name_1") + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.LastCommit + }) //user does not have permission to access repo - comp.EXPECT().LastCommit(mock.Anything, mock.Anything).Return(nil, component.ErrForbidden).Once() - h.LastCommit(ginc) - require.Equal(t, http.StatusForbidden, response.Code) + tester.mocks.repo.EXPECT().LastCommit(mock.Anything, mock.Anything).Return(nil, component.ErrForbidden).Once() + tester.Execute() + require.Equal(t, http.StatusForbidden, tester.response.Code) }) t.Run("server error", func(t *testing.T) { - comp := mockcomponent.NewMockRepoComponent(t) - h := &RepoHandler{comp} - - response := httptest.NewRecorder() - ginc, _ := gin.CreateTestContext(response) - ginc.AddParam("namespace", "user_name_1") - ginc.AddParam("name", "repo_name_1") - + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.LastCommit + }) commit := &types.Commit{} - comp.EXPECT().LastCommit(mock.Anything, mock.Anything).Return(commit, errors.New("custome error")).Once() - h.LastCommit(ginc) - require.Equal(t, http.StatusInternalServerError, response.Code) + tester.mocks.repo.EXPECT().LastCommit(mock.Anything, mock.Anything).Return(commit, errors.New("custome error")).Once() + tester.Execute() + require.Equal(t, http.StatusInternalServerError, tester.response.Code) }) t.Run("success", func(t *testing.T) { - comp := mockcomponent.NewMockRepoComponent(t) - h := &RepoHandler{comp} - - response := httptest.NewRecorder() - ginc, _ := gin.CreateTestContext(response) - ginc.AddParam("namespace", "user_name_1") - ginc.AddParam("name", "repo_name_1") + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.LastCommit + }) commit := &types.Commit{} - commit.AuthorName = "user_name_1" + commit.AuthorName = "u" commit.ID = uuid.New().String() - comp.EXPECT().LastCommit(mock.Anything, mock.Anything).Return(commit, nil).Once() - h.LastCommit(ginc) - require.Equal(t, http.StatusOK, response.Code) + tester.mocks.repo.EXPECT().LastCommit(mock.Anything, mock.Anything).Return(commit, nil).Once() + tester.Execute() + require.Equal(t, http.StatusOK, tester.response.Code) var r = struct { Code int `json:"code,omitempty"` Msg string `json:"msg"` Data *types.Commit `json:"data,omitempty"` }{} - err := json.Unmarshal(response.Body.Bytes(), &r) + err := json.Unmarshal(tester.response.Body.Bytes(), &r) require.Empty(t, err) require.Equal(t, commit.ID, r.Data.ID) }) } -func TestRepoHandler_Tree(t *testing.T) { - t.Run("forbidden", func(t *testing.T) { - comp := mockcomponent.NewMockRepoComponent(t) - h := &RepoHandler{comp} +func TestRepoHandler_FileRaw(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.FileRaw + }) - response := httptest.NewRecorder() - ginc, _ := gin.CreateTestContext(response) - ginc.AddParam("namespace", "user_name_1") - ginc.AddParam("name", "repo_name_1") + tester.WithUser() + tester.mocks.repo.EXPECT().FileRaw(tester.ctx, &types.GetFileReq{ + Namespace: "u", + Name: "r", + Ref: "main", + RepoType: types.ModelRepo, + CurrentUser: "u", + Path: "foo", + }).Return("data", nil) + tester.WithParam("file_path", "foo") + tester.WithQuery("ref", "main") + tester.WithKV("repo_type", types.ModelRepo) - //user does not have permission to access repo - comp.EXPECT().Tree(mock.Anything, mock.Anything).Return(nil, component.ErrForbidden).Once() - h.Tree(ginc) - require.Equal(t, http.StatusForbidden, response.Code) + tester.Execute() + tester.ResponseEq(t, http.StatusOK, tester.OKText, "data") + +} + +func TestRepoHandler_FileInfo(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.FileInfo }) - t.Run("server error", func(t *testing.T) { - comp := mockcomponent.NewMockRepoComponent(t) - h := &RepoHandler{comp} + tester.WithUser() + tester.mocks.repo.EXPECT().FileInfo(tester.ctx, &types.GetFileReq{ + Namespace: "u", + Name: "r", + Ref: "main", + RepoType: types.ModelRepo, + Path: "foo", + CurrentUser: "u", + }).Return(&types.File{Name: "foo.go"}, nil) + tester.WithParam("file_path", "foo") + tester.WithQuery("ref", "main") + tester.WithKV("repo_type", types.ModelRepo) - response := httptest.NewRecorder() - ginc, _ := gin.CreateTestContext(response) - ginc.AddParam("namespace", "user_name_1") - ginc.AddParam("name", "repo_name_1") + tester.Execute() + tester.ResponseEq(t, http.StatusOK, tester.OKText, &types.File{Name: "foo.go"}) - comp.EXPECT().Tree(mock.Anything, mock.Anything).Return(nil, errors.New("custome error")).Once() - h.Tree(ginc) - require.Equal(t, http.StatusInternalServerError, response.Code) +} + +func TestRepoHandler_DownloadFile(t *testing.T) { + + t.Run("lfs file", func(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.DownloadFile + }) + + tester.WithUser() + tester.mocks.repo.EXPECT().DownloadFile(tester.ctx, &types.GetFileReq{ + Namespace: "u", + Name: "r", + Ref: "main", + RepoType: types.ModelRepo, + CurrentUser: "u", + Path: "foo", + Lfs: true, + }, "u").Return(nil, 100, "foo", nil) + tester.WithParam("file_path", "foo") + tester.WithQuery("ref", "main") + tester.WithQuery("lfs", "true") + tester.WithKV("repo_type", types.ModelRepo) + + tester.Execute() + tester.ResponseEq(t, http.StatusOK, tester.OKText, "foo") }) - t.Run("success", func(t *testing.T) { - comp := mockcomponent.NewMockRepoComponent(t) - h := &RepoHandler{comp} + t.Run("normal file", func(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.DownloadFile + }) + + tester.WithUser() + tester.mocks.repo.EXPECT().DownloadFile(tester.ctx, &types.GetFileReq{ + Namespace: "u", + Name: "r", + Ref: "main", + RepoType: types.ModelRepo, + CurrentUser: "u", + Path: "foo", + }, "u").Return( + io.NopCloser(bytes.NewBuffer([]byte("bar"))), 100, "foo", nil, + ) + tester.WithParam("file_path", "foo") + tester.WithQuery("ref", "main") + tester.WithKV("repo_type", types.ModelRepo) + + tester.Execute() + require.Equal(t, 200, tester.response.Code) + headers := tester.response.Header() + require.Equal(t, "application/octet-stream", headers.Get("Content-Type")) + require.Equal(t, `attachment; filename="foo"`, headers.Get("Content-Disposition")) + require.Equal(t, "100", headers.Get("Content-Length")) + r := tester.response.Body.String() + require.Equal(t, "bar", r) + }) + +} + +func TestRepoHandler_Branches(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.Branches + }) + + tester.WithUser() + tester.mocks.repo.EXPECT().Branches(tester.ctx, &types.GetBranchesReq{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + CurrentUser: "u", + Per: 10, + Page: 1, + }).Return([]types.Branch{{Name: "main"}}, nil) + tester.WithKV("repo_type", types.ModelRepo) + tester.AddPagination(1, 10) + + tester.Execute() + tester.ResponseEq(t, http.StatusOK, tester.OKText, []types.Branch{{Name: "main"}}) + +} + +func TestRepoHandler_Tags(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.Tags + }) + + tester.WithUser() + tester.mocks.repo.EXPECT().Tags(tester.ctx, &types.GetTagsReq{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + CurrentUser: "u", + }).Return([]database.Tag{{Name: "main"}}, nil) + tester.WithKV("repo_type", types.ModelRepo) + + tester.Execute() + tester.ResponseEq(t, http.StatusOK, tester.OKText, []database.Tag{{Name: "main"}}) + +} - response := httptest.NewRecorder() - ginc, _ := gin.CreateTestContext(response) - ginc.AddParam("namespace", "user_name_1") - ginc.AddParam("name", "repo_name_1") +func TestRepoHandler_UpdateTags(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.UpdateTags + }) + tester.RequireUser(t) + + tester.mocks.repo.EXPECT().UpdateTags( + tester.ctx, "u", "r", types.ModelRepo, + "cat", "u", []string{"foo", "bar"}, + ).Return(nil) + tester.WithKV("repo_type", types.ModelRepo) + tester.WithParam("category", "cat") + tester.WithBody(t, []string{"foo", "bar"}) + + tester.Execute() + tester.ResponseEq(t, http.StatusOK, tester.OKText, nil) + +} + +func TestRepoHandler_Tree(t *testing.T) { + t.Run("forbidden", func(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.Tree + }) + //user does not have permission to access repo + tester.mocks.repo.EXPECT().Tree(mock.Anything, mock.Anything).Return(nil, component.ErrForbidden).Once() + tester.Execute() + require.Equal(t, http.StatusForbidden, tester.response.Code) + }) + + t.Run("server error", func(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.Tree + }) + tester.mocks.repo.EXPECT().Tree(mock.Anything, mock.Anything).Return(nil, errors.New("custome error")).Once() + tester.Execute() + require.Equal(t, http.StatusInternalServerError, tester.response.Code) + }) + + t.Run("success", func(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.Tree + }) var tree []*types.File - comp.EXPECT().Tree(mock.Anything, mock.Anything).Return(tree, nil).Once() - h.Tree(ginc) - require.Equal(t, http.StatusOK, response.Code) + tester.mocks.repo.EXPECT().Tree(mock.Anything, mock.Anything).Return(tree, nil).Once() + tester.Execute() + require.Equal(t, http.StatusOK, tester.response.Code) var r = struct { Code int `json:"code,omitempty"` Msg string `json:"msg"` Data []*types.File `json:"data,omitempty"` }{} - err := json.Unmarshal(response.Body.Bytes(), &r) + err := json.Unmarshal(tester.response.Body.Bytes(), &r) require.Empty(t, err) require.Equal(t, tree, r.Data) }) } + +func TestRepoHandler_UpdateDownloads(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.UpdateDownloads + }) + + tester.WithUser() + tester.mocks.repo.EXPECT().UpdateDownloads( + tester.ctx, &types.UpdateDownloadsReq{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + ReqDate: "2002-02-01", + Date: time.Date(2002, 2, 1, 0, 0, 0, 0, time.UTC), + }, + ).Return(nil) + tester.WithKV("repo_type", types.ModelRepo) + tester.WithBody(t, &types.UpdateDownloadsReq{ReqDate: "2002-02-01"}) + + tester.Execute() + tester.ResponseEq(t, http.StatusOK, tester.OKText, nil) + +} + +func TestRepoHandler_IncrDownloads(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.IncrDownloads + }) + + tester.WithUser() + tester.mocks.repo.EXPECT().IncrDownloads( + tester.ctx, types.ModelRepo, "u", "r", + ).Return(nil) + tester.WithKV("repo_type", types.ModelRepo) + tester.WithBody(t, &types.UpdateDownloadsReq{ReqDate: "2002-02-01"}) + + tester.Execute() + tester.ResponseEq(t, http.StatusOK, tester.OKText, nil) + +} + +func TestRepoHandler_UploadFile(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.UploadFile + }) + tester.RequireUser(t) + + bodyBuffer := new(bytes.Buffer) + mw := multipart.NewWriter(bodyBuffer) + err := mw.WriteField("file_path", "foo") + require.NoError(t, err) + err = mw.WriteField("message", "msg") + require.NoError(t, err) + err = mw.WriteField("branch", "main") + require.NoError(t, err) + part, err := mw.CreateFormFile("file", "file") + if err != nil { + t.Fatal(err) + } + file := strings.NewReader(`data`) + _, err = io.Copy(part, file) + require.NoError(t, err) + mw.Close() + req := httptest.NewRequest(http.MethodPost, "/", bodyBuffer) + req.Header.Set("Content-Type", mw.FormDataContentType()) + err = req.ParseMultipartForm(20) + require.NoError(t, err) + tester.WithMultipartForm(req.MultipartForm) + + tester.mocks.repo.EXPECT().UploadFile( + tester.ctx, &types.CreateFileReq{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + Content: "ZGF0YQ==", + OriginalContent: []byte("data"), + CurrentUser: "u", + Message: "msg", + Branch: "main", + FilePath: "foo", + Username: "u", + }, + ).Return(nil) + tester.WithKV("repo_type", types.ModelRepo) + + tester.Execute() + tester.ResponseEq(t, http.StatusOK, tester.OKText, nil) + +} + +func TestRepoHandler_SDKListFiles(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.SDKListFiles + }) + + tester.WithUser() + tester.WithParam("ref", "main") + tester.WithParam("branch_mapped", "main_main") + tester.WithKV("repo_type", types.ModelRepo) + tester.mocks.repo.EXPECT().SDKListFiles( + tester.ctx, types.ModelRepo, "u", "r", "main_main", "u", + ).Return(&types.SDKFiles{ID: "f1"}, nil) + + tester.Execute() + tester.ResponseEqSimple(t, http.StatusOK, &types.SDKFiles{ID: "f1"}) +} + +func TestRepoHandler_HandleDownload(t *testing.T) { + + t.Run("lfs file", func(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.SDKDownload + }) + + tester.WithUser() + tester.WithParam("ref", "main") + tester.WithParam("branch_mapped", "main_main") + tester.WithParam("file_path", "foo") + tester.WithKV("repo_type", types.ModelRepo) + req := &types.GetFileReq{ + Namespace: "u", + Name: "r", + Path: "foo", + Ref: "main_main", + Lfs: false, + SaveAs: "foo", + RepoType: types.ModelRepo, + } + tester.mocks.repo.EXPECT().IsLfs(tester.ctx, req).Return(true, 100, nil) + reqnew := *req + reqnew.Lfs = true + tester.mocks.repo.EXPECT().SDKDownloadFile(tester.ctx, &reqnew, "u").Return( + nil, 100, "url", nil, + ) + + tester.Execute() + + // redirected + require.Equal(t, http.StatusOK, tester.response.Code) + resp := tester.response.Result() + defer resp.Body.Close() + lc, err := resp.Location() + require.NoError(t, err) + require.Equal(t, "/url", lc.String()) + }) + + t.Run("normal file", func(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.SDKDownload + }) + + tester.WithUser() + tester.WithParam("ref", "main") + tester.WithParam("branch_mapped", "main_main") + tester.WithParam("file_path", "foo") + tester.WithKV("repo_type", types.ModelRepo) + req := &types.GetFileReq{ + Namespace: "u", + Name: "r", + Path: "foo", + Ref: "main_main", + Lfs: false, + SaveAs: "foo", + RepoType: types.ModelRepo, + } + tester.mocks.repo.EXPECT().IsLfs(tester.ctx, req).Return(false, 100, nil) + tester.mocks.repo.EXPECT().SDKDownloadFile(tester.ctx, req, "u").Return( + io.NopCloser(bytes.NewBuffer([]byte("bar"))), 100, "url", nil, + ) + + tester.Execute() + require.Equal(t, 200, tester.response.Code) + headers := tester.response.Header() + require.Equal(t, "application/octet-stream", headers.Get("Content-Type")) + require.Equal(t, `attachment; filename="foo"`, headers.Get("Content-Disposition")) + require.Equal(t, "100", headers.Get("Content-Length")) + r := tester.response.Body.String() + require.Equal(t, "bar", r) + }) +} + +func TestRepoHandler_HeadSDKDownload(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.HeadSDKDownload + }) + + tester.WithUser() + tester.WithParam("file_path", "foo") + tester.WithParam("branch", "main") + tester.WithKV("repo_type", types.ModelRepo) + tester.mocks.repo.EXPECT().HeadDownloadFile( + tester.ctx, &types.GetFileReq{ + Namespace: "u", + Name: "r", + Path: "foo", + Ref: "main", + SaveAs: "foo", + RepoType: types.ModelRepo, + }, "u", + ).Return(&types.File{Size: 100, SHA: "def"}, &types.Commit{ID: "abc"}, nil) + + tester.Execute() + require.Equal(t, 200, tester.response.Code) + headers := tester.response.Header() + require.Equal(t, "100", headers.Get("Content-Length")) + require.Equal(t, "abc", headers.Get("X-Repo-Commit")) + require.Equal(t, "def", headers.Get("ETag")) +} + +func TestRepoHandler_CommitWithDiff(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.CommitWithDiff + }) + + tester.WithUser() + tester.WithParam("commit_id", "foo") + tester.WithKV("repo_type", types.ModelRepo) + tester.mocks.repo.EXPECT().GetCommitWithDiff( + tester.ctx, &types.GetCommitsReq{ + Namespace: "u", + Name: "r", + Ref: "foo", + RepoType: types.ModelRepo, + CurrentUser: "u", + }, + ).Return(&types.CommitResponse{Commit: &types.Commit{ID: "foo"}}, nil) + + tester.Execute() + tester.ResponseEq(t, 200, tester.OKText, &types.CommitResponse{Commit: &types.Commit{ID: "foo"}}) +} + +func TestRepoHandler_CreateMirror(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.CreateMirror + }) + + tester.RequireUser(t) + tester.WithKV("repo_type", types.ModelRepo) + tester.WithBody(t, &types.CreateMirrorReq{ + SourceUrl: "https://foo.com", + MirrorSourceID: 12, + }) + tester.mocks.repo.EXPECT().CreateMirror( + tester.ctx, types.CreateMirrorReq{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + CurrentUser: "u", + SourceUrl: "https://foo.com", + MirrorSourceID: 12, + }, + ).Return(&database.Mirror{ID: 123}, nil) + + tester.Execute() + tester.ResponseEq(t, 200, tester.OKText, &database.Mirror{ID: 123}) +} + +func TestRepoHandler_MirrorFromSaas(t *testing.T) { + t.Run("valid", func(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.MirrorFromSaas + }) + tester.RequireUser(t) + + tester.WithParam("namespace", types.OpenCSGPrefix+"repo") + tester.WithKV("repo_type", types.ModelRepo) + tester.mocks.repo.EXPECT().MirrorFromSaas( + tester.ctx, "CSG_repo", "r", "u", types.ModelRepo, + ).Return(nil) + + tester.Execute() + tester.ResponseEq(t, 200, tester.OKText, nil) + }) + + t.Run("invalid", func(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.MirrorFromSaas + }) + tester.RequireUser(t) + + tester.WithKV("repo_type", types.ModelRepo) + tester.Execute() + tester.ResponseEq(t, 400, "Repo could not be mirrored", nil) + }) +} + +func TestRepoHandler_GetMirror(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.GetMirror + }) + tester.RequireUser(t) + + tester.WithKV("repo_type", types.ModelRepo) + tester.mocks.repo.EXPECT().GetMirror( + tester.ctx, types.GetMirrorReq{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + CurrentUser: "u", + }, + ).Return(&database.Mirror{ID: 11}, nil) + + tester.Execute() + tester.ResponseEq(t, 200, tester.OKText, &database.Mirror{ID: 11}) +} + +func TestRepoHandler_UpdateMirror(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.UpdateMirror + }) + tester.RequireUser(t) + + tester.WithKV("repo_type", types.ModelRepo) + tester.WithBody(t, &types.UpdateMirrorReq{ + MirrorSourceID: 123, + SourceUrl: "foo", + }) + tester.mocks.repo.EXPECT().UpdateMirror( + tester.ctx, types.UpdateMirrorReq{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + CurrentUser: "u", + MirrorSourceID: 123, + SourceUrl: "foo", + SourceRepoPath: "foo", + }, + ).Return(&database.Mirror{ID: 11}, nil) + + tester.Execute() + tester.ResponseEq(t, 200, tester.OKText, &database.Mirror{ID: 11}) +} + +func TestRepoHandler_DeleteMirror(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.DeleteMirror + }) + tester.RequireUser(t) + + tester.WithKV("repo_type", types.ModelRepo) + tester.mocks.repo.EXPECT().DeleteMirror( + tester.ctx, types.DeleteMirrorReq{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + CurrentUser: "u", + }, + ).Return(nil) + + tester.Execute() + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestRepoHandler_RuntimeFrameworkList(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.RuntimeFrameworkList + }) + + tester.WithKV("repo_type", types.ModelRepo) + tester.WithQuery("deploy_type", "1") + tester.mocks.repo.EXPECT().ListRuntimeFramework( + tester.ctx, types.ModelRepo, "u", "r", 1, + ).Return([]types.RuntimeFramework{{FrameName: "f1"}}, nil) + + tester.Execute() + tester.ResponseEq(t, 200, tester.OKText, []types.RuntimeFramework{{FrameName: "f1"}}) +} + +func TestRepoHandler_RuntimeFrameworkCreate(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.RuntimeFrameworkCreate + }) + + tester.WithKV("repo_type", types.ModelRepo) + tester.WithBody(t, &types.RuntimeFrameworkReq{ + FrameName: "f1", + }) + tester.mocks.repo.EXPECT().CreateRuntimeFramework( + tester.ctx, &types.RuntimeFrameworkReq{FrameName: "f1"}, + ).Return(&types.RuntimeFramework{FrameName: "f1"}, nil) + + tester.Execute() + tester.ResponseEq(t, 200, tester.OKText, &types.RuntimeFramework{FrameName: "f1"}) +} + +func TestRepoHandler_RuntimeFrameworkUpdate(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.RuntimeFrameworkUpdate + }) + + tester.WithKV("repo_type", types.ModelRepo) + tester.WithBody(t, &types.RuntimeFrameworkReq{ + FrameName: "f1", + }) + tester.WithParam("id", "1") + tester.mocks.repo.EXPECT().UpdateRuntimeFramework( + tester.ctx, int64(1), &types.RuntimeFrameworkReq{FrameName: "f1"}, + ).Return(&types.RuntimeFramework{FrameName: "f1"}, nil) + + tester.Execute() + tester.ResponseEq(t, 200, tester.OKText, &types.RuntimeFramework{FrameName: "f1"}) +} + +func TestRepoHandler_RuntimeFrameworkDelete(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.RuntimeFrameworkDelete + }) + + tester.WithKV("repo_type", types.ModelRepo) + tester.WithParam("id", "1") + tester.mocks.repo.EXPECT().DeleteRuntimeFramework( + tester.ctx, int64(1), + ).Return(nil) + + tester.Execute() + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestRepoHandler_DeployList(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.DeployList + }) + tester.RequireUser(t) + + tester.WithKV("repo_type", types.ModelRepo) + tester.mocks.repo.EXPECT().ListDeploy( + tester.ctx, types.ModelRepo, "u", "r", "u", + ).Return([]types.DeployRepo{{DeployName: "n"}}, nil) + + tester.Execute() + tester.ResponseEq(t, 200, tester.OKText, []types.DeployRepo{{DeployName: "n"}}) +} + +func TestRepoHandler_DeployDetail(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.DeployDetail + }) + tester.RequireUser(t) + + tester.WithKV("repo_type", types.ModelRepo) + tester.WithParam("id", "1") + tester.mocks.repo.EXPECT().DeployDetail( + tester.ctx, types.DeployActReq{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployID: 1, + DeployType: types.InferenceType, + }, + ).Return(&types.DeployRepo{DeployName: "n"}, nil) + + tester.Execute() + tester.ResponseEq(t, 200, tester.OKText, &types.DeployRepo{DeployName: "n"}) +} + +func TestRepoHandler_DeployInstanceLogs(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.DeployInstanceLogs + }) + tester.RequireUser(t) + + tester.WithKV("repo_type", types.ModelRepo) + tester.WithParam("id", "1") + tester.WithParam("instance", "ii") + runlogChan := make(chan string) + tester.mocks.repo.EXPECT().DeployInstanceLogs( + mock.Anything, types.DeployActReq{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployID: 1, + DeployType: types.InferenceType, + InstanceName: "ii", + }, + ).Return(deploy.NewMultiLogReader(nil, runlogChan), nil) + cc, cancel := context.WithCancel(tester.ctx.Request.Context()) + tester.ctx.Request = tester.ctx.Request.WithContext(cc) + go func() { + runlogChan <- "foo" + runlogChan <- "bar" + close(runlogChan) + cancel() + }() + + tester.Execute() + require.Equal(t, 200, tester.response.Code) + headers := tester.response.Header() + require.Equal(t, "text/event-stream", headers.Get("Content-Type")) + require.Equal(t, "no-cache", headers.Get("Cache-Control")) + require.Equal(t, "keep-alive", headers.Get("Connection")) + require.Equal(t, "chunked", headers.Get("Transfer-Encoding")) + require.Equal( + t, "event:Container\ndata:foo\n\nevent:Container\ndata:bar\n\n", + tester.response.Body.String(), + ) + +} + +func TestRepoHandler_DeployStatus(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.DeployStatus + }) + tester.handler.deployStatusCheckInterval = 0 + tester.RequireUser(t) + + tester.WithKV("repo_type", types.ModelRepo) + tester.WithParam("id", "1") + tester.mocks.repo.EXPECT().AllowAccessDeploy( + tester.ctx, types.DeployActReq{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployID: 1, + DeployType: types.InferenceType, + }, + ).Return(true, nil) + cc, cancel := context.WithCancel(tester.ctx.Request.Context()) + tester.ctx.Request = tester.ctx.Request.WithContext(cc) + tester.mocks.repo.EXPECT().DeployStatus( + mock.Anything, types.ModelRepo, "u", "r", int64(1), + ).Return("", "s1", []types.Instance{{Name: "i1"}}, nil).Once() + tester.mocks.repo.EXPECT().DeployStatus( + mock.Anything, types.ModelRepo, "u", "r", int64(1), + ).RunAndReturn(func(ctx context.Context, rt types.RepositoryType, s1, s2 string, i int64) (string, string, []types.Instance, error) { + cancel() + return "", "s3", nil, nil + }).Once() + + tester.Execute() + require.Equal(t, 200, tester.response.Code) + headers := tester.response.Header() + require.Equal(t, "text/event-stream", headers.Get("Content-Type")) + require.Equal(t, "no-cache", headers.Get("Cache-Control")) + require.Equal(t, "keep-alive", headers.Get("Connection")) + require.Equal(t, "chunked", headers.Get("Transfer-Encoding")) + require.Equal( + t, "event:status\ndata:{\"status\":\"s1\",\"details\":[{\"name\":\"i1\",\"status\":\"\"}]}\n\nevent:status\ndata:{\"status\":\"s3\",\"details\":null}\n\n", + tester.response.Body.String(), + ) + +} + +func TestRepoHandler_SyncMirror(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.SyncMirror + }) + tester.RequireUser(t) + + tester.WithKV("repo_type", types.ModelRepo) + tester.WithParam("id", "1") + tester.mocks.repo.EXPECT().SyncMirror( + tester.ctx, types.ModelRepo, "u", "r", "u", + ).Return(nil) + + tester.Execute() + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestRepoHandler_MirrorProgress(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.MirrorProgress + }) + tester.RequireUser(t) + + tester.WithKV("repo_type", types.ModelRepo) + tester.WithParam("id", "1") + tester.mocks.repo.EXPECT().MirrorProgress( + tester.ctx, types.ModelRepo, "u", "r", "u", + ).Return([]types.LFSSyncProgressResp{{Oid: "o1"}}, nil) + + tester.Execute() + tester.ResponseEq( + t, 200, tester.OKText, []types.LFSSyncProgressResp{{Oid: "o1"}}, + ) +} + +func TestRepoHandler_DeployUpdate(t *testing.T) { + t.Run("not admin", func(t *testing.T) { + + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.DeployUpdate + }) + tester.RequireUser(t) + + tester.WithKV("repo_type", types.ModelRepo) + tester.WithParam("id", "1") + tester.mocks.repo.EXPECT().AllowAdminAccess(tester.ctx, types.ModelRepo, "u", "r", "u").Return(false, nil) + tester.Execute() + tester.ResponseEq( + t, 401, "user not allowed to update deploy", nil, + ) + }) + + t.Run("admin", func(t *testing.T) { + + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.DeployUpdate + }) + tester.RequireUser(t) + + tester.WithKV("repo_type", types.ModelRepo) + tester.WithParam("id", "1") + tester.mocks.repo.EXPECT().AllowAdminAccess(tester.ctx, types.ModelRepo, "u", "r", "u").Return(true, nil) + tester.WithBody(t, &types.DeployUpdateReq{ + MinReplica: tea.Int(1), + MaxReplica: tea.Int(5), + }) + tester.mocks.repo.EXPECT().DeployUpdate(tester.ctx, types.DeployActReq{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployID: 1, + DeployType: types.InferenceType, + }, &types.DeployUpdateReq{ + MinReplica: tea.Int(1), + MaxReplica: tea.Int(5), + }).Return(nil) + tester.Execute() + tester.ResponseEq(t, 200, tester.OKText, nil) + }) +} + +func TestRepoHandler_RuntimeFrameworkListWithType(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.RuntimeFrameworkListWithType + }) + tester.RequireUser(t) + + tester.WithKV("repo_type", types.ModelRepo) + tester.mocks.repo.EXPECT().ListRuntimeFrameworkWithType( + tester.ctx, types.InferenceType, + ).Return([]types.RuntimeFramework{{FrameName: "f1"}}, nil) + + tester.Execute() + tester.ResponseEq( + t, 200, tester.OKText, []types.RuntimeFramework{{FrameName: "f1"}}, + ) +} + +func TestRepoHandler_ServerlessDetail(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.ServerlessDetail + }) + tester.RequireUser(t) + + tester.WithParam("id", "1") + tester.mocks.repo.EXPECT().DeployDetail( + tester.ctx, types.DeployActReq{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployID: 1, + DeployType: types.ServerlessType, + }, + ).Return(&types.DeployRepo{Name: "r"}, nil) + + tester.Execute() + tester.ResponseEq( + t, 200, tester.OKText, &types.DeployRepo{Name: "r"}, + ) +} + +func TestRepoHandler_ServerlessLogs(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.ServerlessLogs + }) + + tester.RequireUser(t) + + tester.WithKV("repo_type", types.ModelRepo) + tester.WithParam("id", "1") + tester.WithParam("instance", "ii") + runlogChan := make(chan string) + tester.mocks.repo.EXPECT().DeployInstanceLogs( + mock.Anything, types.DeployActReq{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployID: 1, + DeployType: types.ServerlessType, + InstanceName: "ii", + }, + ).Return(deploy.NewMultiLogReader(nil, runlogChan), nil) + cc, cancel := context.WithCancel(tester.ctx.Request.Context()) + tester.ctx.Request = tester.ctx.Request.WithContext(cc) + go func() { + runlogChan <- "foo" + runlogChan <- "bar" + close(runlogChan) + cancel() + }() + + tester.Execute() + require.Equal(t, 200, tester.response.Code) + headers := tester.response.Header() + require.Equal(t, "text/event-stream", headers.Get("Content-Type")) + require.Equal(t, "no-cache", headers.Get("Cache-Control")) + require.Equal(t, "keep-alive", headers.Get("Connection")) + require.Equal(t, "chunked", headers.Get("Transfer-Encoding")) + require.Equal( + t, "event:Container\ndata:foo\n\nevent:Container\ndata:bar\n\n", + tester.response.Body.String(), + ) + +} + +func TestRepoHandler_ServerlessStatus(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.ServerlessStatus + }) + tester.handler.deployStatusCheckInterval = 0 + tester.RequireUser(t) + + tester.WithKV("repo_type", types.ModelRepo) + tester.WithParam("id", "1") + tester.mocks.repo.EXPECT().AllowAccessDeploy( + tester.ctx, types.DeployActReq{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployID: 1, + DeployType: types.ServerlessType, + }, + ).Return(true, nil) + cc, cancel := context.WithCancel(tester.ctx.Request.Context()) + tester.ctx.Request = tester.ctx.Request.WithContext(cc) + tester.mocks.repo.EXPECT().DeployStatus( + mock.Anything, types.ModelRepo, "u", "r", int64(1), + ).Return("", "s1", []types.Instance{{Name: "i1"}}, nil).Once() + tester.mocks.repo.EXPECT().DeployStatus( + mock.Anything, types.ModelRepo, "u", "r", int64(1), + ).RunAndReturn(func(ctx context.Context, rt types.RepositoryType, s1, s2 string, i int64) (string, string, []types.Instance, error) { + cancel() + return "", "s3", nil, nil + }).Once() + + tester.Execute() + require.Equal(t, 200, tester.response.Code) + headers := tester.response.Header() + require.Equal(t, "text/event-stream", headers.Get("Content-Type")) + require.Equal(t, "no-cache", headers.Get("Cache-Control")) + require.Equal(t, "keep-alive", headers.Get("Connection")) + require.Equal(t, "chunked", headers.Get("Transfer-Encoding")) + require.Equal( + t, "event:status\ndata:{\"status\":\"s1\",\"details\":[{\"name\":\"i1\",\"status\":\"\"}]}\n\nevent:status\ndata:{\"status\":\"s3\",\"details\":null}\n\n", + tester.response.Body.String(), + ) + +} + +func TestRepoHandler_ServelessUpdate(t *testing.T) { + + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.ServerlessUpdate + }) + tester.RequireUser(t) + + tester.WithKV("repo_type", types.ModelRepo) + tester.WithParam("id", "1") + tester.WithBody(t, &types.DeployUpdateReq{ + MinReplica: tea.Int(1), + MaxReplica: tea.Int(5), + }) + tester.mocks.repo.EXPECT().DeployUpdate(tester.ctx, types.DeployActReq{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployID: 1, + DeployType: types.ServerlessType, + }, &types.DeployUpdateReq{ + MinReplica: tea.Int(1), + MaxReplica: tea.Int(5), + }).Return(nil) + tester.Execute() + tester.ResponseEq(t, 200, tester.OKText, nil) +} diff --git a/api/httpbase/user.go b/api/httpbase/user.go index 8ef33fce..f94de028 100644 --- a/api/httpbase/user.go +++ b/api/httpbase/user.go @@ -1,6 +1,8 @@ package httpbase -import "github.com/gin-gonic/gin" +import ( + "github.com/gin-gonic/gin" +) const ( CurrentUserCtxVar = "currentUser" From c9dcff355813cca7a926399feb7617e5c8066a8c Mon Sep 17 00:00:00 2001 From: yiling Date: Tue, 24 Dec 2024 17:30:10 +0800 Subject: [PATCH 20/34] fix repo handler test --- api/handler/repo_test.go | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/api/handler/repo_test.go b/api/handler/repo_test.go index e935995e..557e694d 100644 --- a/api/handler/repo_test.go +++ b/api/handler/repo_test.go @@ -585,14 +585,13 @@ func TestRepoHandler_HeadSDKDownload(t *testing.T) { SaveAs: "foo", RepoType: types.ModelRepo, }, "u", - ).Return(&types.File{Size: 100, SHA: "def"}, &types.Commit{ID: "abc"}, nil) + ).Return(&types.File{Size: 100, SHA: "def"}, nil) tester.Execute() require.Equal(t, 200, tester.response.Code) headers := tester.response.Header() require.Equal(t, "100", headers.Get("Content-Length")) - require.Equal(t, "abc", headers.Get("X-Repo-Commit")) - require.Equal(t, "def", headers.Get("ETag")) + require.Equal(t, "def", headers.Get("X-Repo-Commit")) } func TestRepoHandler_CommitWithDiff(t *testing.T) { @@ -947,24 +946,6 @@ func TestRepoHandler_SyncMirror(t *testing.T) { tester.ResponseEq(t, 200, tester.OKText, nil) } -func TestRepoHandler_MirrorProgress(t *testing.T) { - tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { - return rp.MirrorProgress - }) - tester.RequireUser(t) - - tester.WithKV("repo_type", types.ModelRepo) - tester.WithParam("id", "1") - tester.mocks.repo.EXPECT().MirrorProgress( - tester.ctx, types.ModelRepo, "u", "r", "u", - ).Return([]types.LFSSyncProgressResp{{Oid: "o1"}}, nil) - - tester.Execute() - tester.ResponseEq( - t, 200, tester.OKText, []types.LFSSyncProgressResp{{Oid: "o1"}}, - ) -} - func TestRepoHandler_DeployUpdate(t *testing.T) { t.Run("not admin", func(t *testing.T) { From 37188c7215297dc3d276c15c1260f06bb774a561 Mon Sep 17 00:00:00 2001 From: "yiling.ji" Date: Thu, 19 Dec 2024 07:52:23 +0000 Subject: [PATCH 21/34] Merge branch 'feature/handler_tests' into 'main' Add code handler tests and fix prompt component cycle import See merge request product/starhub/starhub-server!741 --- .mockery.yaml | 2 + .../component/mock_CodeComponent.go | 460 ++++++ .../component/mock_PromptComponent.go | 1444 +++++++++++++++++ api/handler/code.go | 24 +- api/handler/code_test.go | 149 ++ api/handler/prompt.go | 4 +- builder/sensitive/aliyun_green.go | 16 +- common/types/prompt.go | 56 +- component/prompt.go | 87 +- component/prompt_test.go | 16 +- 10 files changed, 2149 insertions(+), 109 deletions(-) create mode 100644 _mocks/opencsg.com/csghub-server/component/mock_CodeComponent.go create mode 100644 _mocks/opencsg.com/csghub-server/component/mock_PromptComponent.go create mode 100644 api/handler/code_test.go diff --git a/.mockery.yaml b/.mockery.yaml index cbda4ba7..64fb75b7 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -17,6 +17,8 @@ packages: SpaceComponent: RuntimeArchitectureComponent: SensitiveComponent: + CodeComponent: + PromptComponent: opencsg.com/csghub-server/user/component: config: interfaces: diff --git a/_mocks/opencsg.com/csghub-server/component/mock_CodeComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_CodeComponent.go new file mode 100644 index 00000000..536c106c --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/mock_CodeComponent.go @@ -0,0 +1,460 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + types "opencsg.com/csghub-server/common/types" +) + +// MockCodeComponent is an autogenerated mock type for the CodeComponent type +type MockCodeComponent struct { + mock.Mock +} + +type MockCodeComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockCodeComponent) EXPECT() *MockCodeComponent_Expecter { + return &MockCodeComponent_Expecter{mock: &_m.Mock} +} + +// Create provides a mock function with given fields: ctx, req +func (_m *MockCodeComponent) Create(ctx context.Context, req *types.CreateCodeReq) (*types.Code, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Create") + } + + var r0 *types.Code + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.CreateCodeReq) (*types.Code, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.CreateCodeReq) *types.Code); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Code) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.CreateCodeReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCodeComponent_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' +type MockCodeComponent_Create_Call struct { + *mock.Call +} + +// Create is a helper method to define mock.On call +// - ctx context.Context +// - req *types.CreateCodeReq +func (_e *MockCodeComponent_Expecter) Create(ctx interface{}, req interface{}) *MockCodeComponent_Create_Call { + return &MockCodeComponent_Create_Call{Call: _e.mock.On("Create", ctx, req)} +} + +func (_c *MockCodeComponent_Create_Call) Run(run func(ctx context.Context, req *types.CreateCodeReq)) *MockCodeComponent_Create_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.CreateCodeReq)) + }) + return _c +} + +func (_c *MockCodeComponent_Create_Call) Return(_a0 *types.Code, _a1 error) *MockCodeComponent_Create_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCodeComponent_Create_Call) RunAndReturn(run func(context.Context, *types.CreateCodeReq) (*types.Code, error)) *MockCodeComponent_Create_Call { + _c.Call.Return(run) + return _c +} + +// Delete provides a mock function with given fields: ctx, namespace, name, currentUser +func (_m *MockCodeComponent) Delete(ctx context.Context, namespace string, name string, currentUser string) error { + ret := _m.Called(ctx, namespace, name, currentUser) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { + r0 = rf(ctx, namespace, name, currentUser) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockCodeComponent_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type MockCodeComponent_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - name string +// - currentUser string +func (_e *MockCodeComponent_Expecter) Delete(ctx interface{}, namespace interface{}, name interface{}, currentUser interface{}) *MockCodeComponent_Delete_Call { + return &MockCodeComponent_Delete_Call{Call: _e.mock.On("Delete", ctx, namespace, name, currentUser)} +} + +func (_c *MockCodeComponent_Delete_Call) Run(run func(ctx context.Context, namespace string, name string, currentUser string)) *MockCodeComponent_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MockCodeComponent_Delete_Call) Return(_a0 error) *MockCodeComponent_Delete_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCodeComponent_Delete_Call) RunAndReturn(run func(context.Context, string, string, string) error) *MockCodeComponent_Delete_Call { + _c.Call.Return(run) + return _c +} + +// Index provides a mock function with given fields: ctx, filter, per, page +func (_m *MockCodeComponent) Index(ctx context.Context, filter *types.RepoFilter, per int, page int) ([]types.Code, int, error) { + ret := _m.Called(ctx, filter, per, page) + + if len(ret) == 0 { + panic("no return value specified for Index") + } + + var r0 []types.Code + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.RepoFilter, int, int) ([]types.Code, int, error)); ok { + return rf(ctx, filter, per, page) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.RepoFilter, int, int) []types.Code); ok { + r0 = rf(ctx, filter, per, page) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Code) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.RepoFilter, int, int) int); ok { + r1 = rf(ctx, filter, per, page) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.RepoFilter, int, int) error); ok { + r2 = rf(ctx, filter, per, page) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockCodeComponent_Index_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Index' +type MockCodeComponent_Index_Call struct { + *mock.Call +} + +// Index is a helper method to define mock.On call +// - ctx context.Context +// - filter *types.RepoFilter +// - per int +// - page int +func (_e *MockCodeComponent_Expecter) Index(ctx interface{}, filter interface{}, per interface{}, page interface{}) *MockCodeComponent_Index_Call { + return &MockCodeComponent_Index_Call{Call: _e.mock.On("Index", ctx, filter, per, page)} +} + +func (_c *MockCodeComponent_Index_Call) Run(run func(ctx context.Context, filter *types.RepoFilter, per int, page int)) *MockCodeComponent_Index_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.RepoFilter), args[2].(int), args[3].(int)) + }) + return _c +} + +func (_c *MockCodeComponent_Index_Call) Return(_a0 []types.Code, _a1 int, _a2 error) *MockCodeComponent_Index_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockCodeComponent_Index_Call) RunAndReturn(run func(context.Context, *types.RepoFilter, int, int) ([]types.Code, int, error)) *MockCodeComponent_Index_Call { + _c.Call.Return(run) + return _c +} + +// OrgCodes provides a mock function with given fields: ctx, req +func (_m *MockCodeComponent) OrgCodes(ctx context.Context, req *types.OrgCodesReq) ([]types.Code, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for OrgCodes") + } + + var r0 []types.Code + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.OrgCodesReq) ([]types.Code, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.OrgCodesReq) []types.Code); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Code) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.OrgCodesReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.OrgCodesReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockCodeComponent_OrgCodes_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OrgCodes' +type MockCodeComponent_OrgCodes_Call struct { + *mock.Call +} + +// OrgCodes is a helper method to define mock.On call +// - ctx context.Context +// - req *types.OrgCodesReq +func (_e *MockCodeComponent_Expecter) OrgCodes(ctx interface{}, req interface{}) *MockCodeComponent_OrgCodes_Call { + return &MockCodeComponent_OrgCodes_Call{Call: _e.mock.On("OrgCodes", ctx, req)} +} + +func (_c *MockCodeComponent_OrgCodes_Call) Run(run func(ctx context.Context, req *types.OrgCodesReq)) *MockCodeComponent_OrgCodes_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.OrgCodesReq)) + }) + return _c +} + +func (_c *MockCodeComponent_OrgCodes_Call) Return(_a0 []types.Code, _a1 int, _a2 error) *MockCodeComponent_OrgCodes_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockCodeComponent_OrgCodes_Call) RunAndReturn(run func(context.Context, *types.OrgCodesReq) ([]types.Code, int, error)) *MockCodeComponent_OrgCodes_Call { + _c.Call.Return(run) + return _c +} + +// Relations provides a mock function with given fields: ctx, namespace, name, currentUser +func (_m *MockCodeComponent) Relations(ctx context.Context, namespace string, name string, currentUser string) (*types.Relations, error) { + ret := _m.Called(ctx, namespace, name, currentUser) + + if len(ret) == 0 { + panic("no return value specified for Relations") + } + + var r0 *types.Relations + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (*types.Relations, error)); ok { + return rf(ctx, namespace, name, currentUser) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *types.Relations); ok { + r0 = rf(ctx, namespace, name, currentUser) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Relations) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, namespace, name, currentUser) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCodeComponent_Relations_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Relations' +type MockCodeComponent_Relations_Call struct { + *mock.Call +} + +// Relations is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - name string +// - currentUser string +func (_e *MockCodeComponent_Expecter) Relations(ctx interface{}, namespace interface{}, name interface{}, currentUser interface{}) *MockCodeComponent_Relations_Call { + return &MockCodeComponent_Relations_Call{Call: _e.mock.On("Relations", ctx, namespace, name, currentUser)} +} + +func (_c *MockCodeComponent_Relations_Call) Run(run func(ctx context.Context, namespace string, name string, currentUser string)) *MockCodeComponent_Relations_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MockCodeComponent_Relations_Call) Return(_a0 *types.Relations, _a1 error) *MockCodeComponent_Relations_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCodeComponent_Relations_Call) RunAndReturn(run func(context.Context, string, string, string) (*types.Relations, error)) *MockCodeComponent_Relations_Call { + _c.Call.Return(run) + return _c +} + +// Show provides a mock function with given fields: ctx, namespace, name, currentUser +func (_m *MockCodeComponent) Show(ctx context.Context, namespace string, name string, currentUser string) (*types.Code, error) { + ret := _m.Called(ctx, namespace, name, currentUser) + + if len(ret) == 0 { + panic("no return value specified for Show") + } + + var r0 *types.Code + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (*types.Code, error)); ok { + return rf(ctx, namespace, name, currentUser) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *types.Code); ok { + r0 = rf(ctx, namespace, name, currentUser) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Code) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, namespace, name, currentUser) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCodeComponent_Show_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Show' +type MockCodeComponent_Show_Call struct { + *mock.Call +} + +// Show is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - name string +// - currentUser string +func (_e *MockCodeComponent_Expecter) Show(ctx interface{}, namespace interface{}, name interface{}, currentUser interface{}) *MockCodeComponent_Show_Call { + return &MockCodeComponent_Show_Call{Call: _e.mock.On("Show", ctx, namespace, name, currentUser)} +} + +func (_c *MockCodeComponent_Show_Call) Run(run func(ctx context.Context, namespace string, name string, currentUser string)) *MockCodeComponent_Show_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MockCodeComponent_Show_Call) Return(_a0 *types.Code, _a1 error) *MockCodeComponent_Show_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCodeComponent_Show_Call) RunAndReturn(run func(context.Context, string, string, string) (*types.Code, error)) *MockCodeComponent_Show_Call { + _c.Call.Return(run) + return _c +} + +// Update provides a mock function with given fields: ctx, req +func (_m *MockCodeComponent) Update(ctx context.Context, req *types.UpdateCodeReq) (*types.Code, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Update") + } + + var r0 *types.Code + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UpdateCodeReq) (*types.Code, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.UpdateCodeReq) *types.Code); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Code) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.UpdateCodeReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCodeComponent_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' +type MockCodeComponent_Update_Call struct { + *mock.Call +} + +// Update is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UpdateCodeReq +func (_e *MockCodeComponent_Expecter) Update(ctx interface{}, req interface{}) *MockCodeComponent_Update_Call { + return &MockCodeComponent_Update_Call{Call: _e.mock.On("Update", ctx, req)} +} + +func (_c *MockCodeComponent_Update_Call) Run(run func(ctx context.Context, req *types.UpdateCodeReq)) *MockCodeComponent_Update_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UpdateCodeReq)) + }) + return _c +} + +func (_c *MockCodeComponent_Update_Call) Return(_a0 *types.Code, _a1 error) *MockCodeComponent_Update_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCodeComponent_Update_Call) RunAndReturn(run func(context.Context, *types.UpdateCodeReq) (*types.Code, error)) *MockCodeComponent_Update_Call { + _c.Call.Return(run) + return _c +} + +// NewMockCodeComponent creates a new instance of MockCodeComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockCodeComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockCodeComponent { + mock := &MockCodeComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/component/mock_PromptComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_PromptComponent.go new file mode 100644 index 00000000..a6e65a19 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/mock_PromptComponent.go @@ -0,0 +1,1444 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + + gitserver "opencsg.com/csghub-server/builder/git/gitserver" + database "opencsg.com/csghub-server/builder/store/database" + + mock "github.com/stretchr/testify/mock" + + types "opencsg.com/csghub-server/common/types" +) + +// MockPromptComponent is an autogenerated mock type for the PromptComponent type +type MockPromptComponent struct { + mock.Mock +} + +type MockPromptComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockPromptComponent) EXPECT() *MockPromptComponent_Expecter { + return &MockPromptComponent_Expecter{mock: &_m.Mock} +} + +// AddRelationModel provides a mock function with given fields: ctx, req +func (_m *MockPromptComponent) AddRelationModel(ctx context.Context, req types.RelationModel) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for AddRelationModel") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.RelationModel) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockPromptComponent_AddRelationModel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddRelationModel' +type MockPromptComponent_AddRelationModel_Call struct { + *mock.Call +} + +// AddRelationModel is a helper method to define mock.On call +// - ctx context.Context +// - req types.RelationModel +func (_e *MockPromptComponent_Expecter) AddRelationModel(ctx interface{}, req interface{}) *MockPromptComponent_AddRelationModel_Call { + return &MockPromptComponent_AddRelationModel_Call{Call: _e.mock.On("AddRelationModel", ctx, req)} +} + +func (_c *MockPromptComponent_AddRelationModel_Call) Run(run func(ctx context.Context, req types.RelationModel)) *MockPromptComponent_AddRelationModel_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.RelationModel)) + }) + return _c +} + +func (_c *MockPromptComponent_AddRelationModel_Call) Return(_a0 error) *MockPromptComponent_AddRelationModel_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockPromptComponent_AddRelationModel_Call) RunAndReturn(run func(context.Context, types.RelationModel) error) *MockPromptComponent_AddRelationModel_Call { + _c.Call.Return(run) + return _c +} + +// CreatePrompt provides a mock function with given fields: ctx, req, body +func (_m *MockPromptComponent) CreatePrompt(ctx context.Context, req types.PromptReq, body *types.CreatePromptReq) (*types.Prompt, error) { + ret := _m.Called(ctx, req, body) + + if len(ret) == 0 { + panic("no return value specified for CreatePrompt") + } + + var r0 *types.Prompt + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.PromptReq, *types.CreatePromptReq) (*types.Prompt, error)); ok { + return rf(ctx, req, body) + } + if rf, ok := ret.Get(0).(func(context.Context, types.PromptReq, *types.CreatePromptReq) *types.Prompt); ok { + r0 = rf(ctx, req, body) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Prompt) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.PromptReq, *types.CreatePromptReq) error); ok { + r1 = rf(ctx, req, body) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPromptComponent_CreatePrompt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreatePrompt' +type MockPromptComponent_CreatePrompt_Call struct { + *mock.Call +} + +// CreatePrompt is a helper method to define mock.On call +// - ctx context.Context +// - req types.PromptReq +// - body *types.CreatePromptReq +func (_e *MockPromptComponent_Expecter) CreatePrompt(ctx interface{}, req interface{}, body interface{}) *MockPromptComponent_CreatePrompt_Call { + return &MockPromptComponent_CreatePrompt_Call{Call: _e.mock.On("CreatePrompt", ctx, req, body)} +} + +func (_c *MockPromptComponent_CreatePrompt_Call) Run(run func(ctx context.Context, req types.PromptReq, body *types.CreatePromptReq)) *MockPromptComponent_CreatePrompt_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.PromptReq), args[2].(*types.CreatePromptReq)) + }) + return _c +} + +func (_c *MockPromptComponent_CreatePrompt_Call) Return(_a0 *types.Prompt, _a1 error) *MockPromptComponent_CreatePrompt_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPromptComponent_CreatePrompt_Call) RunAndReturn(run func(context.Context, types.PromptReq, *types.CreatePromptReq) (*types.Prompt, error)) *MockPromptComponent_CreatePrompt_Call { + _c.Call.Return(run) + return _c +} + +// CreatePromptRepo provides a mock function with given fields: ctx, req +func (_m *MockPromptComponent) CreatePromptRepo(ctx context.Context, req *types.CreatePromptRepoReq) (*types.PromptRes, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreatePromptRepo") + } + + var r0 *types.PromptRes + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.CreatePromptRepoReq) (*types.PromptRes, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.CreatePromptRepoReq) *types.PromptRes); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.PromptRes) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.CreatePromptRepoReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPromptComponent_CreatePromptRepo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreatePromptRepo' +type MockPromptComponent_CreatePromptRepo_Call struct { + *mock.Call +} + +// CreatePromptRepo is a helper method to define mock.On call +// - ctx context.Context +// - req *types.CreatePromptRepoReq +func (_e *MockPromptComponent_Expecter) CreatePromptRepo(ctx interface{}, req interface{}) *MockPromptComponent_CreatePromptRepo_Call { + return &MockPromptComponent_CreatePromptRepo_Call{Call: _e.mock.On("CreatePromptRepo", ctx, req)} +} + +func (_c *MockPromptComponent_CreatePromptRepo_Call) Run(run func(ctx context.Context, req *types.CreatePromptRepoReq)) *MockPromptComponent_CreatePromptRepo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.CreatePromptRepoReq)) + }) + return _c +} + +func (_c *MockPromptComponent_CreatePromptRepo_Call) Return(_a0 *types.PromptRes, _a1 error) *MockPromptComponent_CreatePromptRepo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPromptComponent_CreatePromptRepo_Call) RunAndReturn(run func(context.Context, *types.CreatePromptRepoReq) (*types.PromptRes, error)) *MockPromptComponent_CreatePromptRepo_Call { + _c.Call.Return(run) + return _c +} + +// DelRelationModel provides a mock function with given fields: ctx, req +func (_m *MockPromptComponent) DelRelationModel(ctx context.Context, req types.RelationModel) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for DelRelationModel") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.RelationModel) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockPromptComponent_DelRelationModel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DelRelationModel' +type MockPromptComponent_DelRelationModel_Call struct { + *mock.Call +} + +// DelRelationModel is a helper method to define mock.On call +// - ctx context.Context +// - req types.RelationModel +func (_e *MockPromptComponent_Expecter) DelRelationModel(ctx interface{}, req interface{}) *MockPromptComponent_DelRelationModel_Call { + return &MockPromptComponent_DelRelationModel_Call{Call: _e.mock.On("DelRelationModel", ctx, req)} +} + +func (_c *MockPromptComponent_DelRelationModel_Call) Run(run func(ctx context.Context, req types.RelationModel)) *MockPromptComponent_DelRelationModel_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.RelationModel)) + }) + return _c +} + +func (_c *MockPromptComponent_DelRelationModel_Call) Return(_a0 error) *MockPromptComponent_DelRelationModel_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockPromptComponent_DelRelationModel_Call) RunAndReturn(run func(context.Context, types.RelationModel) error) *MockPromptComponent_DelRelationModel_Call { + _c.Call.Return(run) + return _c +} + +// DeletePrompt provides a mock function with given fields: ctx, req +func (_m *MockPromptComponent) DeletePrompt(ctx context.Context, req types.PromptReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for DeletePrompt") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.PromptReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockPromptComponent_DeletePrompt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeletePrompt' +type MockPromptComponent_DeletePrompt_Call struct { + *mock.Call +} + +// DeletePrompt is a helper method to define mock.On call +// - ctx context.Context +// - req types.PromptReq +func (_e *MockPromptComponent_Expecter) DeletePrompt(ctx interface{}, req interface{}) *MockPromptComponent_DeletePrompt_Call { + return &MockPromptComponent_DeletePrompt_Call{Call: _e.mock.On("DeletePrompt", ctx, req)} +} + +func (_c *MockPromptComponent_DeletePrompt_Call) Run(run func(ctx context.Context, req types.PromptReq)) *MockPromptComponent_DeletePrompt_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.PromptReq)) + }) + return _c +} + +func (_c *MockPromptComponent_DeletePrompt_Call) Return(_a0 error) *MockPromptComponent_DeletePrompt_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockPromptComponent_DeletePrompt_Call) RunAndReturn(run func(context.Context, types.PromptReq) error) *MockPromptComponent_DeletePrompt_Call { + _c.Call.Return(run) + return _c +} + +// GetConversation provides a mock function with given fields: ctx, req +func (_m *MockPromptComponent) GetConversation(ctx context.Context, req types.ConversationReq) (*database.PromptConversation, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for GetConversation") + } + + var r0 *database.PromptConversation + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.ConversationReq) (*database.PromptConversation, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.ConversationReq) *database.PromptConversation); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.PromptConversation) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.ConversationReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPromptComponent_GetConversation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetConversation' +type MockPromptComponent_GetConversation_Call struct { + *mock.Call +} + +// GetConversation is a helper method to define mock.On call +// - ctx context.Context +// - req types.ConversationReq +func (_e *MockPromptComponent_Expecter) GetConversation(ctx interface{}, req interface{}) *MockPromptComponent_GetConversation_Call { + return &MockPromptComponent_GetConversation_Call{Call: _e.mock.On("GetConversation", ctx, req)} +} + +func (_c *MockPromptComponent_GetConversation_Call) Run(run func(ctx context.Context, req types.ConversationReq)) *MockPromptComponent_GetConversation_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.ConversationReq)) + }) + return _c +} + +func (_c *MockPromptComponent_GetConversation_Call) Return(_a0 *database.PromptConversation, _a1 error) *MockPromptComponent_GetConversation_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPromptComponent_GetConversation_Call) RunAndReturn(run func(context.Context, types.ConversationReq) (*database.PromptConversation, error)) *MockPromptComponent_GetConversation_Call { + _c.Call.Return(run) + return _c +} + +// GetPrompt provides a mock function with given fields: ctx, req +func (_m *MockPromptComponent) GetPrompt(ctx context.Context, req types.PromptReq) (*types.PromptOutput, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for GetPrompt") + } + + var r0 *types.PromptOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.PromptReq) (*types.PromptOutput, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.PromptReq) *types.PromptOutput); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.PromptOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.PromptReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPromptComponent_GetPrompt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPrompt' +type MockPromptComponent_GetPrompt_Call struct { + *mock.Call +} + +// GetPrompt is a helper method to define mock.On call +// - ctx context.Context +// - req types.PromptReq +func (_e *MockPromptComponent_Expecter) GetPrompt(ctx interface{}, req interface{}) *MockPromptComponent_GetPrompt_Call { + return &MockPromptComponent_GetPrompt_Call{Call: _e.mock.On("GetPrompt", ctx, req)} +} + +func (_c *MockPromptComponent_GetPrompt_Call) Run(run func(ctx context.Context, req types.PromptReq)) *MockPromptComponent_GetPrompt_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.PromptReq)) + }) + return _c +} + +func (_c *MockPromptComponent_GetPrompt_Call) Return(_a0 *types.PromptOutput, _a1 error) *MockPromptComponent_GetPrompt_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPromptComponent_GetPrompt_Call) RunAndReturn(run func(context.Context, types.PromptReq) (*types.PromptOutput, error)) *MockPromptComponent_GetPrompt_Call { + _c.Call.Return(run) + return _c +} + +// HateConversationMessage provides a mock function with given fields: ctx, req +func (_m *MockPromptComponent) HateConversationMessage(ctx context.Context, req types.ConversationMessageReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for HateConversationMessage") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.ConversationMessageReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockPromptComponent_HateConversationMessage_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HateConversationMessage' +type MockPromptComponent_HateConversationMessage_Call struct { + *mock.Call +} + +// HateConversationMessage is a helper method to define mock.On call +// - ctx context.Context +// - req types.ConversationMessageReq +func (_e *MockPromptComponent_Expecter) HateConversationMessage(ctx interface{}, req interface{}) *MockPromptComponent_HateConversationMessage_Call { + return &MockPromptComponent_HateConversationMessage_Call{Call: _e.mock.On("HateConversationMessage", ctx, req)} +} + +func (_c *MockPromptComponent_HateConversationMessage_Call) Run(run func(ctx context.Context, req types.ConversationMessageReq)) *MockPromptComponent_HateConversationMessage_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.ConversationMessageReq)) + }) + return _c +} + +func (_c *MockPromptComponent_HateConversationMessage_Call) Return(_a0 error) *MockPromptComponent_HateConversationMessage_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockPromptComponent_HateConversationMessage_Call) RunAndReturn(run func(context.Context, types.ConversationMessageReq) error) *MockPromptComponent_HateConversationMessage_Call { + _c.Call.Return(run) + return _c +} + +// IndexPromptRepo provides a mock function with given fields: ctx, filter, per, page +func (_m *MockPromptComponent) IndexPromptRepo(ctx context.Context, filter *types.RepoFilter, per int, page int) ([]types.PromptRes, int, error) { + ret := _m.Called(ctx, filter, per, page) + + if len(ret) == 0 { + panic("no return value specified for IndexPromptRepo") + } + + var r0 []types.PromptRes + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.RepoFilter, int, int) ([]types.PromptRes, int, error)); ok { + return rf(ctx, filter, per, page) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.RepoFilter, int, int) []types.PromptRes); ok { + r0 = rf(ctx, filter, per, page) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.PromptRes) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.RepoFilter, int, int) int); ok { + r1 = rf(ctx, filter, per, page) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.RepoFilter, int, int) error); ok { + r2 = rf(ctx, filter, per, page) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockPromptComponent_IndexPromptRepo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IndexPromptRepo' +type MockPromptComponent_IndexPromptRepo_Call struct { + *mock.Call +} + +// IndexPromptRepo is a helper method to define mock.On call +// - ctx context.Context +// - filter *types.RepoFilter +// - per int +// - page int +func (_e *MockPromptComponent_Expecter) IndexPromptRepo(ctx interface{}, filter interface{}, per interface{}, page interface{}) *MockPromptComponent_IndexPromptRepo_Call { + return &MockPromptComponent_IndexPromptRepo_Call{Call: _e.mock.On("IndexPromptRepo", ctx, filter, per, page)} +} + +func (_c *MockPromptComponent_IndexPromptRepo_Call) Run(run func(ctx context.Context, filter *types.RepoFilter, per int, page int)) *MockPromptComponent_IndexPromptRepo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.RepoFilter), args[2].(int), args[3].(int)) + }) + return _c +} + +func (_c *MockPromptComponent_IndexPromptRepo_Call) Return(_a0 []types.PromptRes, _a1 int, _a2 error) *MockPromptComponent_IndexPromptRepo_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockPromptComponent_IndexPromptRepo_Call) RunAndReturn(run func(context.Context, *types.RepoFilter, int, int) ([]types.PromptRes, int, error)) *MockPromptComponent_IndexPromptRepo_Call { + _c.Call.Return(run) + return _c +} + +// LikeConversationMessage provides a mock function with given fields: ctx, req +func (_m *MockPromptComponent) LikeConversationMessage(ctx context.Context, req types.ConversationMessageReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for LikeConversationMessage") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.ConversationMessageReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockPromptComponent_LikeConversationMessage_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LikeConversationMessage' +type MockPromptComponent_LikeConversationMessage_Call struct { + *mock.Call +} + +// LikeConversationMessage is a helper method to define mock.On call +// - ctx context.Context +// - req types.ConversationMessageReq +func (_e *MockPromptComponent_Expecter) LikeConversationMessage(ctx interface{}, req interface{}) *MockPromptComponent_LikeConversationMessage_Call { + return &MockPromptComponent_LikeConversationMessage_Call{Call: _e.mock.On("LikeConversationMessage", ctx, req)} +} + +func (_c *MockPromptComponent_LikeConversationMessage_Call) Run(run func(ctx context.Context, req types.ConversationMessageReq)) *MockPromptComponent_LikeConversationMessage_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.ConversationMessageReq)) + }) + return _c +} + +func (_c *MockPromptComponent_LikeConversationMessage_Call) Return(_a0 error) *MockPromptComponent_LikeConversationMessage_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockPromptComponent_LikeConversationMessage_Call) RunAndReturn(run func(context.Context, types.ConversationMessageReq) error) *MockPromptComponent_LikeConversationMessage_Call { + _c.Call.Return(run) + return _c +} + +// ListConversationsByUserID provides a mock function with given fields: ctx, currentUser +func (_m *MockPromptComponent) ListConversationsByUserID(ctx context.Context, currentUser string) ([]database.PromptConversation, error) { + ret := _m.Called(ctx, currentUser) + + if len(ret) == 0 { + panic("no return value specified for ListConversationsByUserID") + } + + var r0 []database.PromptConversation + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]database.PromptConversation, error)); ok { + return rf(ctx, currentUser) + } + if rf, ok := ret.Get(0).(func(context.Context, string) []database.PromptConversation); ok { + r0 = rf(ctx, currentUser) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]database.PromptConversation) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, currentUser) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPromptComponent_ListConversationsByUserID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListConversationsByUserID' +type MockPromptComponent_ListConversationsByUserID_Call struct { + *mock.Call +} + +// ListConversationsByUserID is a helper method to define mock.On call +// - ctx context.Context +// - currentUser string +func (_e *MockPromptComponent_Expecter) ListConversationsByUserID(ctx interface{}, currentUser interface{}) *MockPromptComponent_ListConversationsByUserID_Call { + return &MockPromptComponent_ListConversationsByUserID_Call{Call: _e.mock.On("ListConversationsByUserID", ctx, currentUser)} +} + +func (_c *MockPromptComponent_ListConversationsByUserID_Call) Run(run func(ctx context.Context, currentUser string)) *MockPromptComponent_ListConversationsByUserID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockPromptComponent_ListConversationsByUserID_Call) Return(_a0 []database.PromptConversation, _a1 error) *MockPromptComponent_ListConversationsByUserID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPromptComponent_ListConversationsByUserID_Call) RunAndReturn(run func(context.Context, string) ([]database.PromptConversation, error)) *MockPromptComponent_ListConversationsByUserID_Call { + _c.Call.Return(run) + return _c +} + +// ListPrompt provides a mock function with given fields: ctx, req +func (_m *MockPromptComponent) ListPrompt(ctx context.Context, req types.PromptReq) ([]types.PromptOutput, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for ListPrompt") + } + + var r0 []types.PromptOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.PromptReq) ([]types.PromptOutput, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.PromptReq) []types.PromptOutput); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.PromptOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.PromptReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPromptComponent_ListPrompt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListPrompt' +type MockPromptComponent_ListPrompt_Call struct { + *mock.Call +} + +// ListPrompt is a helper method to define mock.On call +// - ctx context.Context +// - req types.PromptReq +func (_e *MockPromptComponent_Expecter) ListPrompt(ctx interface{}, req interface{}) *MockPromptComponent_ListPrompt_Call { + return &MockPromptComponent_ListPrompt_Call{Call: _e.mock.On("ListPrompt", ctx, req)} +} + +func (_c *MockPromptComponent_ListPrompt_Call) Run(run func(ctx context.Context, req types.PromptReq)) *MockPromptComponent_ListPrompt_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.PromptReq)) + }) + return _c +} + +func (_c *MockPromptComponent_ListPrompt_Call) Return(_a0 []types.PromptOutput, _a1 error) *MockPromptComponent_ListPrompt_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPromptComponent_ListPrompt_Call) RunAndReturn(run func(context.Context, types.PromptReq) ([]types.PromptOutput, error)) *MockPromptComponent_ListPrompt_Call { + _c.Call.Return(run) + return _c +} + +// NewConversation provides a mock function with given fields: ctx, req +func (_m *MockPromptComponent) NewConversation(ctx context.Context, req types.ConversationTitleReq) (*database.PromptConversation, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for NewConversation") + } + + var r0 *database.PromptConversation + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.ConversationTitleReq) (*database.PromptConversation, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.ConversationTitleReq) *database.PromptConversation); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.PromptConversation) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.ConversationTitleReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPromptComponent_NewConversation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NewConversation' +type MockPromptComponent_NewConversation_Call struct { + *mock.Call +} + +// NewConversation is a helper method to define mock.On call +// - ctx context.Context +// - req types.ConversationTitleReq +func (_e *MockPromptComponent_Expecter) NewConversation(ctx interface{}, req interface{}) *MockPromptComponent_NewConversation_Call { + return &MockPromptComponent_NewConversation_Call{Call: _e.mock.On("NewConversation", ctx, req)} +} + +func (_c *MockPromptComponent_NewConversation_Call) Run(run func(ctx context.Context, req types.ConversationTitleReq)) *MockPromptComponent_NewConversation_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.ConversationTitleReq)) + }) + return _c +} + +func (_c *MockPromptComponent_NewConversation_Call) Return(_a0 *database.PromptConversation, _a1 error) *MockPromptComponent_NewConversation_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPromptComponent_NewConversation_Call) RunAndReturn(run func(context.Context, types.ConversationTitleReq) (*database.PromptConversation, error)) *MockPromptComponent_NewConversation_Call { + _c.Call.Return(run) + return _c +} + +// OrgPrompts provides a mock function with given fields: ctx, req +func (_m *MockPromptComponent) OrgPrompts(ctx context.Context, req *types.OrgPromptsReq) ([]types.PromptRes, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for OrgPrompts") + } + + var r0 []types.PromptRes + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.OrgPromptsReq) ([]types.PromptRes, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.OrgPromptsReq) []types.PromptRes); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.PromptRes) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.OrgPromptsReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.OrgPromptsReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockPromptComponent_OrgPrompts_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OrgPrompts' +type MockPromptComponent_OrgPrompts_Call struct { + *mock.Call +} + +// OrgPrompts is a helper method to define mock.On call +// - ctx context.Context +// - req *types.OrgPromptsReq +func (_e *MockPromptComponent_Expecter) OrgPrompts(ctx interface{}, req interface{}) *MockPromptComponent_OrgPrompts_Call { + return &MockPromptComponent_OrgPrompts_Call{Call: _e.mock.On("OrgPrompts", ctx, req)} +} + +func (_c *MockPromptComponent_OrgPrompts_Call) Run(run func(ctx context.Context, req *types.OrgPromptsReq)) *MockPromptComponent_OrgPrompts_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.OrgPromptsReq)) + }) + return _c +} + +func (_c *MockPromptComponent_OrgPrompts_Call) Return(_a0 []types.PromptRes, _a1 int, _a2 error) *MockPromptComponent_OrgPrompts_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockPromptComponent_OrgPrompts_Call) RunAndReturn(run func(context.Context, *types.OrgPromptsReq) ([]types.PromptRes, int, error)) *MockPromptComponent_OrgPrompts_Call { + _c.Call.Return(run) + return _c +} + +// ParseJsonFile provides a mock function with given fields: ctx, req +func (_m *MockPromptComponent) ParseJsonFile(ctx context.Context, req gitserver.GetRepoInfoByPathReq) (*types.PromptOutput, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for ParseJsonFile") + } + + var r0 *types.PromptOutput + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, gitserver.GetRepoInfoByPathReq) (*types.PromptOutput, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, gitserver.GetRepoInfoByPathReq) *types.PromptOutput); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.PromptOutput) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, gitserver.GetRepoInfoByPathReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPromptComponent_ParseJsonFile_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ParseJsonFile' +type MockPromptComponent_ParseJsonFile_Call struct { + *mock.Call +} + +// ParseJsonFile is a helper method to define mock.On call +// - ctx context.Context +// - req gitserver.GetRepoInfoByPathReq +func (_e *MockPromptComponent_Expecter) ParseJsonFile(ctx interface{}, req interface{}) *MockPromptComponent_ParseJsonFile_Call { + return &MockPromptComponent_ParseJsonFile_Call{Call: _e.mock.On("ParseJsonFile", ctx, req)} +} + +func (_c *MockPromptComponent_ParseJsonFile_Call) Run(run func(ctx context.Context, req gitserver.GetRepoInfoByPathReq)) *MockPromptComponent_ParseJsonFile_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(gitserver.GetRepoInfoByPathReq)) + }) + return _c +} + +func (_c *MockPromptComponent_ParseJsonFile_Call) Return(_a0 *types.PromptOutput, _a1 error) *MockPromptComponent_ParseJsonFile_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPromptComponent_ParseJsonFile_Call) RunAndReturn(run func(context.Context, gitserver.GetRepoInfoByPathReq) (*types.PromptOutput, error)) *MockPromptComponent_ParseJsonFile_Call { + _c.Call.Return(run) + return _c +} + +// Relations provides a mock function with given fields: ctx, namespace, name, currentUser +func (_m *MockPromptComponent) Relations(ctx context.Context, namespace string, name string, currentUser string) (*types.Relations, error) { + ret := _m.Called(ctx, namespace, name, currentUser) + + if len(ret) == 0 { + panic("no return value specified for Relations") + } + + var r0 *types.Relations + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (*types.Relations, error)); ok { + return rf(ctx, namespace, name, currentUser) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *types.Relations); ok { + r0 = rf(ctx, namespace, name, currentUser) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Relations) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, namespace, name, currentUser) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPromptComponent_Relations_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Relations' +type MockPromptComponent_Relations_Call struct { + *mock.Call +} + +// Relations is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - name string +// - currentUser string +func (_e *MockPromptComponent_Expecter) Relations(ctx interface{}, namespace interface{}, name interface{}, currentUser interface{}) *MockPromptComponent_Relations_Call { + return &MockPromptComponent_Relations_Call{Call: _e.mock.On("Relations", ctx, namespace, name, currentUser)} +} + +func (_c *MockPromptComponent_Relations_Call) Run(run func(ctx context.Context, namespace string, name string, currentUser string)) *MockPromptComponent_Relations_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MockPromptComponent_Relations_Call) Return(_a0 *types.Relations, _a1 error) *MockPromptComponent_Relations_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPromptComponent_Relations_Call) RunAndReturn(run func(context.Context, string, string, string) (*types.Relations, error)) *MockPromptComponent_Relations_Call { + _c.Call.Return(run) + return _c +} + +// RemoveConversation provides a mock function with given fields: ctx, req +func (_m *MockPromptComponent) RemoveConversation(ctx context.Context, req types.ConversationReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for RemoveConversation") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.ConversationReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockPromptComponent_RemoveConversation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveConversation' +type MockPromptComponent_RemoveConversation_Call struct { + *mock.Call +} + +// RemoveConversation is a helper method to define mock.On call +// - ctx context.Context +// - req types.ConversationReq +func (_e *MockPromptComponent_Expecter) RemoveConversation(ctx interface{}, req interface{}) *MockPromptComponent_RemoveConversation_Call { + return &MockPromptComponent_RemoveConversation_Call{Call: _e.mock.On("RemoveConversation", ctx, req)} +} + +func (_c *MockPromptComponent_RemoveConversation_Call) Run(run func(ctx context.Context, req types.ConversationReq)) *MockPromptComponent_RemoveConversation_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.ConversationReq)) + }) + return _c +} + +func (_c *MockPromptComponent_RemoveConversation_Call) Return(_a0 error) *MockPromptComponent_RemoveConversation_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockPromptComponent_RemoveConversation_Call) RunAndReturn(run func(context.Context, types.ConversationReq) error) *MockPromptComponent_RemoveConversation_Call { + _c.Call.Return(run) + return _c +} + +// RemoveRepo provides a mock function with given fields: ctx, namespace, name, currentUser +func (_m *MockPromptComponent) RemoveRepo(ctx context.Context, namespace string, name string, currentUser string) error { + ret := _m.Called(ctx, namespace, name, currentUser) + + if len(ret) == 0 { + panic("no return value specified for RemoveRepo") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { + r0 = rf(ctx, namespace, name, currentUser) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockPromptComponent_RemoveRepo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveRepo' +type MockPromptComponent_RemoveRepo_Call struct { + *mock.Call +} + +// RemoveRepo is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - name string +// - currentUser string +func (_e *MockPromptComponent_Expecter) RemoveRepo(ctx interface{}, namespace interface{}, name interface{}, currentUser interface{}) *MockPromptComponent_RemoveRepo_Call { + return &MockPromptComponent_RemoveRepo_Call{Call: _e.mock.On("RemoveRepo", ctx, namespace, name, currentUser)} +} + +func (_c *MockPromptComponent_RemoveRepo_Call) Run(run func(ctx context.Context, namespace string, name string, currentUser string)) *MockPromptComponent_RemoveRepo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MockPromptComponent_RemoveRepo_Call) Return(_a0 error) *MockPromptComponent_RemoveRepo_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockPromptComponent_RemoveRepo_Call) RunAndReturn(run func(context.Context, string, string, string) error) *MockPromptComponent_RemoveRepo_Call { + _c.Call.Return(run) + return _c +} + +// SaveGeneratedText provides a mock function with given fields: ctx, req +func (_m *MockPromptComponent) SaveGeneratedText(ctx context.Context, req types.Conversation) (*database.PromptConversationMessage, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for SaveGeneratedText") + } + + var r0 *database.PromptConversationMessage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.Conversation) (*database.PromptConversationMessage, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.Conversation) *database.PromptConversationMessage); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.PromptConversationMessage) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.Conversation) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPromptComponent_SaveGeneratedText_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveGeneratedText' +type MockPromptComponent_SaveGeneratedText_Call struct { + *mock.Call +} + +// SaveGeneratedText is a helper method to define mock.On call +// - ctx context.Context +// - req types.Conversation +func (_e *MockPromptComponent_Expecter) SaveGeneratedText(ctx interface{}, req interface{}) *MockPromptComponent_SaveGeneratedText_Call { + return &MockPromptComponent_SaveGeneratedText_Call{Call: _e.mock.On("SaveGeneratedText", ctx, req)} +} + +func (_c *MockPromptComponent_SaveGeneratedText_Call) Run(run func(ctx context.Context, req types.Conversation)) *MockPromptComponent_SaveGeneratedText_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.Conversation)) + }) + return _c +} + +func (_c *MockPromptComponent_SaveGeneratedText_Call) Return(_a0 *database.PromptConversationMessage, _a1 error) *MockPromptComponent_SaveGeneratedText_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPromptComponent_SaveGeneratedText_Call) RunAndReturn(run func(context.Context, types.Conversation) (*database.PromptConversationMessage, error)) *MockPromptComponent_SaveGeneratedText_Call { + _c.Call.Return(run) + return _c +} + +// SetRelationModels provides a mock function with given fields: ctx, req +func (_m *MockPromptComponent) SetRelationModels(ctx context.Context, req types.RelationModels) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for SetRelationModels") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.RelationModels) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockPromptComponent_SetRelationModels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetRelationModels' +type MockPromptComponent_SetRelationModels_Call struct { + *mock.Call +} + +// SetRelationModels is a helper method to define mock.On call +// - ctx context.Context +// - req types.RelationModels +func (_e *MockPromptComponent_Expecter) SetRelationModels(ctx interface{}, req interface{}) *MockPromptComponent_SetRelationModels_Call { + return &MockPromptComponent_SetRelationModels_Call{Call: _e.mock.On("SetRelationModels", ctx, req)} +} + +func (_c *MockPromptComponent_SetRelationModels_Call) Run(run func(ctx context.Context, req types.RelationModels)) *MockPromptComponent_SetRelationModels_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.RelationModels)) + }) + return _c +} + +func (_c *MockPromptComponent_SetRelationModels_Call) Return(_a0 error) *MockPromptComponent_SetRelationModels_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockPromptComponent_SetRelationModels_Call) RunAndReturn(run func(context.Context, types.RelationModels) error) *MockPromptComponent_SetRelationModels_Call { + _c.Call.Return(run) + return _c +} + +// Show provides a mock function with given fields: ctx, namespace, name, currentUser +func (_m *MockPromptComponent) Show(ctx context.Context, namespace string, name string, currentUser string) (*types.PromptRes, error) { + ret := _m.Called(ctx, namespace, name, currentUser) + + if len(ret) == 0 { + panic("no return value specified for Show") + } + + var r0 *types.PromptRes + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (*types.PromptRes, error)); ok { + return rf(ctx, namespace, name, currentUser) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *types.PromptRes); ok { + r0 = rf(ctx, namespace, name, currentUser) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.PromptRes) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, namespace, name, currentUser) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPromptComponent_Show_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Show' +type MockPromptComponent_Show_Call struct { + *mock.Call +} + +// Show is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - name string +// - currentUser string +func (_e *MockPromptComponent_Expecter) Show(ctx interface{}, namespace interface{}, name interface{}, currentUser interface{}) *MockPromptComponent_Show_Call { + return &MockPromptComponent_Show_Call{Call: _e.mock.On("Show", ctx, namespace, name, currentUser)} +} + +func (_c *MockPromptComponent_Show_Call) Run(run func(ctx context.Context, namespace string, name string, currentUser string)) *MockPromptComponent_Show_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MockPromptComponent_Show_Call) Return(_a0 *types.PromptRes, _a1 error) *MockPromptComponent_Show_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPromptComponent_Show_Call) RunAndReturn(run func(context.Context, string, string, string) (*types.PromptRes, error)) *MockPromptComponent_Show_Call { + _c.Call.Return(run) + return _c +} + +// SubmitMessage provides a mock function with given fields: ctx, req +func (_m *MockPromptComponent) SubmitMessage(ctx context.Context, req types.ConversationReq) (<-chan string, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for SubmitMessage") + } + + var r0 <-chan string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.ConversationReq) (<-chan string, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.ConversationReq) <-chan string); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.ConversationReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPromptComponent_SubmitMessage_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SubmitMessage' +type MockPromptComponent_SubmitMessage_Call struct { + *mock.Call +} + +// SubmitMessage is a helper method to define mock.On call +// - ctx context.Context +// - req types.ConversationReq +func (_e *MockPromptComponent_Expecter) SubmitMessage(ctx interface{}, req interface{}) *MockPromptComponent_SubmitMessage_Call { + return &MockPromptComponent_SubmitMessage_Call{Call: _e.mock.On("SubmitMessage", ctx, req)} +} + +func (_c *MockPromptComponent_SubmitMessage_Call) Run(run func(ctx context.Context, req types.ConversationReq)) *MockPromptComponent_SubmitMessage_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.ConversationReq)) + }) + return _c +} + +func (_c *MockPromptComponent_SubmitMessage_Call) Return(_a0 <-chan string, _a1 error) *MockPromptComponent_SubmitMessage_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPromptComponent_SubmitMessage_Call) RunAndReturn(run func(context.Context, types.ConversationReq) (<-chan string, error)) *MockPromptComponent_SubmitMessage_Call { + _c.Call.Return(run) + return _c +} + +// UpdateConversation provides a mock function with given fields: ctx, req +func (_m *MockPromptComponent) UpdateConversation(ctx context.Context, req types.ConversationTitleReq) (*database.PromptConversation, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for UpdateConversation") + } + + var r0 *database.PromptConversation + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.ConversationTitleReq) (*database.PromptConversation, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.ConversationTitleReq) *database.PromptConversation); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.PromptConversation) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.ConversationTitleReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPromptComponent_UpdateConversation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateConversation' +type MockPromptComponent_UpdateConversation_Call struct { + *mock.Call +} + +// UpdateConversation is a helper method to define mock.On call +// - ctx context.Context +// - req types.ConversationTitleReq +func (_e *MockPromptComponent_Expecter) UpdateConversation(ctx interface{}, req interface{}) *MockPromptComponent_UpdateConversation_Call { + return &MockPromptComponent_UpdateConversation_Call{Call: _e.mock.On("UpdateConversation", ctx, req)} +} + +func (_c *MockPromptComponent_UpdateConversation_Call) Run(run func(ctx context.Context, req types.ConversationTitleReq)) *MockPromptComponent_UpdateConversation_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.ConversationTitleReq)) + }) + return _c +} + +func (_c *MockPromptComponent_UpdateConversation_Call) Return(_a0 *database.PromptConversation, _a1 error) *MockPromptComponent_UpdateConversation_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPromptComponent_UpdateConversation_Call) RunAndReturn(run func(context.Context, types.ConversationTitleReq) (*database.PromptConversation, error)) *MockPromptComponent_UpdateConversation_Call { + _c.Call.Return(run) + return _c +} + +// UpdatePrompt provides a mock function with given fields: ctx, req, body +func (_m *MockPromptComponent) UpdatePrompt(ctx context.Context, req types.PromptReq, body *types.UpdatePromptReq) (*types.Prompt, error) { + ret := _m.Called(ctx, req, body) + + if len(ret) == 0 { + panic("no return value specified for UpdatePrompt") + } + + var r0 *types.Prompt + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.PromptReq, *types.UpdatePromptReq) (*types.Prompt, error)); ok { + return rf(ctx, req, body) + } + if rf, ok := ret.Get(0).(func(context.Context, types.PromptReq, *types.UpdatePromptReq) *types.Prompt); ok { + r0 = rf(ctx, req, body) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Prompt) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.PromptReq, *types.UpdatePromptReq) error); ok { + r1 = rf(ctx, req, body) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPromptComponent_UpdatePrompt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdatePrompt' +type MockPromptComponent_UpdatePrompt_Call struct { + *mock.Call +} + +// UpdatePrompt is a helper method to define mock.On call +// - ctx context.Context +// - req types.PromptReq +// - body *types.UpdatePromptReq +func (_e *MockPromptComponent_Expecter) UpdatePrompt(ctx interface{}, req interface{}, body interface{}) *MockPromptComponent_UpdatePrompt_Call { + return &MockPromptComponent_UpdatePrompt_Call{Call: _e.mock.On("UpdatePrompt", ctx, req, body)} +} + +func (_c *MockPromptComponent_UpdatePrompt_Call) Run(run func(ctx context.Context, req types.PromptReq, body *types.UpdatePromptReq)) *MockPromptComponent_UpdatePrompt_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.PromptReq), args[2].(*types.UpdatePromptReq)) + }) + return _c +} + +func (_c *MockPromptComponent_UpdatePrompt_Call) Return(_a0 *types.Prompt, _a1 error) *MockPromptComponent_UpdatePrompt_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPromptComponent_UpdatePrompt_Call) RunAndReturn(run func(context.Context, types.PromptReq, *types.UpdatePromptReq) (*types.Prompt, error)) *MockPromptComponent_UpdatePrompt_Call { + _c.Call.Return(run) + return _c +} + +// UpdatePromptRepo provides a mock function with given fields: ctx, req +func (_m *MockPromptComponent) UpdatePromptRepo(ctx context.Context, req *types.UpdatePromptRepoReq) (*types.PromptRes, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for UpdatePromptRepo") + } + + var r0 *types.PromptRes + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UpdatePromptRepoReq) (*types.PromptRes, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.UpdatePromptRepoReq) *types.PromptRes); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.PromptRes) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.UpdatePromptRepoReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPromptComponent_UpdatePromptRepo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdatePromptRepo' +type MockPromptComponent_UpdatePromptRepo_Call struct { + *mock.Call +} + +// UpdatePromptRepo is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UpdatePromptRepoReq +func (_e *MockPromptComponent_Expecter) UpdatePromptRepo(ctx interface{}, req interface{}) *MockPromptComponent_UpdatePromptRepo_Call { + return &MockPromptComponent_UpdatePromptRepo_Call{Call: _e.mock.On("UpdatePromptRepo", ctx, req)} +} + +func (_c *MockPromptComponent_UpdatePromptRepo_Call) Run(run func(ctx context.Context, req *types.UpdatePromptRepoReq)) *MockPromptComponent_UpdatePromptRepo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UpdatePromptRepoReq)) + }) + return _c +} + +func (_c *MockPromptComponent_UpdatePromptRepo_Call) Return(_a0 *types.PromptRes, _a1 error) *MockPromptComponent_UpdatePromptRepo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPromptComponent_UpdatePromptRepo_Call) RunAndReturn(run func(context.Context, *types.UpdatePromptRepoReq) (*types.PromptRes, error)) *MockPromptComponent_UpdatePromptRepo_Call { + _c.Call.Return(run) + return _c +} + +// NewMockPromptComponent creates a new instance of MockPromptComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockPromptComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockPromptComponent { + mock := &MockPromptComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/api/handler/code.go b/api/handler/code.go index 5792c0c0..bca06df0 100644 --- a/api/handler/code.go +++ b/api/handler/code.go @@ -25,14 +25,14 @@ func NewCodeHandler(config *config.Config) (*CodeHandler, error) { return nil, fmt.Errorf("error creating sensitive component:%w", err) } return &CodeHandler{ - c: tc, - sc: sc, + code: tc, + sensitive: sc, }, nil } type CodeHandler struct { - c component.CodeComponent - sc component.SensitiveComponent + code component.CodeComponent + sensitive component.SensitiveComponent } // CreateCode godoc @@ -61,7 +61,7 @@ func (h *CodeHandler) Create(ctx *gin.Context) { return } - _, err := h.sc.CheckRequestV2(ctx, req) + _, err := h.sensitive.CheckRequestV2(ctx, req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -69,7 +69,7 @@ func (h *CodeHandler) Create(ctx *gin.Context) { } req.Username = currentUser - code, err := h.c.Create(ctx, req) + code, err := h.code.Create(ctx, req) if err != nil { slog.Error("Failed to create code", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -128,7 +128,7 @@ func (h *CodeHandler) Index(ctx *gin.Context) { return } - codes, total, err := h.c.Index(ctx, filter, per, page) + codes, total, err := h.code.Index(ctx, filter, per, page) if err != nil { slog.Error("Failed to get codes", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -170,7 +170,7 @@ func (h *CodeHandler) Update(ctx *gin.Context) { return } - _, err := h.sc.CheckRequestV2(ctx, req) + _, err := h.sensitive.CheckRequestV2(ctx, req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -187,7 +187,7 @@ func (h *CodeHandler) Update(ctx *gin.Context) { req.Namespace = namespace req.Name = name - code, err := h.c.Update(ctx, req) + code, err := h.code.Update(ctx, req) if err != nil { slog.Error("Failed to update code", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -224,7 +224,7 @@ func (h *CodeHandler) Delete(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - err = h.c.Delete(ctx, namespace, name, currentUser) + err = h.code.Delete(ctx, namespace, name, currentUser) if err != nil { slog.Error("Failed to delete code", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -256,7 +256,7 @@ func (h *CodeHandler) Show(ctx *gin.Context) { return } currentUser := httpbase.GetCurrentUser(ctx) - detail, err := h.c.Show(ctx, namespace, name, currentUser) + detail, err := h.code.Show(ctx, namespace, name, currentUser) if err != nil { if errors.Is(err, component.ErrUnauthorized) { httpbase.UnauthorizedError(ctx, err) @@ -291,7 +291,7 @@ func (h *CodeHandler) Relations(ctx *gin.Context) { return } currentUser := httpbase.GetCurrentUser(ctx) - detail, err := h.c.Relations(ctx, namespace, name, currentUser) + detail, err := h.code.Relations(ctx, namespace, name, currentUser) if err != nil { if errors.Is(err, component.ErrUnauthorized) { httpbase.UnauthorizedError(ctx, err) diff --git a/api/handler/code_test.go b/api/handler/code_test.go new file mode 100644 index 00000000..09bbe917 --- /dev/null +++ b/api/handler/code_test.go @@ -0,0 +1,149 @@ +package handler + +import ( + "fmt" + "testing" + + "github.com/alibabacloud-go/tea/tea" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/common/types" +) + +type CodeTester struct { + *GinTester + handler *CodeHandler + mocks struct { + code *mockcomponent.MockCodeComponent + sensitive *mockcomponent.MockSensitiveComponent + } +} + +func NewCodeTester(t *testing.T) *CodeTester { + tester := &CodeTester{GinTester: NewGinTester()} + tester.mocks.code = mockcomponent.NewMockCodeComponent(t) + tester.mocks.sensitive = mockcomponent.NewMockSensitiveComponent(t) + tester.handler = &CodeHandler{code: tester.mocks.code, sensitive: tester.mocks.sensitive} + tester.WithParam("name", "r") + tester.WithParam("namespace", "u") + return tester + +} + +func (ct *CodeTester) WithHandleFunc(fn func(cp *CodeHandler) gin.HandlerFunc) *CodeTester { + ct.ginHandler = fn(ct.handler) + return ct +} + +func TestCodeHandler_Create(t *testing.T) { + tester := NewCodeTester(t).WithHandleFunc(func(cp *CodeHandler) gin.HandlerFunc { + return cp.Create + }) + tester.RequireUser(t) + + req := &types.CreateCodeReq{CreateRepoReq: types.CreateRepoReq{Name: "c"}} + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, req).Return(true, nil) + reqn := *req + reqn.Username = "u" + tester.mocks.code.EXPECT().Create(tester.ctx, &reqn).Return(&types.Code{Name: "c"}, nil) + tester.WithBody(t, req).Execute() + + tester.ResponseEqSimple(t, 200, gin.H{"data": &types.Code{Name: "c"}}) + +} + +func TestCodeHandler_Index(t *testing.T) { + + cases := []struct { + sort string + source string + error bool + }{ + {"most_download", "local", false}, + {"foo", "local", true}, + {"most_download", "bar", true}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("%+v", c), func(t *testing.T) { + + tester := NewCodeTester(t).WithHandleFunc(func(cp *CodeHandler) gin.HandlerFunc { + return cp.Index + }) + + if !c.error { + tester.mocks.code.EXPECT().Index(tester.ctx, &types.RepoFilter{ + Search: "foo", + Sort: c.sort, + Source: c.source, + }, 10, 1).Return([]types.Code{ + {Name: "cc"}, + }, 100, nil) + } + + tester.AddPagination(1, 10).WithQuery("search", "foo"). + WithQuery("sort", c.sort). + WithQuery("source", c.source).Execute() + + if c.error { + require.Equal(t, 400, tester.response.Code) + } else { + tester.ResponseEqSimple(t, 200, gin.H{ + "data": []types.Code{{Name: "cc"}}, + "total": 100, + }) + } + }) + } +} + +func TestCodeHandler_Update(t *testing.T) { + tester := NewCodeTester(t).WithHandleFunc(func(cp *CodeHandler) gin.HandlerFunc { + return cp.Update + }) + tester.RequireUser(t) + + req := &types.UpdateCodeReq{UpdateRepoReq: types.UpdateRepoReq{Nickname: tea.String("nc")}} + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, req).Return(true, nil) + reqn := *req + reqn.Username = "u" + reqn.Name = "r" + reqn.Namespace = "u" + tester.mocks.code.EXPECT().Update(tester.ctx, &reqn).Return(&types.Code{Name: "c"}, nil) + tester.WithBody(t, req).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Code{Name: "c"}) + +} + +func TestCodeHandler_Delete(t *testing.T) { + tester := NewCodeTester(t).WithHandleFunc(func(cp *CodeHandler) gin.HandlerFunc { + return cp.Delete + }) + tester.RequireUser(t) + + tester.mocks.code.EXPECT().Delete(tester.ctx, "u", "r", "u").Return(nil) + tester.Execute() + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestCodeHandler_Show(t *testing.T) { + tester := NewCodeTester(t).WithHandleFunc(func(cp *CodeHandler) gin.HandlerFunc { + return cp.Show + }) + + tester.mocks.code.EXPECT().Show(tester.ctx, "u", "r", "u").Return(&types.Code{Name: "c"}, nil) + tester.WithUser().Execute() + tester.ResponseEq(t, 200, tester.OKText, &types.Code{Name: "c"}) +} + +func TestCodeHandler_Relations(t *testing.T) { + tester := NewCodeTester(t).WithHandleFunc(func(cp *CodeHandler) gin.HandlerFunc { + return cp.Relations + }) + + tester.mocks.code.EXPECT().Relations(tester.ctx, "u", "r", "u").Return(&types.Relations{}, nil) + tester.WithUser().Execute() + tester.ResponseEq(t, 200, tester.OKText, &types.Relations{}) +} diff --git a/api/handler/prompt.go b/api/handler/prompt.go index 0eecbdee..6354c638 100644 --- a/api/handler/prompt.go +++ b/api/handler/prompt.go @@ -229,7 +229,7 @@ func (h *PromptHandler) CreatePrompt(ctx *gin.Context) { return } - var body *component.CreatePromptReq + var body *types.CreatePromptReq if err := ctx.ShouldBindJSON(&body); err != nil { slog.Error("Bad request prompt format", "error", err) httpbase.BadRequest(ctx, err.Error()) @@ -291,7 +291,7 @@ func (h *PromptHandler) UpdatePrompt(ctx *gin.Context) { return } - var body *component.UpdatePromptReq + var body *types.UpdatePromptReq if err := ctx.ShouldBindJSON(&body); err != nil { slog.Error("Bad request prompt format", "error", err) httpbase.BadRequest(ctx, err.Error()) diff --git a/builder/sensitive/aliyun_green.go b/builder/sensitive/aliyun_green.go index 7c06a953..b6041127 100644 --- a/builder/sensitive/aliyun_green.go +++ b/builder/sensitive/aliyun_green.go @@ -15,9 +15,21 @@ import ( "github.com/alibabacloud-go/tea/tea" "github.com/aliyun/alibaba-cloud-sdk-go/services/green" "opencsg.com/csghub-server/common/config" - "opencsg.com/csghub-server/common/utils/common" ) +// copy from common/utils/common to avoid cycle import +func truncString(s string, limit int) string { + if len(s) <= limit { + return s + } + + s1 := []byte(s[:limit]) + s1[limit-1] = '.' + s1[limit-2] = '.' + s1[limit-3] = '.' + return string(s1) +} + type GreenClient interface { TextScan(request *green.TextScanRequest) (response *TextScanResponse, err error) } @@ -149,7 +161,7 @@ func (c *AliyunGreenChecker) PassLargeTextCheck(ctx context.Context, text string } if result.Suggestion == "block" { - slog.Info("block content", slog.String("content", common.TruncString(data.Content, 128)), slog.String("taskId", data.TaskId), + slog.Info("block content", slog.String("content", truncString(data.Content, 128)), slog.String("taskId", data.TaskId), slog.String("aliyun_request_id", resp.RequestID)) return &CheckResult{IsSensitive: true, Reason: result.Label}, nil diff --git a/common/types/prompt.go b/common/types/prompt.go index 7ad62e90..c6a6f641 100644 --- a/common/types/prompt.go +++ b/common/types/prompt.go @@ -2,6 +2,8 @@ package types import ( "time" + + "opencsg.com/csghub-server/builder/sensitive" ) type PromptReq struct { @@ -104,19 +106,6 @@ type PromptRes struct { Namespace *Namespace `json:"namespace"` } -type Prompt struct { - Title string `json:"title" binding:"required"` - Content string `json:"content" binding:"required"` - Language string `json:"language" binding:"required"` - Tags []string `json:"tags"` - Type string `json:"type"` // "text|image|video|audio" - Source string `json:"source"` - Author string `json:"author"` - Time string `json:"time"` - Copyright string `json:"copyright"` - Feedback []string `json:"feedback"` -} - type PromptOutput struct { Prompt FilePath string `json:"file_path"` @@ -131,3 +120,44 @@ type CreatePromptReq struct { type UpdatePromptReq struct { Prompt } + +type Prompt struct { + Title string `json:"title" binding:"required"` + Content string `json:"content" binding:"required"` + Language string `json:"language" binding:"required"` + Tags []string `json:"tags"` + Type string `json:"type"` // "text|image|video|audio" + Source string `json:"source"` + Author string `json:"author"` + Time string `json:"time"` + Copyright string `json:"copyright"` + Feedback []string `json:"feedback"` +} + +func (req *Prompt) GetSensitiveFields() []SensitiveField { + var fields []SensitiveField + fields = append(fields, SensitiveField{ + Name: "title", + Value: func() string { + return req.Title + }, + Scenario: string(sensitive.ScenarioCommentDetection), + }) + fields = append(fields, SensitiveField{ + Name: "content", + Value: func() string { + return req.Content + }, + Scenario: string(sensitive.ScenarioCommentDetection), + }) + if len(req.Source) > 0 { + fields = append(fields, SensitiveField{ + Name: "source", + Value: func() string { + return req.Source + }, + Scenario: string(sensitive.ScenarioCommentDetection), + }) + } + return fields +} diff --git a/component/prompt.go b/component/prompt.go index 6fdadcb2..dbe41c1b 100644 --- a/component/prompt.go +++ b/component/prompt.go @@ -17,7 +17,6 @@ import ( "opencsg.com/csghub-server/builder/git/membership" "opencsg.com/csghub-server/builder/llm" "opencsg.com/csghub-server/builder/rpc" - "opencsg.com/csghub-server/builder/sensitive" "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/types" @@ -48,11 +47,11 @@ type promptComponentImpl struct { } type PromptComponent interface { - ListPrompt(ctx context.Context, req types.PromptReq) ([]PromptOutput, error) - GetPrompt(ctx context.Context, req types.PromptReq) (*PromptOutput, error) - ParseJsonFile(ctx context.Context, req gitserver.GetRepoInfoByPathReq) (*PromptOutput, error) - CreatePrompt(ctx context.Context, req types.PromptReq, body *CreatePromptReq) (*Prompt, error) - UpdatePrompt(ctx context.Context, req types.PromptReq, body *UpdatePromptReq) (*Prompt, error) + ListPrompt(ctx context.Context, req types.PromptReq) ([]types.PromptOutput, error) + GetPrompt(ctx context.Context, req types.PromptReq) (*types.PromptOutput, error) + ParseJsonFile(ctx context.Context, req gitserver.GetRepoInfoByPathReq) (*types.PromptOutput, error) + CreatePrompt(ctx context.Context, req types.PromptReq, body *types.CreatePromptReq) (*types.Prompt, error) + UpdatePrompt(ctx context.Context, req types.PromptReq, body *types.UpdatePromptReq) (*types.Prompt, error) DeletePrompt(ctx context.Context, req types.PromptReq) error NewConversation(ctx context.Context, req types.ConversationTitleReq) (*database.PromptConversation, error) ListConversationsByUserID(ctx context.Context, currentUser string) ([]database.PromptConversation, error) @@ -104,7 +103,7 @@ func NewPromptComponent(cfg *config.Config) (PromptComponent, error) { }, nil } -func (c *promptComponentImpl) ListPrompt(ctx context.Context, req types.PromptReq) ([]PromptOutput, error) { +func (c *promptComponentImpl) ListPrompt(ctx context.Context, req types.PromptReq) ([]types.PromptOutput, error) { r, err := c.repoStore.FindByPath(ctx, types.PromptRepo, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find prompt set, error: %w", err) @@ -129,9 +128,9 @@ func (c *promptComponentImpl) ListPrompt(ctx context.Context, req types.PromptRe if tree == nil { return nil, fmt.Errorf("failed to find any files") } - var prompts []PromptOutput + var prompts []types.PromptOutput wg := &sync.WaitGroup{} - chPrompts := make(chan *PromptOutput, len(tree)) + chPrompts := make(chan *types.PromptOutput, len(tree)) done := make(chan struct{}, 1) go func() { @@ -177,7 +176,7 @@ func (c *promptComponentImpl) ListPrompt(ctx context.Context, req types.PromptRe return prompts, nil } -func (c *promptComponentImpl) GetPrompt(ctx context.Context, req types.PromptReq) (*PromptOutput, error) { +func (c *promptComponentImpl) GetPrompt(ctx context.Context, req types.PromptReq) (*types.PromptOutput, error) { r, err := c.repoStore.FindByPath(ctx, types.PromptRepo, req.Namespace, req.Name) if err != nil { return nil, fmt.Errorf("failed to find prompt repo, error: %w", err) @@ -207,7 +206,7 @@ func (c *promptComponentImpl) GetPrompt(ctx context.Context, req types.PromptReq return p, nil } -func (c *promptComponentImpl) ParseJsonFile(ctx context.Context, req gitserver.GetRepoInfoByPathReq) (*PromptOutput, error) { +func (c *promptComponentImpl) ParseJsonFile(ctx context.Context, req gitserver.GetRepoInfoByPathReq) (*types.PromptOutput, error) { f, err := c.gitServer.GetRepoFileContents(ctx, req) if err != nil { return nil, fmt.Errorf("failed to get %s contents, cause:%w", req.Path, err) @@ -216,7 +215,7 @@ func (c *promptComponentImpl) ParseJsonFile(ctx context.Context, req gitserver.G if err != nil { return nil, fmt.Errorf("failed to base64 decode %s contents, cause:%w", req.Path, err) } - var prompt Prompt + var prompt types.Prompt err = yaml.Unmarshal(decodedContent, &prompt) if err != nil { return nil, fmt.Errorf("failed to Unmarshal %s contents, cause: %w, decodedContent: %v", req.Path, err, string(decodedContent)) @@ -224,14 +223,14 @@ func (c *promptComponentImpl) ParseJsonFile(ctx context.Context, req gitserver.G if len(prompt.Title) < 1 { prompt.Title = f.Name } - po := PromptOutput{ + po := types.PromptOutput{ Prompt: prompt, FilePath: req.Path, } return &po, nil } -func (c *promptComponentImpl) CreatePrompt(ctx context.Context, req types.PromptReq, body *CreatePromptReq) (*Prompt, error) { +func (c *promptComponentImpl) CreatePrompt(ctx context.Context, req types.PromptReq, body *types.CreatePromptReq) (*types.Prompt, error) { u, err := c.checkPromptRepoPermission(ctx, req) if err != nil { return nil, fmt.Errorf("user do not allowed create prompt") @@ -267,7 +266,7 @@ func (c *promptComponentImpl) CreatePrompt(ctx context.Context, req types.Prompt return &body.Prompt, nil } -func (c *promptComponentImpl) UpdatePrompt(ctx context.Context, req types.PromptReq, body *UpdatePromptReq) (*Prompt, error) { +func (c *promptComponentImpl) UpdatePrompt(ctx context.Context, req types.PromptReq, body *types.UpdatePromptReq) (*types.Prompt, error) { u, err := c.checkPromptRepoPermission(ctx, req) if err != nil { return nil, fmt.Errorf("user do not allowed update prompt") @@ -1161,63 +1160,7 @@ func (c *promptComponentImpl) getRelations(ctx context.Context, repoID int64, cu return rels, nil } -type Prompt struct { - Title string `json:"title" binding:"required"` - Content string `json:"content" binding:"required"` - Language string `json:"language" binding:"required"` - Tags []string `json:"tags"` - Type string `json:"type"` // "text|image|video|audio" - Source string `json:"source"` - Author string `json:"author"` - Time string `json:"time"` - Copyright string `json:"copyright"` - Feedback []string `json:"feedback"` -} - -type PromptOutput struct { - Prompt - FilePath string `json:"file_path"` - CanWrite bool `json:"can_write"` - CanManage bool `json:"can_manage"` -} - -type CreatePromptReq struct { - Prompt -} - -type UpdatePromptReq struct { - Prompt -} - -var _ types.SensitiveRequestV2 = (*Prompt)(nil) - -func (req *Prompt) GetSensitiveFields() []types.SensitiveField { - var fields []types.SensitiveField - fields = append(fields, types.SensitiveField{ - Name: "title", - Value: func() string { - return req.Title - }, - Scenario: string(sensitive.ScenarioCommentDetection), - }) - fields = append(fields, types.SensitiveField{ - Name: "content", - Value: func() string { - return req.Content - }, - Scenario: string(sensitive.ScenarioCommentDetection), - }) - if len(req.Source) > 0 { - fields = append(fields, types.SensitiveField{ - Name: "source", - Value: func() string { - return req.Source - }, - Scenario: string(sensitive.ScenarioCommentDetection), - }) - } - return fields -} +var _ types.SensitiveRequestV2 = (*types.Prompt)(nil) func (c *promptComponentImpl) OrgPrompts(ctx context.Context, req *types.OrgPromptsReq) ([]types.PromptRes, int, error) { var resPrompts []types.PromptRes diff --git a/component/prompt_test.go b/component/prompt_test.go index 30ce73ce..5704c66c 100644 --- a/component/prompt_test.go +++ b/component/prompt_test.go @@ -133,8 +133,8 @@ func TestPromptComponent_CreatePrompt(t *testing.T) { Name: "n", CurrentUser: "foo", Path: "p", - }, &CreatePromptReq{ - Prompt: Prompt{Title: "TEST", Content: "test"}, + }, &types.CreatePromptReq{ + Prompt: types.Prompt{Title: "TEST", Content: "test"}, }) require.NotNil(t, err) return @@ -166,8 +166,8 @@ func TestPromptComponent_CreatePrompt(t *testing.T) { Name: "n", CurrentUser: "foo", Path: "p", - }, &CreatePromptReq{ - Prompt: Prompt{Title: "TEST", Content: "test"}, + }, &types.CreatePromptReq{ + Prompt: types.Prompt{Title: "TEST", Content: "test"}, }) require.Nil(t, err) @@ -204,8 +204,8 @@ func TestPromptComponent_UpdatePrompt(t *testing.T) { Name: "n", CurrentUser: "foo", Path: "TEST.jsonl", - }, &UpdatePromptReq{ - Prompt: Prompt{Title: "TEST.jsonl", Content: "test"}, + }, &types.UpdatePromptReq{ + Prompt: types.Prompt{Title: "TEST.jsonl", Content: "test"}, }) require.NotNil(t, err) return @@ -237,8 +237,8 @@ func TestPromptComponent_UpdatePrompt(t *testing.T) { Name: "n", CurrentUser: "foo", Path: "TEST.jsonl", - }, &UpdatePromptReq{ - Prompt: Prompt{Title: "TEST", Content: "test"}, + }, &types.UpdatePromptReq{ + Prompt: types.Prompt{Title: "TEST", Content: "test"}, }) require.Nil(t, err) From 45af473464c902960c7c556a47d38c3e820d3f8a Mon Sep 17 00:00:00 2001 From: "yiling.ji" Date: Mon, 23 Dec 2024 05:50:30 +0000 Subject: [PATCH 22/34] Merge branch 'feature/handler_tests' into 'main' Add model/user/git-http handler tests See merge request product/starhub/starhub-server!751 --- .mockery.yaml | 3 + .../component/mock_GitHTTPComponent.go | 647 +++++++++ .../component/mock_ModelComponent.go | 1108 ++++++++++++++ .../component/mock_UserComponent.go | 1277 +++++++++++++++++ api/handler/git_http.go | 36 +- api/handler/git_http_test.go | 315 ++++ api/handler/helper_test.go | 26 +- api/handler/model.go | 58 +- api/handler/model_test.go | 561 ++++++++ api/handler/user.go | 52 +- api/handler/user_test.go | 439 ++++++ common/types/user.go | 8 +- component/user.go | 12 +- component/user_test.go | 32 +- 14 files changed, 4470 insertions(+), 104 deletions(-) create mode 100644 _mocks/opencsg.com/csghub-server/component/mock_GitHTTPComponent.go create mode 100644 _mocks/opencsg.com/csghub-server/component/mock_ModelComponent.go create mode 100644 _mocks/opencsg.com/csghub-server/component/mock_UserComponent.go create mode 100644 api/handler/git_http_test.go create mode 100644 api/handler/model_test.go create mode 100644 api/handler/user_test.go diff --git a/.mockery.yaml b/.mockery.yaml index 64fb75b7..04274577 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -19,6 +19,9 @@ packages: SensitiveComponent: CodeComponent: PromptComponent: + ModelComponent: + UserComponent: + GitHTTPComponent: opencsg.com/csghub-server/user/component: config: interfaces: diff --git a/_mocks/opencsg.com/csghub-server/component/mock_GitHTTPComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_GitHTTPComponent.go new file mode 100644 index 00000000..1b10efa4 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/mock_GitHTTPComponent.go @@ -0,0 +1,647 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + io "io" + + database "opencsg.com/csghub-server/builder/store/database" + + mock "github.com/stretchr/testify/mock" + + types "opencsg.com/csghub-server/common/types" + + url "net/url" +) + +// MockGitHTTPComponent is an autogenerated mock type for the GitHTTPComponent type +type MockGitHTTPComponent struct { + mock.Mock +} + +type MockGitHTTPComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockGitHTTPComponent) EXPECT() *MockGitHTTPComponent_Expecter { + return &MockGitHTTPComponent_Expecter{mock: &_m.Mock} +} + +// BuildObjectResponse provides a mock function with given fields: ctx, req, isUpload +func (_m *MockGitHTTPComponent) BuildObjectResponse(ctx context.Context, req types.BatchRequest, isUpload bool) (*types.BatchResponse, error) { + ret := _m.Called(ctx, req, isUpload) + + if len(ret) == 0 { + panic("no return value specified for BuildObjectResponse") + } + + var r0 *types.BatchResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.BatchRequest, bool) (*types.BatchResponse, error)); ok { + return rf(ctx, req, isUpload) + } + if rf, ok := ret.Get(0).(func(context.Context, types.BatchRequest, bool) *types.BatchResponse); ok { + r0 = rf(ctx, req, isUpload) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.BatchResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.BatchRequest, bool) error); ok { + r1 = rf(ctx, req, isUpload) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockGitHTTPComponent_BuildObjectResponse_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BuildObjectResponse' +type MockGitHTTPComponent_BuildObjectResponse_Call struct { + *mock.Call +} + +// BuildObjectResponse is a helper method to define mock.On call +// - ctx context.Context +// - req types.BatchRequest +// - isUpload bool +func (_e *MockGitHTTPComponent_Expecter) BuildObjectResponse(ctx interface{}, req interface{}, isUpload interface{}) *MockGitHTTPComponent_BuildObjectResponse_Call { + return &MockGitHTTPComponent_BuildObjectResponse_Call{Call: _e.mock.On("BuildObjectResponse", ctx, req, isUpload)} +} + +func (_c *MockGitHTTPComponent_BuildObjectResponse_Call) Run(run func(ctx context.Context, req types.BatchRequest, isUpload bool)) *MockGitHTTPComponent_BuildObjectResponse_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.BatchRequest), args[2].(bool)) + }) + return _c +} + +func (_c *MockGitHTTPComponent_BuildObjectResponse_Call) Return(_a0 *types.BatchResponse, _a1 error) *MockGitHTTPComponent_BuildObjectResponse_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockGitHTTPComponent_BuildObjectResponse_Call) RunAndReturn(run func(context.Context, types.BatchRequest, bool) (*types.BatchResponse, error)) *MockGitHTTPComponent_BuildObjectResponse_Call { + _c.Call.Return(run) + return _c +} + +// CreateLock provides a mock function with given fields: ctx, req +func (_m *MockGitHTTPComponent) CreateLock(ctx context.Context, req types.LfsLockReq) (*database.LfsLock, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateLock") + } + + var r0 *database.LfsLock + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.LfsLockReq) (*database.LfsLock, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.LfsLockReq) *database.LfsLock); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.LfsLock) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.LfsLockReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockGitHTTPComponent_CreateLock_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateLock' +type MockGitHTTPComponent_CreateLock_Call struct { + *mock.Call +} + +// CreateLock is a helper method to define mock.On call +// - ctx context.Context +// - req types.LfsLockReq +func (_e *MockGitHTTPComponent_Expecter) CreateLock(ctx interface{}, req interface{}) *MockGitHTTPComponent_CreateLock_Call { + return &MockGitHTTPComponent_CreateLock_Call{Call: _e.mock.On("CreateLock", ctx, req)} +} + +func (_c *MockGitHTTPComponent_CreateLock_Call) Run(run func(ctx context.Context, req types.LfsLockReq)) *MockGitHTTPComponent_CreateLock_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.LfsLockReq)) + }) + return _c +} + +func (_c *MockGitHTTPComponent_CreateLock_Call) Return(_a0 *database.LfsLock, _a1 error) *MockGitHTTPComponent_CreateLock_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockGitHTTPComponent_CreateLock_Call) RunAndReturn(run func(context.Context, types.LfsLockReq) (*database.LfsLock, error)) *MockGitHTTPComponent_CreateLock_Call { + _c.Call.Return(run) + return _c +} + +// GitReceivePack provides a mock function with given fields: ctx, req +func (_m *MockGitHTTPComponent) GitReceivePack(ctx context.Context, req types.GitReceivePackReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for GitReceivePack") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.GitReceivePackReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockGitHTTPComponent_GitReceivePack_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GitReceivePack' +type MockGitHTTPComponent_GitReceivePack_Call struct { + *mock.Call +} + +// GitReceivePack is a helper method to define mock.On call +// - ctx context.Context +// - req types.GitReceivePackReq +func (_e *MockGitHTTPComponent_Expecter) GitReceivePack(ctx interface{}, req interface{}) *MockGitHTTPComponent_GitReceivePack_Call { + return &MockGitHTTPComponent_GitReceivePack_Call{Call: _e.mock.On("GitReceivePack", ctx, req)} +} + +func (_c *MockGitHTTPComponent_GitReceivePack_Call) Run(run func(ctx context.Context, req types.GitReceivePackReq)) *MockGitHTTPComponent_GitReceivePack_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.GitReceivePackReq)) + }) + return _c +} + +func (_c *MockGitHTTPComponent_GitReceivePack_Call) Return(_a0 error) *MockGitHTTPComponent_GitReceivePack_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockGitHTTPComponent_GitReceivePack_Call) RunAndReturn(run func(context.Context, types.GitReceivePackReq) error) *MockGitHTTPComponent_GitReceivePack_Call { + _c.Call.Return(run) + return _c +} + +// GitUploadPack provides a mock function with given fields: ctx, req +func (_m *MockGitHTTPComponent) GitUploadPack(ctx context.Context, req types.GitUploadPackReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for GitUploadPack") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.GitUploadPackReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockGitHTTPComponent_GitUploadPack_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GitUploadPack' +type MockGitHTTPComponent_GitUploadPack_Call struct { + *mock.Call +} + +// GitUploadPack is a helper method to define mock.On call +// - ctx context.Context +// - req types.GitUploadPackReq +func (_e *MockGitHTTPComponent_Expecter) GitUploadPack(ctx interface{}, req interface{}) *MockGitHTTPComponent_GitUploadPack_Call { + return &MockGitHTTPComponent_GitUploadPack_Call{Call: _e.mock.On("GitUploadPack", ctx, req)} +} + +func (_c *MockGitHTTPComponent_GitUploadPack_Call) Run(run func(ctx context.Context, req types.GitUploadPackReq)) *MockGitHTTPComponent_GitUploadPack_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.GitUploadPackReq)) + }) + return _c +} + +func (_c *MockGitHTTPComponent_GitUploadPack_Call) Return(_a0 error) *MockGitHTTPComponent_GitUploadPack_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockGitHTTPComponent_GitUploadPack_Call) RunAndReturn(run func(context.Context, types.GitUploadPackReq) error) *MockGitHTTPComponent_GitUploadPack_Call { + _c.Call.Return(run) + return _c +} + +// InfoRefs provides a mock function with given fields: ctx, req +func (_m *MockGitHTTPComponent) InfoRefs(ctx context.Context, req types.InfoRefsReq) (io.Reader, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for InfoRefs") + } + + var r0 io.Reader + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.InfoRefsReq) (io.Reader, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.InfoRefsReq) io.Reader); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(io.Reader) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.InfoRefsReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockGitHTTPComponent_InfoRefs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InfoRefs' +type MockGitHTTPComponent_InfoRefs_Call struct { + *mock.Call +} + +// InfoRefs is a helper method to define mock.On call +// - ctx context.Context +// - req types.InfoRefsReq +func (_e *MockGitHTTPComponent_Expecter) InfoRefs(ctx interface{}, req interface{}) *MockGitHTTPComponent_InfoRefs_Call { + return &MockGitHTTPComponent_InfoRefs_Call{Call: _e.mock.On("InfoRefs", ctx, req)} +} + +func (_c *MockGitHTTPComponent_InfoRefs_Call) Run(run func(ctx context.Context, req types.InfoRefsReq)) *MockGitHTTPComponent_InfoRefs_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.InfoRefsReq)) + }) + return _c +} + +func (_c *MockGitHTTPComponent_InfoRefs_Call) Return(_a0 io.Reader, _a1 error) *MockGitHTTPComponent_InfoRefs_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockGitHTTPComponent_InfoRefs_Call) RunAndReturn(run func(context.Context, types.InfoRefsReq) (io.Reader, error)) *MockGitHTTPComponent_InfoRefs_Call { + _c.Call.Return(run) + return _c +} + +// LfsDownload provides a mock function with given fields: ctx, req +func (_m *MockGitHTTPComponent) LfsDownload(ctx context.Context, req types.DownloadRequest) (*url.URL, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for LfsDownload") + } + + var r0 *url.URL + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.DownloadRequest) (*url.URL, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.DownloadRequest) *url.URL); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*url.URL) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.DownloadRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockGitHTTPComponent_LfsDownload_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LfsDownload' +type MockGitHTTPComponent_LfsDownload_Call struct { + *mock.Call +} + +// LfsDownload is a helper method to define mock.On call +// - ctx context.Context +// - req types.DownloadRequest +func (_e *MockGitHTTPComponent_Expecter) LfsDownload(ctx interface{}, req interface{}) *MockGitHTTPComponent_LfsDownload_Call { + return &MockGitHTTPComponent_LfsDownload_Call{Call: _e.mock.On("LfsDownload", ctx, req)} +} + +func (_c *MockGitHTTPComponent_LfsDownload_Call) Run(run func(ctx context.Context, req types.DownloadRequest)) *MockGitHTTPComponent_LfsDownload_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.DownloadRequest)) + }) + return _c +} + +func (_c *MockGitHTTPComponent_LfsDownload_Call) Return(_a0 *url.URL, _a1 error) *MockGitHTTPComponent_LfsDownload_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockGitHTTPComponent_LfsDownload_Call) RunAndReturn(run func(context.Context, types.DownloadRequest) (*url.URL, error)) *MockGitHTTPComponent_LfsDownload_Call { + _c.Call.Return(run) + return _c +} + +// LfsUpload provides a mock function with given fields: ctx, body, req +func (_m *MockGitHTTPComponent) LfsUpload(ctx context.Context, body io.ReadCloser, req types.UploadRequest) error { + ret := _m.Called(ctx, body, req) + + if len(ret) == 0 { + panic("no return value specified for LfsUpload") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, io.ReadCloser, types.UploadRequest) error); ok { + r0 = rf(ctx, body, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockGitHTTPComponent_LfsUpload_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LfsUpload' +type MockGitHTTPComponent_LfsUpload_Call struct { + *mock.Call +} + +// LfsUpload is a helper method to define mock.On call +// - ctx context.Context +// - body io.ReadCloser +// - req types.UploadRequest +func (_e *MockGitHTTPComponent_Expecter) LfsUpload(ctx interface{}, body interface{}, req interface{}) *MockGitHTTPComponent_LfsUpload_Call { + return &MockGitHTTPComponent_LfsUpload_Call{Call: _e.mock.On("LfsUpload", ctx, body, req)} +} + +func (_c *MockGitHTTPComponent_LfsUpload_Call) Run(run func(ctx context.Context, body io.ReadCloser, req types.UploadRequest)) *MockGitHTTPComponent_LfsUpload_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(io.ReadCloser), args[2].(types.UploadRequest)) + }) + return _c +} + +func (_c *MockGitHTTPComponent_LfsUpload_Call) Return(_a0 error) *MockGitHTTPComponent_LfsUpload_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockGitHTTPComponent_LfsUpload_Call) RunAndReturn(run func(context.Context, io.ReadCloser, types.UploadRequest) error) *MockGitHTTPComponent_LfsUpload_Call { + _c.Call.Return(run) + return _c +} + +// LfsVerify provides a mock function with given fields: ctx, req, p +func (_m *MockGitHTTPComponent) LfsVerify(ctx context.Context, req types.VerifyRequest, p types.Pointer) error { + ret := _m.Called(ctx, req, p) + + if len(ret) == 0 { + panic("no return value specified for LfsVerify") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.VerifyRequest, types.Pointer) error); ok { + r0 = rf(ctx, req, p) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockGitHTTPComponent_LfsVerify_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LfsVerify' +type MockGitHTTPComponent_LfsVerify_Call struct { + *mock.Call +} + +// LfsVerify is a helper method to define mock.On call +// - ctx context.Context +// - req types.VerifyRequest +// - p types.Pointer +func (_e *MockGitHTTPComponent_Expecter) LfsVerify(ctx interface{}, req interface{}, p interface{}) *MockGitHTTPComponent_LfsVerify_Call { + return &MockGitHTTPComponent_LfsVerify_Call{Call: _e.mock.On("LfsVerify", ctx, req, p)} +} + +func (_c *MockGitHTTPComponent_LfsVerify_Call) Run(run func(ctx context.Context, req types.VerifyRequest, p types.Pointer)) *MockGitHTTPComponent_LfsVerify_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.VerifyRequest), args[2].(types.Pointer)) + }) + return _c +} + +func (_c *MockGitHTTPComponent_LfsVerify_Call) Return(_a0 error) *MockGitHTTPComponent_LfsVerify_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockGitHTTPComponent_LfsVerify_Call) RunAndReturn(run func(context.Context, types.VerifyRequest, types.Pointer) error) *MockGitHTTPComponent_LfsVerify_Call { + _c.Call.Return(run) + return _c +} + +// ListLocks provides a mock function with given fields: ctx, req +func (_m *MockGitHTTPComponent) ListLocks(ctx context.Context, req types.ListLFSLockReq) (*types.LFSLockList, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for ListLocks") + } + + var r0 *types.LFSLockList + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.ListLFSLockReq) (*types.LFSLockList, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.ListLFSLockReq) *types.LFSLockList); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.LFSLockList) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.ListLFSLockReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockGitHTTPComponent_ListLocks_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListLocks' +type MockGitHTTPComponent_ListLocks_Call struct { + *mock.Call +} + +// ListLocks is a helper method to define mock.On call +// - ctx context.Context +// - req types.ListLFSLockReq +func (_e *MockGitHTTPComponent_Expecter) ListLocks(ctx interface{}, req interface{}) *MockGitHTTPComponent_ListLocks_Call { + return &MockGitHTTPComponent_ListLocks_Call{Call: _e.mock.On("ListLocks", ctx, req)} +} + +func (_c *MockGitHTTPComponent_ListLocks_Call) Run(run func(ctx context.Context, req types.ListLFSLockReq)) *MockGitHTTPComponent_ListLocks_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.ListLFSLockReq)) + }) + return _c +} + +func (_c *MockGitHTTPComponent_ListLocks_Call) Return(_a0 *types.LFSLockList, _a1 error) *MockGitHTTPComponent_ListLocks_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockGitHTTPComponent_ListLocks_Call) RunAndReturn(run func(context.Context, types.ListLFSLockReq) (*types.LFSLockList, error)) *MockGitHTTPComponent_ListLocks_Call { + _c.Call.Return(run) + return _c +} + +// UnLock provides a mock function with given fields: ctx, req +func (_m *MockGitHTTPComponent) UnLock(ctx context.Context, req types.UnlockLFSReq) (*database.LfsLock, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for UnLock") + } + + var r0 *database.LfsLock + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.UnlockLFSReq) (*database.LfsLock, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.UnlockLFSReq) *database.LfsLock); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.LfsLock) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.UnlockLFSReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockGitHTTPComponent_UnLock_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UnLock' +type MockGitHTTPComponent_UnLock_Call struct { + *mock.Call +} + +// UnLock is a helper method to define mock.On call +// - ctx context.Context +// - req types.UnlockLFSReq +func (_e *MockGitHTTPComponent_Expecter) UnLock(ctx interface{}, req interface{}) *MockGitHTTPComponent_UnLock_Call { + return &MockGitHTTPComponent_UnLock_Call{Call: _e.mock.On("UnLock", ctx, req)} +} + +func (_c *MockGitHTTPComponent_UnLock_Call) Run(run func(ctx context.Context, req types.UnlockLFSReq)) *MockGitHTTPComponent_UnLock_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.UnlockLFSReq)) + }) + return _c +} + +func (_c *MockGitHTTPComponent_UnLock_Call) Return(_a0 *database.LfsLock, _a1 error) *MockGitHTTPComponent_UnLock_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockGitHTTPComponent_UnLock_Call) RunAndReturn(run func(context.Context, types.UnlockLFSReq) (*database.LfsLock, error)) *MockGitHTTPComponent_UnLock_Call { + _c.Call.Return(run) + return _c +} + +// VerifyLock provides a mock function with given fields: ctx, req +func (_m *MockGitHTTPComponent) VerifyLock(ctx context.Context, req types.VerifyLFSLockReq) (*types.LFSLockListVerify, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for VerifyLock") + } + + var r0 *types.LFSLockListVerify + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.VerifyLFSLockReq) (*types.LFSLockListVerify, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.VerifyLFSLockReq) *types.LFSLockListVerify); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.LFSLockListVerify) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.VerifyLFSLockReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockGitHTTPComponent_VerifyLock_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyLock' +type MockGitHTTPComponent_VerifyLock_Call struct { + *mock.Call +} + +// VerifyLock is a helper method to define mock.On call +// - ctx context.Context +// - req types.VerifyLFSLockReq +func (_e *MockGitHTTPComponent_Expecter) VerifyLock(ctx interface{}, req interface{}) *MockGitHTTPComponent_VerifyLock_Call { + return &MockGitHTTPComponent_VerifyLock_Call{Call: _e.mock.On("VerifyLock", ctx, req)} +} + +func (_c *MockGitHTTPComponent_VerifyLock_Call) Run(run func(ctx context.Context, req types.VerifyLFSLockReq)) *MockGitHTTPComponent_VerifyLock_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.VerifyLFSLockReq)) + }) + return _c +} + +func (_c *MockGitHTTPComponent_VerifyLock_Call) Return(_a0 *types.LFSLockListVerify, _a1 error) *MockGitHTTPComponent_VerifyLock_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockGitHTTPComponent_VerifyLock_Call) RunAndReturn(run func(context.Context, types.VerifyLFSLockReq) (*types.LFSLockListVerify, error)) *MockGitHTTPComponent_VerifyLock_Call { + _c.Call.Return(run) + return _c +} + +// NewMockGitHTTPComponent creates a new instance of MockGitHTTPComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockGitHTTPComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockGitHTTPComponent { + mock := &MockGitHTTPComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/component/mock_ModelComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_ModelComponent.go new file mode 100644 index 00000000..fc5bc4d2 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/mock_ModelComponent.go @@ -0,0 +1,1108 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + database "opencsg.com/csghub-server/builder/store/database" + + types "opencsg.com/csghub-server/common/types" +) + +// MockModelComponent is an autogenerated mock type for the ModelComponent type +type MockModelComponent struct { + mock.Mock +} + +type MockModelComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockModelComponent) EXPECT() *MockModelComponent_Expecter { + return &MockModelComponent_Expecter{mock: &_m.Mock} +} + +// AddRelationDataset provides a mock function with given fields: ctx, req +func (_m *MockModelComponent) AddRelationDataset(ctx context.Context, req types.RelationDataset) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for AddRelationDataset") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.RelationDataset) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockModelComponent_AddRelationDataset_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddRelationDataset' +type MockModelComponent_AddRelationDataset_Call struct { + *mock.Call +} + +// AddRelationDataset is a helper method to define mock.On call +// - ctx context.Context +// - req types.RelationDataset +func (_e *MockModelComponent_Expecter) AddRelationDataset(ctx interface{}, req interface{}) *MockModelComponent_AddRelationDataset_Call { + return &MockModelComponent_AddRelationDataset_Call{Call: _e.mock.On("AddRelationDataset", ctx, req)} +} + +func (_c *MockModelComponent_AddRelationDataset_Call) Run(run func(ctx context.Context, req types.RelationDataset)) *MockModelComponent_AddRelationDataset_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.RelationDataset)) + }) + return _c +} + +func (_c *MockModelComponent_AddRelationDataset_Call) Return(_a0 error) *MockModelComponent_AddRelationDataset_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockModelComponent_AddRelationDataset_Call) RunAndReturn(run func(context.Context, types.RelationDataset) error) *MockModelComponent_AddRelationDataset_Call { + _c.Call.Return(run) + return _c +} + +// Create provides a mock function with given fields: ctx, req +func (_m *MockModelComponent) Create(ctx context.Context, req *types.CreateModelReq) (*types.Model, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Create") + } + + var r0 *types.Model + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.CreateModelReq) (*types.Model, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.CreateModelReq) *types.Model); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Model) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.CreateModelReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockModelComponent_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' +type MockModelComponent_Create_Call struct { + *mock.Call +} + +// Create is a helper method to define mock.On call +// - ctx context.Context +// - req *types.CreateModelReq +func (_e *MockModelComponent_Expecter) Create(ctx interface{}, req interface{}) *MockModelComponent_Create_Call { + return &MockModelComponent_Create_Call{Call: _e.mock.On("Create", ctx, req)} +} + +func (_c *MockModelComponent_Create_Call) Run(run func(ctx context.Context, req *types.CreateModelReq)) *MockModelComponent_Create_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.CreateModelReq)) + }) + return _c +} + +func (_c *MockModelComponent_Create_Call) Return(_a0 *types.Model, _a1 error) *MockModelComponent_Create_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockModelComponent_Create_Call) RunAndReturn(run func(context.Context, *types.CreateModelReq) (*types.Model, error)) *MockModelComponent_Create_Call { + _c.Call.Return(run) + return _c +} + +// DelRelationDataset provides a mock function with given fields: ctx, req +func (_m *MockModelComponent) DelRelationDataset(ctx context.Context, req types.RelationDataset) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for DelRelationDataset") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.RelationDataset) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockModelComponent_DelRelationDataset_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DelRelationDataset' +type MockModelComponent_DelRelationDataset_Call struct { + *mock.Call +} + +// DelRelationDataset is a helper method to define mock.On call +// - ctx context.Context +// - req types.RelationDataset +func (_e *MockModelComponent_Expecter) DelRelationDataset(ctx interface{}, req interface{}) *MockModelComponent_DelRelationDataset_Call { + return &MockModelComponent_DelRelationDataset_Call{Call: _e.mock.On("DelRelationDataset", ctx, req)} +} + +func (_c *MockModelComponent_DelRelationDataset_Call) Run(run func(ctx context.Context, req types.RelationDataset)) *MockModelComponent_DelRelationDataset_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.RelationDataset)) + }) + return _c +} + +func (_c *MockModelComponent_DelRelationDataset_Call) Return(_a0 error) *MockModelComponent_DelRelationDataset_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockModelComponent_DelRelationDataset_Call) RunAndReturn(run func(context.Context, types.RelationDataset) error) *MockModelComponent_DelRelationDataset_Call { + _c.Call.Return(run) + return _c +} + +// Delete provides a mock function with given fields: ctx, namespace, name, currentUser +func (_m *MockModelComponent) Delete(ctx context.Context, namespace string, name string, currentUser string) error { + ret := _m.Called(ctx, namespace, name, currentUser) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { + r0 = rf(ctx, namespace, name, currentUser) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockModelComponent_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type MockModelComponent_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - name string +// - currentUser string +func (_e *MockModelComponent_Expecter) Delete(ctx interface{}, namespace interface{}, name interface{}, currentUser interface{}) *MockModelComponent_Delete_Call { + return &MockModelComponent_Delete_Call{Call: _e.mock.On("Delete", ctx, namespace, name, currentUser)} +} + +func (_c *MockModelComponent_Delete_Call) Run(run func(ctx context.Context, namespace string, name string, currentUser string)) *MockModelComponent_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MockModelComponent_Delete_Call) Return(_a0 error) *MockModelComponent_Delete_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockModelComponent_Delete_Call) RunAndReturn(run func(context.Context, string, string, string) error) *MockModelComponent_Delete_Call { + _c.Call.Return(run) + return _c +} + +// DeleteRuntimeFrameworkModes provides a mock function with given fields: ctx, deployType, id, paths +func (_m *MockModelComponent) DeleteRuntimeFrameworkModes(ctx context.Context, deployType int, id int64, paths []string) ([]string, error) { + ret := _m.Called(ctx, deployType, id, paths) + + if len(ret) == 0 { + panic("no return value specified for DeleteRuntimeFrameworkModes") + } + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int, int64, []string) ([]string, error)); ok { + return rf(ctx, deployType, id, paths) + } + if rf, ok := ret.Get(0).(func(context.Context, int, int64, []string) []string); ok { + r0 = rf(ctx, deployType, id, paths) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int, int64, []string) error); ok { + r1 = rf(ctx, deployType, id, paths) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockModelComponent_DeleteRuntimeFrameworkModes_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteRuntimeFrameworkModes' +type MockModelComponent_DeleteRuntimeFrameworkModes_Call struct { + *mock.Call +} + +// DeleteRuntimeFrameworkModes is a helper method to define mock.On call +// - ctx context.Context +// - deployType int +// - id int64 +// - paths []string +func (_e *MockModelComponent_Expecter) DeleteRuntimeFrameworkModes(ctx interface{}, deployType interface{}, id interface{}, paths interface{}) *MockModelComponent_DeleteRuntimeFrameworkModes_Call { + return &MockModelComponent_DeleteRuntimeFrameworkModes_Call{Call: _e.mock.On("DeleteRuntimeFrameworkModes", ctx, deployType, id, paths)} +} + +func (_c *MockModelComponent_DeleteRuntimeFrameworkModes_Call) Run(run func(ctx context.Context, deployType int, id int64, paths []string)) *MockModelComponent_DeleteRuntimeFrameworkModes_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int), args[2].(int64), args[3].([]string)) + }) + return _c +} + +func (_c *MockModelComponent_DeleteRuntimeFrameworkModes_Call) Return(_a0 []string, _a1 error) *MockModelComponent_DeleteRuntimeFrameworkModes_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockModelComponent_DeleteRuntimeFrameworkModes_Call) RunAndReturn(run func(context.Context, int, int64, []string) ([]string, error)) *MockModelComponent_DeleteRuntimeFrameworkModes_Call { + _c.Call.Return(run) + return _c +} + +// Deploy provides a mock function with given fields: ctx, deployReq, req +func (_m *MockModelComponent) Deploy(ctx context.Context, deployReq types.DeployActReq, req types.ModelRunReq) (int64, error) { + ret := _m.Called(ctx, deployReq, req) + + if len(ret) == 0 { + panic("no return value specified for Deploy") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.DeployActReq, types.ModelRunReq) (int64, error)); ok { + return rf(ctx, deployReq, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.DeployActReq, types.ModelRunReq) int64); ok { + r0 = rf(ctx, deployReq, req) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(context.Context, types.DeployActReq, types.ModelRunReq) error); ok { + r1 = rf(ctx, deployReq, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockModelComponent_Deploy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Deploy' +type MockModelComponent_Deploy_Call struct { + *mock.Call +} + +// Deploy is a helper method to define mock.On call +// - ctx context.Context +// - deployReq types.DeployActReq +// - req types.ModelRunReq +func (_e *MockModelComponent_Expecter) Deploy(ctx interface{}, deployReq interface{}, req interface{}) *MockModelComponent_Deploy_Call { + return &MockModelComponent_Deploy_Call{Call: _e.mock.On("Deploy", ctx, deployReq, req)} +} + +func (_c *MockModelComponent_Deploy_Call) Run(run func(ctx context.Context, deployReq types.DeployActReq, req types.ModelRunReq)) *MockModelComponent_Deploy_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.DeployActReq), args[2].(types.ModelRunReq)) + }) + return _c +} + +func (_c *MockModelComponent_Deploy_Call) Return(_a0 int64, _a1 error) *MockModelComponent_Deploy_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockModelComponent_Deploy_Call) RunAndReturn(run func(context.Context, types.DeployActReq, types.ModelRunReq) (int64, error)) *MockModelComponent_Deploy_Call { + _c.Call.Return(run) + return _c +} + +// GetServerless provides a mock function with given fields: ctx, namespace, name, currentUser +func (_m *MockModelComponent) GetServerless(ctx context.Context, namespace string, name string, currentUser string) (*types.DeployRepo, error) { + ret := _m.Called(ctx, namespace, name, currentUser) + + if len(ret) == 0 { + panic("no return value specified for GetServerless") + } + + var r0 *types.DeployRepo + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (*types.DeployRepo, error)); ok { + return rf(ctx, namespace, name, currentUser) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *types.DeployRepo); ok { + r0 = rf(ctx, namespace, name, currentUser) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.DeployRepo) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, namespace, name, currentUser) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockModelComponent_GetServerless_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetServerless' +type MockModelComponent_GetServerless_Call struct { + *mock.Call +} + +// GetServerless is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - name string +// - currentUser string +func (_e *MockModelComponent_Expecter) GetServerless(ctx interface{}, namespace interface{}, name interface{}, currentUser interface{}) *MockModelComponent_GetServerless_Call { + return &MockModelComponent_GetServerless_Call{Call: _e.mock.On("GetServerless", ctx, namespace, name, currentUser)} +} + +func (_c *MockModelComponent_GetServerless_Call) Run(run func(ctx context.Context, namespace string, name string, currentUser string)) *MockModelComponent_GetServerless_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MockModelComponent_GetServerless_Call) Return(_a0 *types.DeployRepo, _a1 error) *MockModelComponent_GetServerless_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockModelComponent_GetServerless_Call) RunAndReturn(run func(context.Context, string, string, string) (*types.DeployRepo, error)) *MockModelComponent_GetServerless_Call { + _c.Call.Return(run) + return _c +} + +// Index provides a mock function with given fields: ctx, filter, per, page, needOpWeight +func (_m *MockModelComponent) Index(ctx context.Context, filter *types.RepoFilter, per int, page int, needOpWeight bool) ([]*types.Model, int, error) { + ret := _m.Called(ctx, filter, per, page, needOpWeight) + + if len(ret) == 0 { + panic("no return value specified for Index") + } + + var r0 []*types.Model + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.RepoFilter, int, int, bool) ([]*types.Model, int, error)); ok { + return rf(ctx, filter, per, page, needOpWeight) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.RepoFilter, int, int, bool) []*types.Model); ok { + r0 = rf(ctx, filter, per, page, needOpWeight) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*types.Model) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.RepoFilter, int, int, bool) int); ok { + r1 = rf(ctx, filter, per, page, needOpWeight) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.RepoFilter, int, int, bool) error); ok { + r2 = rf(ctx, filter, per, page, needOpWeight) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockModelComponent_Index_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Index' +type MockModelComponent_Index_Call struct { + *mock.Call +} + +// Index is a helper method to define mock.On call +// - ctx context.Context +// - filter *types.RepoFilter +// - per int +// - page int +// - needOpWeight bool +func (_e *MockModelComponent_Expecter) Index(ctx interface{}, filter interface{}, per interface{}, page interface{}, needOpWeight interface{}) *MockModelComponent_Index_Call { + return &MockModelComponent_Index_Call{Call: _e.mock.On("Index", ctx, filter, per, page, needOpWeight)} +} + +func (_c *MockModelComponent_Index_Call) Run(run func(ctx context.Context, filter *types.RepoFilter, per int, page int, needOpWeight bool)) *MockModelComponent_Index_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.RepoFilter), args[2].(int), args[3].(int), args[4].(bool)) + }) + return _c +} + +func (_c *MockModelComponent_Index_Call) Return(_a0 []*types.Model, _a1 int, _a2 error) *MockModelComponent_Index_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockModelComponent_Index_Call) RunAndReturn(run func(context.Context, *types.RepoFilter, int, int, bool) ([]*types.Model, int, error)) *MockModelComponent_Index_Call { + _c.Call.Return(run) + return _c +} + +// ListAllByRuntimeFramework provides a mock function with given fields: ctx, currentUser +func (_m *MockModelComponent) ListAllByRuntimeFramework(ctx context.Context, currentUser string) ([]database.RuntimeFramework, error) { + ret := _m.Called(ctx, currentUser) + + if len(ret) == 0 { + panic("no return value specified for ListAllByRuntimeFramework") + } + + var r0 []database.RuntimeFramework + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]database.RuntimeFramework, error)); ok { + return rf(ctx, currentUser) + } + if rf, ok := ret.Get(0).(func(context.Context, string) []database.RuntimeFramework); ok { + r0 = rf(ctx, currentUser) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]database.RuntimeFramework) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, currentUser) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockModelComponent_ListAllByRuntimeFramework_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAllByRuntimeFramework' +type MockModelComponent_ListAllByRuntimeFramework_Call struct { + *mock.Call +} + +// ListAllByRuntimeFramework is a helper method to define mock.On call +// - ctx context.Context +// - currentUser string +func (_e *MockModelComponent_Expecter) ListAllByRuntimeFramework(ctx interface{}, currentUser interface{}) *MockModelComponent_ListAllByRuntimeFramework_Call { + return &MockModelComponent_ListAllByRuntimeFramework_Call{Call: _e.mock.On("ListAllByRuntimeFramework", ctx, currentUser)} +} + +func (_c *MockModelComponent_ListAllByRuntimeFramework_Call) Run(run func(ctx context.Context, currentUser string)) *MockModelComponent_ListAllByRuntimeFramework_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockModelComponent_ListAllByRuntimeFramework_Call) Return(_a0 []database.RuntimeFramework, _a1 error) *MockModelComponent_ListAllByRuntimeFramework_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockModelComponent_ListAllByRuntimeFramework_Call) RunAndReturn(run func(context.Context, string) ([]database.RuntimeFramework, error)) *MockModelComponent_ListAllByRuntimeFramework_Call { + _c.Call.Return(run) + return _c +} + +// ListModelsByRuntimeFrameworkID provides a mock function with given fields: ctx, currentUser, per, page, id, deployType +func (_m *MockModelComponent) ListModelsByRuntimeFrameworkID(ctx context.Context, currentUser string, per int, page int, id int64, deployType int) ([]types.Model, int, error) { + ret := _m.Called(ctx, currentUser, per, page, id, deployType) + + if len(ret) == 0 { + panic("no return value specified for ListModelsByRuntimeFrameworkID") + } + + var r0 []types.Model + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string, int, int, int64, int) ([]types.Model, int, error)); ok { + return rf(ctx, currentUser, per, page, id, deployType) + } + if rf, ok := ret.Get(0).(func(context.Context, string, int, int, int64, int) []types.Model); ok { + r0 = rf(ctx, currentUser, per, page, id, deployType) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Model) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, int, int, int64, int) int); ok { + r1 = rf(ctx, currentUser, per, page, id, deployType) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, string, int, int, int64, int) error); ok { + r2 = rf(ctx, currentUser, per, page, id, deployType) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockModelComponent_ListModelsByRuntimeFrameworkID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListModelsByRuntimeFrameworkID' +type MockModelComponent_ListModelsByRuntimeFrameworkID_Call struct { + *mock.Call +} + +// ListModelsByRuntimeFrameworkID is a helper method to define mock.On call +// - ctx context.Context +// - currentUser string +// - per int +// - page int +// - id int64 +// - deployType int +func (_e *MockModelComponent_Expecter) ListModelsByRuntimeFrameworkID(ctx interface{}, currentUser interface{}, per interface{}, page interface{}, id interface{}, deployType interface{}) *MockModelComponent_ListModelsByRuntimeFrameworkID_Call { + return &MockModelComponent_ListModelsByRuntimeFrameworkID_Call{Call: _e.mock.On("ListModelsByRuntimeFrameworkID", ctx, currentUser, per, page, id, deployType)} +} + +func (_c *MockModelComponent_ListModelsByRuntimeFrameworkID_Call) Run(run func(ctx context.Context, currentUser string, per int, page int, id int64, deployType int)) *MockModelComponent_ListModelsByRuntimeFrameworkID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(int), args[3].(int), args[4].(int64), args[5].(int)) + }) + return _c +} + +func (_c *MockModelComponent_ListModelsByRuntimeFrameworkID_Call) Return(_a0 []types.Model, _a1 int, _a2 error) *MockModelComponent_ListModelsByRuntimeFrameworkID_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockModelComponent_ListModelsByRuntimeFrameworkID_Call) RunAndReturn(run func(context.Context, string, int, int, int64, int) ([]types.Model, int, error)) *MockModelComponent_ListModelsByRuntimeFrameworkID_Call { + _c.Call.Return(run) + return _c +} + +// ListModelsOfRuntimeFrameworks provides a mock function with given fields: ctx, currentUser, search, sort, per, page, deployType +func (_m *MockModelComponent) ListModelsOfRuntimeFrameworks(ctx context.Context, currentUser string, search string, sort string, per int, page int, deployType int) ([]types.Model, int, error) { + ret := _m.Called(ctx, currentUser, search, sort, per, page, deployType) + + if len(ret) == 0 { + panic("no return value specified for ListModelsOfRuntimeFrameworks") + } + + var r0 []types.Model + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, int, int, int) ([]types.Model, int, error)); ok { + return rf(ctx, currentUser, search, sort, per, page, deployType) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, int, int, int) []types.Model); ok { + r0 = rf(ctx, currentUser, search, sort, per, page, deployType) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Model) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, int, int, int) int); ok { + r1 = rf(ctx, currentUser, search, sort, per, page, deployType) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, string, string, string, int, int, int) error); ok { + r2 = rf(ctx, currentUser, search, sort, per, page, deployType) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockModelComponent_ListModelsOfRuntimeFrameworks_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListModelsOfRuntimeFrameworks' +type MockModelComponent_ListModelsOfRuntimeFrameworks_Call struct { + *mock.Call +} + +// ListModelsOfRuntimeFrameworks is a helper method to define mock.On call +// - ctx context.Context +// - currentUser string +// - search string +// - sort string +// - per int +// - page int +// - deployType int +func (_e *MockModelComponent_Expecter) ListModelsOfRuntimeFrameworks(ctx interface{}, currentUser interface{}, search interface{}, sort interface{}, per interface{}, page interface{}, deployType interface{}) *MockModelComponent_ListModelsOfRuntimeFrameworks_Call { + return &MockModelComponent_ListModelsOfRuntimeFrameworks_Call{Call: _e.mock.On("ListModelsOfRuntimeFrameworks", ctx, currentUser, search, sort, per, page, deployType)} +} + +func (_c *MockModelComponent_ListModelsOfRuntimeFrameworks_Call) Run(run func(ctx context.Context, currentUser string, search string, sort string, per int, page int, deployType int)) *MockModelComponent_ListModelsOfRuntimeFrameworks_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(int), args[5].(int), args[6].(int)) + }) + return _c +} + +func (_c *MockModelComponent_ListModelsOfRuntimeFrameworks_Call) Return(_a0 []types.Model, _a1 int, _a2 error) *MockModelComponent_ListModelsOfRuntimeFrameworks_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockModelComponent_ListModelsOfRuntimeFrameworks_Call) RunAndReturn(run func(context.Context, string, string, string, int, int, int) ([]types.Model, int, error)) *MockModelComponent_ListModelsOfRuntimeFrameworks_Call { + _c.Call.Return(run) + return _c +} + +// OrgModels provides a mock function with given fields: ctx, req +func (_m *MockModelComponent) OrgModels(ctx context.Context, req *types.OrgModelsReq) ([]types.Model, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for OrgModels") + } + + var r0 []types.Model + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.OrgModelsReq) ([]types.Model, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.OrgModelsReq) []types.Model); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Model) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.OrgModelsReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.OrgModelsReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockModelComponent_OrgModels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OrgModels' +type MockModelComponent_OrgModels_Call struct { + *mock.Call +} + +// OrgModels is a helper method to define mock.On call +// - ctx context.Context +// - req *types.OrgModelsReq +func (_e *MockModelComponent_Expecter) OrgModels(ctx interface{}, req interface{}) *MockModelComponent_OrgModels_Call { + return &MockModelComponent_OrgModels_Call{Call: _e.mock.On("OrgModels", ctx, req)} +} + +func (_c *MockModelComponent_OrgModels_Call) Run(run func(ctx context.Context, req *types.OrgModelsReq)) *MockModelComponent_OrgModels_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.OrgModelsReq)) + }) + return _c +} + +func (_c *MockModelComponent_OrgModels_Call) Return(_a0 []types.Model, _a1 int, _a2 error) *MockModelComponent_OrgModels_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockModelComponent_OrgModels_Call) RunAndReturn(run func(context.Context, *types.OrgModelsReq) ([]types.Model, int, error)) *MockModelComponent_OrgModels_Call { + _c.Call.Return(run) + return _c +} + +// Relations provides a mock function with given fields: ctx, namespace, name, currentUser +func (_m *MockModelComponent) Relations(ctx context.Context, namespace string, name string, currentUser string) (*types.Relations, error) { + ret := _m.Called(ctx, namespace, name, currentUser) + + if len(ret) == 0 { + panic("no return value specified for Relations") + } + + var r0 *types.Relations + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (*types.Relations, error)); ok { + return rf(ctx, namespace, name, currentUser) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *types.Relations); ok { + r0 = rf(ctx, namespace, name, currentUser) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Relations) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, namespace, name, currentUser) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockModelComponent_Relations_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Relations' +type MockModelComponent_Relations_Call struct { + *mock.Call +} + +// Relations is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - name string +// - currentUser string +func (_e *MockModelComponent_Expecter) Relations(ctx interface{}, namespace interface{}, name interface{}, currentUser interface{}) *MockModelComponent_Relations_Call { + return &MockModelComponent_Relations_Call{Call: _e.mock.On("Relations", ctx, namespace, name, currentUser)} +} + +func (_c *MockModelComponent_Relations_Call) Run(run func(ctx context.Context, namespace string, name string, currentUser string)) *MockModelComponent_Relations_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MockModelComponent_Relations_Call) Return(_a0 *types.Relations, _a1 error) *MockModelComponent_Relations_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockModelComponent_Relations_Call) RunAndReturn(run func(context.Context, string, string, string) (*types.Relations, error)) *MockModelComponent_Relations_Call { + _c.Call.Return(run) + return _c +} + +// SDKModelInfo provides a mock function with given fields: ctx, namespace, name, ref, currentUser +func (_m *MockModelComponent) SDKModelInfo(ctx context.Context, namespace string, name string, ref string, currentUser string) (*types.SDKModelInfo, error) { + ret := _m.Called(ctx, namespace, name, ref, currentUser) + + if len(ret) == 0 { + panic("no return value specified for SDKModelInfo") + } + + var r0 *types.SDKModelInfo + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) (*types.SDKModelInfo, error)); ok { + return rf(ctx, namespace, name, ref, currentUser) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) *types.SDKModelInfo); ok { + r0 = rf(ctx, namespace, name, ref, currentUser) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.SDKModelInfo) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, string) error); ok { + r1 = rf(ctx, namespace, name, ref, currentUser) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockModelComponent_SDKModelInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SDKModelInfo' +type MockModelComponent_SDKModelInfo_Call struct { + *mock.Call +} + +// SDKModelInfo is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - name string +// - ref string +// - currentUser string +func (_e *MockModelComponent_Expecter) SDKModelInfo(ctx interface{}, namespace interface{}, name interface{}, ref interface{}, currentUser interface{}) *MockModelComponent_SDKModelInfo_Call { + return &MockModelComponent_SDKModelInfo_Call{Call: _e.mock.On("SDKModelInfo", ctx, namespace, name, ref, currentUser)} +} + +func (_c *MockModelComponent_SDKModelInfo_Call) Run(run func(ctx context.Context, namespace string, name string, ref string, currentUser string)) *MockModelComponent_SDKModelInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + }) + return _c +} + +func (_c *MockModelComponent_SDKModelInfo_Call) Return(_a0 *types.SDKModelInfo, _a1 error) *MockModelComponent_SDKModelInfo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockModelComponent_SDKModelInfo_Call) RunAndReturn(run func(context.Context, string, string, string, string) (*types.SDKModelInfo, error)) *MockModelComponent_SDKModelInfo_Call { + _c.Call.Return(run) + return _c +} + +// SetRelationDatasets provides a mock function with given fields: ctx, req +func (_m *MockModelComponent) SetRelationDatasets(ctx context.Context, req types.RelationDatasets) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for SetRelationDatasets") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.RelationDatasets) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockModelComponent_SetRelationDatasets_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetRelationDatasets' +type MockModelComponent_SetRelationDatasets_Call struct { + *mock.Call +} + +// SetRelationDatasets is a helper method to define mock.On call +// - ctx context.Context +// - req types.RelationDatasets +func (_e *MockModelComponent_Expecter) SetRelationDatasets(ctx interface{}, req interface{}) *MockModelComponent_SetRelationDatasets_Call { + return &MockModelComponent_SetRelationDatasets_Call{Call: _e.mock.On("SetRelationDatasets", ctx, req)} +} + +func (_c *MockModelComponent_SetRelationDatasets_Call) Run(run func(ctx context.Context, req types.RelationDatasets)) *MockModelComponent_SetRelationDatasets_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.RelationDatasets)) + }) + return _c +} + +func (_c *MockModelComponent_SetRelationDatasets_Call) Return(_a0 error) *MockModelComponent_SetRelationDatasets_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockModelComponent_SetRelationDatasets_Call) RunAndReturn(run func(context.Context, types.RelationDatasets) error) *MockModelComponent_SetRelationDatasets_Call { + _c.Call.Return(run) + return _c +} + +// SetRuntimeFrameworkModes provides a mock function with given fields: ctx, deployType, id, paths +func (_m *MockModelComponent) SetRuntimeFrameworkModes(ctx context.Context, deployType int, id int64, paths []string) ([]string, error) { + ret := _m.Called(ctx, deployType, id, paths) + + if len(ret) == 0 { + panic("no return value specified for SetRuntimeFrameworkModes") + } + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int, int64, []string) ([]string, error)); ok { + return rf(ctx, deployType, id, paths) + } + if rf, ok := ret.Get(0).(func(context.Context, int, int64, []string) []string); ok { + r0 = rf(ctx, deployType, id, paths) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int, int64, []string) error); ok { + r1 = rf(ctx, deployType, id, paths) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockModelComponent_SetRuntimeFrameworkModes_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetRuntimeFrameworkModes' +type MockModelComponent_SetRuntimeFrameworkModes_Call struct { + *mock.Call +} + +// SetRuntimeFrameworkModes is a helper method to define mock.On call +// - ctx context.Context +// - deployType int +// - id int64 +// - paths []string +func (_e *MockModelComponent_Expecter) SetRuntimeFrameworkModes(ctx interface{}, deployType interface{}, id interface{}, paths interface{}) *MockModelComponent_SetRuntimeFrameworkModes_Call { + return &MockModelComponent_SetRuntimeFrameworkModes_Call{Call: _e.mock.On("SetRuntimeFrameworkModes", ctx, deployType, id, paths)} +} + +func (_c *MockModelComponent_SetRuntimeFrameworkModes_Call) Run(run func(ctx context.Context, deployType int, id int64, paths []string)) *MockModelComponent_SetRuntimeFrameworkModes_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int), args[2].(int64), args[3].([]string)) + }) + return _c +} + +func (_c *MockModelComponent_SetRuntimeFrameworkModes_Call) Return(_a0 []string, _a1 error) *MockModelComponent_SetRuntimeFrameworkModes_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockModelComponent_SetRuntimeFrameworkModes_Call) RunAndReturn(run func(context.Context, int, int64, []string) ([]string, error)) *MockModelComponent_SetRuntimeFrameworkModes_Call { + _c.Call.Return(run) + return _c +} + +// Show provides a mock function with given fields: ctx, namespace, name, currentUser, needOpWeight +func (_m *MockModelComponent) Show(ctx context.Context, namespace string, name string, currentUser string, needOpWeight bool) (*types.Model, error) { + ret := _m.Called(ctx, namespace, name, currentUser, needOpWeight) + + if len(ret) == 0 { + panic("no return value specified for Show") + } + + var r0 *types.Model + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, bool) (*types.Model, error)); ok { + return rf(ctx, namespace, name, currentUser, needOpWeight) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, bool) *types.Model); ok { + r0 = rf(ctx, namespace, name, currentUser, needOpWeight) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Model) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, bool) error); ok { + r1 = rf(ctx, namespace, name, currentUser, needOpWeight) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockModelComponent_Show_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Show' +type MockModelComponent_Show_Call struct { + *mock.Call +} + +// Show is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - name string +// - currentUser string +// - needOpWeight bool +func (_e *MockModelComponent_Expecter) Show(ctx interface{}, namespace interface{}, name interface{}, currentUser interface{}, needOpWeight interface{}) *MockModelComponent_Show_Call { + return &MockModelComponent_Show_Call{Call: _e.mock.On("Show", ctx, namespace, name, currentUser, needOpWeight)} +} + +func (_c *MockModelComponent_Show_Call) Run(run func(ctx context.Context, namespace string, name string, currentUser string, needOpWeight bool)) *MockModelComponent_Show_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(bool)) + }) + return _c +} + +func (_c *MockModelComponent_Show_Call) Return(_a0 *types.Model, _a1 error) *MockModelComponent_Show_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockModelComponent_Show_Call) RunAndReturn(run func(context.Context, string, string, string, bool) (*types.Model, error)) *MockModelComponent_Show_Call { + _c.Call.Return(run) + return _c +} + +// Update provides a mock function with given fields: ctx, req +func (_m *MockModelComponent) Update(ctx context.Context, req *types.UpdateModelReq) (*types.Model, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Update") + } + + var r0 *types.Model + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UpdateModelReq) (*types.Model, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.UpdateModelReq) *types.Model); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Model) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.UpdateModelReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockModelComponent_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' +type MockModelComponent_Update_Call struct { + *mock.Call +} + +// Update is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UpdateModelReq +func (_e *MockModelComponent_Expecter) Update(ctx interface{}, req interface{}) *MockModelComponent_Update_Call { + return &MockModelComponent_Update_Call{Call: _e.mock.On("Update", ctx, req)} +} + +func (_c *MockModelComponent_Update_Call) Run(run func(ctx context.Context, req *types.UpdateModelReq)) *MockModelComponent_Update_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UpdateModelReq)) + }) + return _c +} + +func (_c *MockModelComponent_Update_Call) Return(_a0 *types.Model, _a1 error) *MockModelComponent_Update_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockModelComponent_Update_Call) RunAndReturn(run func(context.Context, *types.UpdateModelReq) (*types.Model, error)) *MockModelComponent_Update_Call { + _c.Call.Return(run) + return _c +} + +// NewMockModelComponent creates a new instance of MockModelComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockModelComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockModelComponent { + mock := &MockModelComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/component/mock_UserComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_UserComponent.go new file mode 100644 index 00000000..69e029e0 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/mock_UserComponent.go @@ -0,0 +1,1277 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + database "opencsg.com/csghub-server/builder/store/database" + + types "opencsg.com/csghub-server/common/types" +) + +// MockUserComponent is an autogenerated mock type for the UserComponent type +type MockUserComponent struct { + mock.Mock +} + +type MockUserComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockUserComponent) EXPECT() *MockUserComponent_Expecter { + return &MockUserComponent_Expecter{mock: &_m.Mock} +} + +// AddLikes provides a mock function with given fields: ctx, req +func (_m *MockUserComponent) AddLikes(ctx context.Context, req *types.UserLikesRequest) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for AddLikes") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UserLikesRequest) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockUserComponent_AddLikes_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddLikes' +type MockUserComponent_AddLikes_Call struct { + *mock.Call +} + +// AddLikes is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UserLikesRequest +func (_e *MockUserComponent_Expecter) AddLikes(ctx interface{}, req interface{}) *MockUserComponent_AddLikes_Call { + return &MockUserComponent_AddLikes_Call{Call: _e.mock.On("AddLikes", ctx, req)} +} + +func (_c *MockUserComponent_AddLikes_Call) Run(run func(ctx context.Context, req *types.UserLikesRequest)) *MockUserComponent_AddLikes_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UserLikesRequest)) + }) + return _c +} + +func (_c *MockUserComponent_AddLikes_Call) Return(_a0 error) *MockUserComponent_AddLikes_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockUserComponent_AddLikes_Call) RunAndReturn(run func(context.Context, *types.UserLikesRequest) error) *MockUserComponent_AddLikes_Call { + _c.Call.Return(run) + return _c +} + +// Codes provides a mock function with given fields: ctx, req +func (_m *MockUserComponent) Codes(ctx context.Context, req *types.UserModelsReq) ([]types.Code, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Codes") + } + + var r0 []types.Code + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UserModelsReq) ([]types.Code, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.UserModelsReq) []types.Code); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Code) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.UserModelsReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.UserModelsReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockUserComponent_Codes_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Codes' +type MockUserComponent_Codes_Call struct { + *mock.Call +} + +// Codes is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UserModelsReq +func (_e *MockUserComponent_Expecter) Codes(ctx interface{}, req interface{}) *MockUserComponent_Codes_Call { + return &MockUserComponent_Codes_Call{Call: _e.mock.On("Codes", ctx, req)} +} + +func (_c *MockUserComponent_Codes_Call) Run(run func(ctx context.Context, req *types.UserModelsReq)) *MockUserComponent_Codes_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UserModelsReq)) + }) + return _c +} + +func (_c *MockUserComponent_Codes_Call) Return(_a0 []types.Code, _a1 int, _a2 error) *MockUserComponent_Codes_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockUserComponent_Codes_Call) RunAndReturn(run func(context.Context, *types.UserModelsReq) ([]types.Code, int, error)) *MockUserComponent_Codes_Call { + _c.Call.Return(run) + return _c +} + +// Collections provides a mock function with given fields: ctx, req +func (_m *MockUserComponent) Collections(ctx context.Context, req *types.UserCollectionReq) ([]types.Collection, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Collections") + } + + var r0 []types.Collection + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UserCollectionReq) ([]types.Collection, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.UserCollectionReq) []types.Collection); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Collection) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.UserCollectionReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.UserCollectionReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockUserComponent_Collections_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Collections' +type MockUserComponent_Collections_Call struct { + *mock.Call +} + +// Collections is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UserCollectionReq +func (_e *MockUserComponent_Expecter) Collections(ctx interface{}, req interface{}) *MockUserComponent_Collections_Call { + return &MockUserComponent_Collections_Call{Call: _e.mock.On("Collections", ctx, req)} +} + +func (_c *MockUserComponent_Collections_Call) Run(run func(ctx context.Context, req *types.UserCollectionReq)) *MockUserComponent_Collections_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UserCollectionReq)) + }) + return _c +} + +func (_c *MockUserComponent_Collections_Call) Return(_a0 []types.Collection, _a1 int, _a2 error) *MockUserComponent_Collections_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockUserComponent_Collections_Call) RunAndReturn(run func(context.Context, *types.UserCollectionReq) ([]types.Collection, int, error)) *MockUserComponent_Collections_Call { + _c.Call.Return(run) + return _c +} + +// Datasets provides a mock function with given fields: ctx, req +func (_m *MockUserComponent) Datasets(ctx context.Context, req *types.UserDatasetsReq) ([]types.Dataset, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Datasets") + } + + var r0 []types.Dataset + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UserDatasetsReq) ([]types.Dataset, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.UserDatasetsReq) []types.Dataset); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Dataset) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.UserDatasetsReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.UserDatasetsReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockUserComponent_Datasets_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Datasets' +type MockUserComponent_Datasets_Call struct { + *mock.Call +} + +// Datasets is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UserDatasetsReq +func (_e *MockUserComponent_Expecter) Datasets(ctx interface{}, req interface{}) *MockUserComponent_Datasets_Call { + return &MockUserComponent_Datasets_Call{Call: _e.mock.On("Datasets", ctx, req)} +} + +func (_c *MockUserComponent_Datasets_Call) Run(run func(ctx context.Context, req *types.UserDatasetsReq)) *MockUserComponent_Datasets_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UserDatasetsReq)) + }) + return _c +} + +func (_c *MockUserComponent_Datasets_Call) Return(_a0 []types.Dataset, _a1 int, _a2 error) *MockUserComponent_Datasets_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockUserComponent_Datasets_Call) RunAndReturn(run func(context.Context, *types.UserDatasetsReq) ([]types.Dataset, int, error)) *MockUserComponent_Datasets_Call { + _c.Call.Return(run) + return _c +} + +// DeleteLikes provides a mock function with given fields: ctx, req +func (_m *MockUserComponent) DeleteLikes(ctx context.Context, req *types.UserLikesRequest) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for DeleteLikes") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UserLikesRequest) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockUserComponent_DeleteLikes_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteLikes' +type MockUserComponent_DeleteLikes_Call struct { + *mock.Call +} + +// DeleteLikes is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UserLikesRequest +func (_e *MockUserComponent_Expecter) DeleteLikes(ctx interface{}, req interface{}) *MockUserComponent_DeleteLikes_Call { + return &MockUserComponent_DeleteLikes_Call{Call: _e.mock.On("DeleteLikes", ctx, req)} +} + +func (_c *MockUserComponent_DeleteLikes_Call) Run(run func(ctx context.Context, req *types.UserLikesRequest)) *MockUserComponent_DeleteLikes_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UserLikesRequest)) + }) + return _c +} + +func (_c *MockUserComponent_DeleteLikes_Call) Return(_a0 error) *MockUserComponent_DeleteLikes_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockUserComponent_DeleteLikes_Call) RunAndReturn(run func(context.Context, *types.UserLikesRequest) error) *MockUserComponent_DeleteLikes_Call { + _c.Call.Return(run) + return _c +} + +// Evaluations provides a mock function with given fields: ctx, req +func (_m *MockUserComponent) Evaluations(ctx context.Context, req *types.UserEvaluationReq) ([]types.ArgoWorkFlowRes, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Evaluations") + } + + var r0 []types.ArgoWorkFlowRes + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UserEvaluationReq) ([]types.ArgoWorkFlowRes, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.UserEvaluationReq) []types.ArgoWorkFlowRes); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.ArgoWorkFlowRes) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.UserEvaluationReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.UserEvaluationReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockUserComponent_Evaluations_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Evaluations' +type MockUserComponent_Evaluations_Call struct { + *mock.Call +} + +// Evaluations is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UserEvaluationReq +func (_e *MockUserComponent_Expecter) Evaluations(ctx interface{}, req interface{}) *MockUserComponent_Evaluations_Call { + return &MockUserComponent_Evaluations_Call{Call: _e.mock.On("Evaluations", ctx, req)} +} + +func (_c *MockUserComponent_Evaluations_Call) Run(run func(ctx context.Context, req *types.UserEvaluationReq)) *MockUserComponent_Evaluations_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UserEvaluationReq)) + }) + return _c +} + +func (_c *MockUserComponent_Evaluations_Call) Return(_a0 []types.ArgoWorkFlowRes, _a1 int, _a2 error) *MockUserComponent_Evaluations_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockUserComponent_Evaluations_Call) RunAndReturn(run func(context.Context, *types.UserEvaluationReq) ([]types.ArgoWorkFlowRes, int, error)) *MockUserComponent_Evaluations_Call { + _c.Call.Return(run) + return _c +} + +// GetUserByName provides a mock function with given fields: ctx, userName +func (_m *MockUserComponent) GetUserByName(ctx context.Context, userName string) (*database.User, error) { + ret := _m.Called(ctx, userName) + + if len(ret) == 0 { + panic("no return value specified for GetUserByName") + } + + var r0 *database.User + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*database.User, error)); ok { + return rf(ctx, userName) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *database.User); ok { + r0 = rf(ctx, userName) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.User) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, userName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockUserComponent_GetUserByName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUserByName' +type MockUserComponent_GetUserByName_Call struct { + *mock.Call +} + +// GetUserByName is a helper method to define mock.On call +// - ctx context.Context +// - userName string +func (_e *MockUserComponent_Expecter) GetUserByName(ctx interface{}, userName interface{}) *MockUserComponent_GetUserByName_Call { + return &MockUserComponent_GetUserByName_Call{Call: _e.mock.On("GetUserByName", ctx, userName)} +} + +func (_c *MockUserComponent_GetUserByName_Call) Run(run func(ctx context.Context, userName string)) *MockUserComponent_GetUserByName_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockUserComponent_GetUserByName_Call) Return(_a0 *database.User, _a1 error) *MockUserComponent_GetUserByName_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockUserComponent_GetUserByName_Call) RunAndReturn(run func(context.Context, string) (*database.User, error)) *MockUserComponent_GetUserByName_Call { + _c.Call.Return(run) + return _c +} + +// LikeCollection provides a mock function with given fields: ctx, req +func (_m *MockUserComponent) LikeCollection(ctx context.Context, req *types.UserLikesRequest) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for LikeCollection") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UserLikesRequest) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockUserComponent_LikeCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LikeCollection' +type MockUserComponent_LikeCollection_Call struct { + *mock.Call +} + +// LikeCollection is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UserLikesRequest +func (_e *MockUserComponent_Expecter) LikeCollection(ctx interface{}, req interface{}) *MockUserComponent_LikeCollection_Call { + return &MockUserComponent_LikeCollection_Call{Call: _e.mock.On("LikeCollection", ctx, req)} +} + +func (_c *MockUserComponent_LikeCollection_Call) Run(run func(ctx context.Context, req *types.UserLikesRequest)) *MockUserComponent_LikeCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UserLikesRequest)) + }) + return _c +} + +func (_c *MockUserComponent_LikeCollection_Call) Return(_a0 error) *MockUserComponent_LikeCollection_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockUserComponent_LikeCollection_Call) RunAndReturn(run func(context.Context, *types.UserLikesRequest) error) *MockUserComponent_LikeCollection_Call { + _c.Call.Return(run) + return _c +} + +// LikesCodes provides a mock function with given fields: ctx, req +func (_m *MockUserComponent) LikesCodes(ctx context.Context, req *types.UserModelsReq) ([]types.Code, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for LikesCodes") + } + + var r0 []types.Code + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UserModelsReq) ([]types.Code, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.UserModelsReq) []types.Code); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Code) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.UserModelsReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.UserModelsReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockUserComponent_LikesCodes_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LikesCodes' +type MockUserComponent_LikesCodes_Call struct { + *mock.Call +} + +// LikesCodes is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UserModelsReq +func (_e *MockUserComponent_Expecter) LikesCodes(ctx interface{}, req interface{}) *MockUserComponent_LikesCodes_Call { + return &MockUserComponent_LikesCodes_Call{Call: _e.mock.On("LikesCodes", ctx, req)} +} + +func (_c *MockUserComponent_LikesCodes_Call) Run(run func(ctx context.Context, req *types.UserModelsReq)) *MockUserComponent_LikesCodes_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UserModelsReq)) + }) + return _c +} + +func (_c *MockUserComponent_LikesCodes_Call) Return(_a0 []types.Code, _a1 int, _a2 error) *MockUserComponent_LikesCodes_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockUserComponent_LikesCodes_Call) RunAndReturn(run func(context.Context, *types.UserModelsReq) ([]types.Code, int, error)) *MockUserComponent_LikesCodes_Call { + _c.Call.Return(run) + return _c +} + +// LikesCollection provides a mock function with given fields: ctx, req +func (_m *MockUserComponent) LikesCollection(ctx context.Context, req *types.UserSpacesReq) ([]types.Collection, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for LikesCollection") + } + + var r0 []types.Collection + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UserSpacesReq) ([]types.Collection, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.UserSpacesReq) []types.Collection); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Collection) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.UserSpacesReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.UserSpacesReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockUserComponent_LikesCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LikesCollection' +type MockUserComponent_LikesCollection_Call struct { + *mock.Call +} + +// LikesCollection is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UserSpacesReq +func (_e *MockUserComponent_Expecter) LikesCollection(ctx interface{}, req interface{}) *MockUserComponent_LikesCollection_Call { + return &MockUserComponent_LikesCollection_Call{Call: _e.mock.On("LikesCollection", ctx, req)} +} + +func (_c *MockUserComponent_LikesCollection_Call) Run(run func(ctx context.Context, req *types.UserSpacesReq)) *MockUserComponent_LikesCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UserSpacesReq)) + }) + return _c +} + +func (_c *MockUserComponent_LikesCollection_Call) Return(_a0 []types.Collection, _a1 int, _a2 error) *MockUserComponent_LikesCollection_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockUserComponent_LikesCollection_Call) RunAndReturn(run func(context.Context, *types.UserSpacesReq) ([]types.Collection, int, error)) *MockUserComponent_LikesCollection_Call { + _c.Call.Return(run) + return _c +} + +// LikesDatasets provides a mock function with given fields: ctx, req +func (_m *MockUserComponent) LikesDatasets(ctx context.Context, req *types.UserDatasetsReq) ([]types.Dataset, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for LikesDatasets") + } + + var r0 []types.Dataset + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UserDatasetsReq) ([]types.Dataset, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.UserDatasetsReq) []types.Dataset); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Dataset) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.UserDatasetsReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.UserDatasetsReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockUserComponent_LikesDatasets_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LikesDatasets' +type MockUserComponent_LikesDatasets_Call struct { + *mock.Call +} + +// LikesDatasets is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UserDatasetsReq +func (_e *MockUserComponent_Expecter) LikesDatasets(ctx interface{}, req interface{}) *MockUserComponent_LikesDatasets_Call { + return &MockUserComponent_LikesDatasets_Call{Call: _e.mock.On("LikesDatasets", ctx, req)} +} + +func (_c *MockUserComponent_LikesDatasets_Call) Run(run func(ctx context.Context, req *types.UserDatasetsReq)) *MockUserComponent_LikesDatasets_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UserDatasetsReq)) + }) + return _c +} + +func (_c *MockUserComponent_LikesDatasets_Call) Return(_a0 []types.Dataset, _a1 int, _a2 error) *MockUserComponent_LikesDatasets_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockUserComponent_LikesDatasets_Call) RunAndReturn(run func(context.Context, *types.UserDatasetsReq) ([]types.Dataset, int, error)) *MockUserComponent_LikesDatasets_Call { + _c.Call.Return(run) + return _c +} + +// LikesModels provides a mock function with given fields: ctx, req +func (_m *MockUserComponent) LikesModels(ctx context.Context, req *types.UserModelsReq) ([]types.Model, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for LikesModels") + } + + var r0 []types.Model + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UserModelsReq) ([]types.Model, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.UserModelsReq) []types.Model); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Model) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.UserModelsReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.UserModelsReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockUserComponent_LikesModels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LikesModels' +type MockUserComponent_LikesModels_Call struct { + *mock.Call +} + +// LikesModels is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UserModelsReq +func (_e *MockUserComponent_Expecter) LikesModels(ctx interface{}, req interface{}) *MockUserComponent_LikesModels_Call { + return &MockUserComponent_LikesModels_Call{Call: _e.mock.On("LikesModels", ctx, req)} +} + +func (_c *MockUserComponent_LikesModels_Call) Run(run func(ctx context.Context, req *types.UserModelsReq)) *MockUserComponent_LikesModels_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UserModelsReq)) + }) + return _c +} + +func (_c *MockUserComponent_LikesModels_Call) Return(_a0 []types.Model, _a1 int, _a2 error) *MockUserComponent_LikesModels_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockUserComponent_LikesModels_Call) RunAndReturn(run func(context.Context, *types.UserModelsReq) ([]types.Model, int, error)) *MockUserComponent_LikesModels_Call { + _c.Call.Return(run) + return _c +} + +// LikesSpaces provides a mock function with given fields: ctx, req +func (_m *MockUserComponent) LikesSpaces(ctx context.Context, req *types.UserSpacesReq) ([]types.Space, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for LikesSpaces") + } + + var r0 []types.Space + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UserSpacesReq) ([]types.Space, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.UserSpacesReq) []types.Space); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Space) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.UserSpacesReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.UserSpacesReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockUserComponent_LikesSpaces_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LikesSpaces' +type MockUserComponent_LikesSpaces_Call struct { + *mock.Call +} + +// LikesSpaces is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UserSpacesReq +func (_e *MockUserComponent_Expecter) LikesSpaces(ctx interface{}, req interface{}) *MockUserComponent_LikesSpaces_Call { + return &MockUserComponent_LikesSpaces_Call{Call: _e.mock.On("LikesSpaces", ctx, req)} +} + +func (_c *MockUserComponent_LikesSpaces_Call) Run(run func(ctx context.Context, req *types.UserSpacesReq)) *MockUserComponent_LikesSpaces_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UserSpacesReq)) + }) + return _c +} + +func (_c *MockUserComponent_LikesSpaces_Call) Return(_a0 []types.Space, _a1 int, _a2 error) *MockUserComponent_LikesSpaces_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockUserComponent_LikesSpaces_Call) RunAndReturn(run func(context.Context, *types.UserSpacesReq) ([]types.Space, int, error)) *MockUserComponent_LikesSpaces_Call { + _c.Call.Return(run) + return _c +} + +// ListDeploys provides a mock function with given fields: ctx, repoType, req +func (_m *MockUserComponent) ListDeploys(ctx context.Context, repoType types.RepositoryType, req *types.DeployReq) ([]types.DeployRepo, int, error) { + ret := _m.Called(ctx, repoType, req) + + if len(ret) == 0 { + panic("no return value specified for ListDeploys") + } + + var r0 []types.DeployRepo + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, types.RepositoryType, *types.DeployReq) ([]types.DeployRepo, int, error)); ok { + return rf(ctx, repoType, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.RepositoryType, *types.DeployReq) []types.DeployRepo); ok { + r0 = rf(ctx, repoType, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.DeployRepo) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.RepositoryType, *types.DeployReq) int); ok { + r1 = rf(ctx, repoType, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, types.RepositoryType, *types.DeployReq) error); ok { + r2 = rf(ctx, repoType, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockUserComponent_ListDeploys_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListDeploys' +type MockUserComponent_ListDeploys_Call struct { + *mock.Call +} + +// ListDeploys is a helper method to define mock.On call +// - ctx context.Context +// - repoType types.RepositoryType +// - req *types.DeployReq +func (_e *MockUserComponent_Expecter) ListDeploys(ctx interface{}, repoType interface{}, req interface{}) *MockUserComponent_ListDeploys_Call { + return &MockUserComponent_ListDeploys_Call{Call: _e.mock.On("ListDeploys", ctx, repoType, req)} +} + +func (_c *MockUserComponent_ListDeploys_Call) Run(run func(ctx context.Context, repoType types.RepositoryType, req *types.DeployReq)) *MockUserComponent_ListDeploys_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.RepositoryType), args[2].(*types.DeployReq)) + }) + return _c +} + +func (_c *MockUserComponent_ListDeploys_Call) Return(_a0 []types.DeployRepo, _a1 int, _a2 error) *MockUserComponent_ListDeploys_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockUserComponent_ListDeploys_Call) RunAndReturn(run func(context.Context, types.RepositoryType, *types.DeployReq) ([]types.DeployRepo, int, error)) *MockUserComponent_ListDeploys_Call { + _c.Call.Return(run) + return _c +} + +// ListInstances provides a mock function with given fields: ctx, req +func (_m *MockUserComponent) ListInstances(ctx context.Context, req *types.UserRepoReq) ([]types.DeployRepo, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for ListInstances") + } + + var r0 []types.DeployRepo + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UserRepoReq) ([]types.DeployRepo, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.UserRepoReq) []types.DeployRepo); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.DeployRepo) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.UserRepoReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.UserRepoReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockUserComponent_ListInstances_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListInstances' +type MockUserComponent_ListInstances_Call struct { + *mock.Call +} + +// ListInstances is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UserRepoReq +func (_e *MockUserComponent_Expecter) ListInstances(ctx interface{}, req interface{}) *MockUserComponent_ListInstances_Call { + return &MockUserComponent_ListInstances_Call{Call: _e.mock.On("ListInstances", ctx, req)} +} + +func (_c *MockUserComponent_ListInstances_Call) Run(run func(ctx context.Context, req *types.UserRepoReq)) *MockUserComponent_ListInstances_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UserRepoReq)) + }) + return _c +} + +func (_c *MockUserComponent_ListInstances_Call) Return(_a0 []types.DeployRepo, _a1 int, _a2 error) *MockUserComponent_ListInstances_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockUserComponent_ListInstances_Call) RunAndReturn(run func(context.Context, *types.UserRepoReq) ([]types.DeployRepo, int, error)) *MockUserComponent_ListInstances_Call { + _c.Call.Return(run) + return _c +} + +// ListServerless provides a mock function with given fields: ctx, req +func (_m *MockUserComponent) ListServerless(ctx context.Context, req types.DeployReq) ([]types.DeployRepo, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for ListServerless") + } + + var r0 []types.DeployRepo + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, types.DeployReq) ([]types.DeployRepo, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.DeployReq) []types.DeployRepo); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.DeployRepo) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.DeployReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, types.DeployReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockUserComponent_ListServerless_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListServerless' +type MockUserComponent_ListServerless_Call struct { + *mock.Call +} + +// ListServerless is a helper method to define mock.On call +// - ctx context.Context +// - req types.DeployReq +func (_e *MockUserComponent_Expecter) ListServerless(ctx interface{}, req interface{}) *MockUserComponent_ListServerless_Call { + return &MockUserComponent_ListServerless_Call{Call: _e.mock.On("ListServerless", ctx, req)} +} + +func (_c *MockUserComponent_ListServerless_Call) Run(run func(ctx context.Context, req types.DeployReq)) *MockUserComponent_ListServerless_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.DeployReq)) + }) + return _c +} + +func (_c *MockUserComponent_ListServerless_Call) Return(_a0 []types.DeployRepo, _a1 int, _a2 error) *MockUserComponent_ListServerless_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockUserComponent_ListServerless_Call) RunAndReturn(run func(context.Context, types.DeployReq) ([]types.DeployRepo, int, error)) *MockUserComponent_ListServerless_Call { + _c.Call.Return(run) + return _c +} + +// Models provides a mock function with given fields: ctx, req +func (_m *MockUserComponent) Models(ctx context.Context, req *types.UserModelsReq) ([]types.Model, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Models") + } + + var r0 []types.Model + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UserModelsReq) ([]types.Model, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.UserModelsReq) []types.Model); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Model) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.UserModelsReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.UserModelsReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockUserComponent_Models_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Models' +type MockUserComponent_Models_Call struct { + *mock.Call +} + +// Models is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UserModelsReq +func (_e *MockUserComponent_Expecter) Models(ctx interface{}, req interface{}) *MockUserComponent_Models_Call { + return &MockUserComponent_Models_Call{Call: _e.mock.On("Models", ctx, req)} +} + +func (_c *MockUserComponent_Models_Call) Run(run func(ctx context.Context, req *types.UserModelsReq)) *MockUserComponent_Models_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UserModelsReq)) + }) + return _c +} + +func (_c *MockUserComponent_Models_Call) Return(_a0 []types.Model, _a1 int, _a2 error) *MockUserComponent_Models_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockUserComponent_Models_Call) RunAndReturn(run func(context.Context, *types.UserModelsReq) ([]types.Model, int, error)) *MockUserComponent_Models_Call { + _c.Call.Return(run) + return _c +} + +// Prompts provides a mock function with given fields: ctx, req +func (_m *MockUserComponent) Prompts(ctx context.Context, req *types.UserPromptsReq) ([]types.PromptRes, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Prompts") + } + + var r0 []types.PromptRes + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UserPromptsReq) ([]types.PromptRes, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.UserPromptsReq) []types.PromptRes); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.PromptRes) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.UserPromptsReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.UserPromptsReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockUserComponent_Prompts_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Prompts' +type MockUserComponent_Prompts_Call struct { + *mock.Call +} + +// Prompts is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UserPromptsReq +func (_e *MockUserComponent_Expecter) Prompts(ctx interface{}, req interface{}) *MockUserComponent_Prompts_Call { + return &MockUserComponent_Prompts_Call{Call: _e.mock.On("Prompts", ctx, req)} +} + +func (_c *MockUserComponent_Prompts_Call) Run(run func(ctx context.Context, req *types.UserPromptsReq)) *MockUserComponent_Prompts_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UserPromptsReq)) + }) + return _c +} + +func (_c *MockUserComponent_Prompts_Call) Return(_a0 []types.PromptRes, _a1 int, _a2 error) *MockUserComponent_Prompts_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockUserComponent_Prompts_Call) RunAndReturn(run func(context.Context, *types.UserPromptsReq) ([]types.PromptRes, int, error)) *MockUserComponent_Prompts_Call { + _c.Call.Return(run) + return _c +} + +// Spaces provides a mock function with given fields: ctx, req +func (_m *MockUserComponent) Spaces(ctx context.Context, req *types.UserSpacesReq) ([]types.Space, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Spaces") + } + + var r0 []types.Space + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UserSpacesReq) ([]types.Space, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.UserSpacesReq) []types.Space); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Space) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.UserSpacesReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.UserSpacesReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockUserComponent_Spaces_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Spaces' +type MockUserComponent_Spaces_Call struct { + *mock.Call +} + +// Spaces is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UserSpacesReq +func (_e *MockUserComponent_Expecter) Spaces(ctx interface{}, req interface{}) *MockUserComponent_Spaces_Call { + return &MockUserComponent_Spaces_Call{Call: _e.mock.On("Spaces", ctx, req)} +} + +func (_c *MockUserComponent_Spaces_Call) Run(run func(ctx context.Context, req *types.UserSpacesReq)) *MockUserComponent_Spaces_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UserSpacesReq)) + }) + return _c +} + +func (_c *MockUserComponent_Spaces_Call) Return(_a0 []types.Space, _a1 int, _a2 error) *MockUserComponent_Spaces_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockUserComponent_Spaces_Call) RunAndReturn(run func(context.Context, *types.UserSpacesReq) ([]types.Space, int, error)) *MockUserComponent_Spaces_Call { + _c.Call.Return(run) + return _c +} + +// UnLikeCollection provides a mock function with given fields: ctx, req +func (_m *MockUserComponent) UnLikeCollection(ctx context.Context, req *types.UserLikesRequest) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for UnLikeCollection") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UserLikesRequest) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockUserComponent_UnLikeCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UnLikeCollection' +type MockUserComponent_UnLikeCollection_Call struct { + *mock.Call +} + +// UnLikeCollection is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UserLikesRequest +func (_e *MockUserComponent_Expecter) UnLikeCollection(ctx interface{}, req interface{}) *MockUserComponent_UnLikeCollection_Call { + return &MockUserComponent_UnLikeCollection_Call{Call: _e.mock.On("UnLikeCollection", ctx, req)} +} + +func (_c *MockUserComponent_UnLikeCollection_Call) Run(run func(ctx context.Context, req *types.UserLikesRequest)) *MockUserComponent_UnLikeCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UserLikesRequest)) + }) + return _c +} + +func (_c *MockUserComponent_UnLikeCollection_Call) Return(_a0 error) *MockUserComponent_UnLikeCollection_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockUserComponent_UnLikeCollection_Call) RunAndReturn(run func(context.Context, *types.UserLikesRequest) error) *MockUserComponent_UnLikeCollection_Call { + _c.Call.Return(run) + return _c +} + +// NewMockUserComponent creates a new instance of MockUserComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockUserComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockUserComponent { + mock := &MockUserComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/api/handler/git_http.go b/api/handler/git_http.go index 5047cea1..a9bcb86b 100644 --- a/api/handler/git_http.go +++ b/api/handler/git_http.go @@ -25,12 +25,12 @@ func NewGitHTTPHandler(config *config.Config) (*GitHTTPHandler, error) { return nil, err } return &GitHTTPHandler{ - c: uc, + gitHttp: uc, }, nil } type GitHTTPHandler struct { - c component.GitHTTPComponent + gitHttp component.GitHTTPComponent } func (h *GitHTTPHandler) InfoRefs(ctx *gin.Context) { @@ -53,7 +53,7 @@ func (h *GitHTTPHandler) InfoRefs(ctx *gin.Context) { GitProtocol: gitProtocol, CurrentUser: httpbase.GetCurrentUser(ctx), } - reader, err := h.c.InfoRefs(ctx, req) + reader, err := h.gitHttp.InfoRefs(ctx, req) if err != nil { if err == component.ErrUnauthorized { ctx.Header("WWW-Authenticate", "Basic realm=opencsg-git") @@ -109,7 +109,7 @@ func (h *GitHTTPHandler) GitUploadPack(ctx *gin.Context) { ctx.Header("Content-Type", fmt.Sprintf("application/x-%s-result", action)) ctx.Header("Cache-Control", "no-cache") - err := h.c.GitUploadPack(ctx, req) + err := h.gitHttp.GitUploadPack(ctx, req) if err != nil { httpbase.ServerError(ctx, err) return @@ -132,7 +132,7 @@ func (h *GitHTTPHandler) GitReceivePack(ctx *gin.Context) { ctx.Header("Content-Type", fmt.Sprintf("application/x-%s-result", action)) ctx.Header("Cache-Control", "no-cache") - err := h.c.GitReceivePack(ctx, req) + err := h.gitHttp.GitReceivePack(ctx, req) if err != nil { if err == component.ErrUnauthorized { ctx.Header("WWW-Authenticate", "Basic realm=opencsg-git") @@ -175,12 +175,7 @@ func (h *GitHTTPHandler) LfsBatch(ctx *gin.Context) { return } - s3Internal := ctx.GetHeader("X-OPENCSG-S3-Internal") - if s3Internal == "true" { - ctx.Set("X-OPENCSG-S3-Internal", true) - } - - objectResponse, err := h.c.BuildObjectResponse(ctx, batchRequest, isUpload) + objectResponse, err := h.gitHttp.BuildObjectResponse(ctx, batchRequest, isUpload) if err != nil { if errors.Is(err, component.ErrUnauthorized) { ctx.Header("WWW-Authenticate", "Basic realm=opencsg-git") @@ -216,7 +211,7 @@ func (h *GitHTTPHandler) LfsUpload(ctx *gin.Context) { uploadRequest.RepoType = types.RepositoryType(ctx.GetString("repo_type")) uploadRequest.CurrentUser = httpbase.GetCurrentUser(ctx) - err = h.c.LfsUpload(ctx, ctx.Request.Body, uploadRequest) + err = h.gitHttp.LfsUpload(ctx, ctx.Request.Body, uploadRequest) if err != nil { httpbase.ServerError(ctx, err) return @@ -240,12 +235,7 @@ func (h *GitHTTPHandler) LfsDownload(ctx *gin.Context) { downloadRequest.CurrentUser = httpbase.GetCurrentUser(ctx) downloadRequest.SaveAs = ctx.Query("save_as") - s3Internal := ctx.GetHeader("X-OPENCSG-S3-Internal") - if s3Internal == "true" { - ctx.Set("X-OPENCSG-S3-Internal", true) - } - - url, err := h.c.LfsDownload(ctx, downloadRequest) + url, err := h.gitHttp.LfsDownload(ctx, downloadRequest) if err != nil { httpbase.ServerError(ctx, err) return @@ -269,7 +259,7 @@ func (h *GitHTTPHandler) LfsVerify(ctx *gin.Context) { verifyRequest.RepoType = types.RepositoryType(ctx.GetString("repo_type")) verifyRequest.CurrentUser = httpbase.GetCurrentUser(ctx) - err := h.c.LfsVerify(ctx, verifyRequest, pointer) + err := h.gitHttp.LfsVerify(ctx, verifyRequest, pointer) if err != nil { slog.Error("Bad request format", "error", err) httpbase.BadRequest(ctx, err.Error()) @@ -312,7 +302,7 @@ func (h *GitHTTPHandler) ListLocks(ctx *gin.Context) { } req.Limit = limit - res, err := h.c.ListLocks(ctx, req) + res, err := h.gitHttp.ListLocks(ctx, req) if err != nil { if errors.Is(err, component.ErrUnauthorized) { ctx.Header("WWW-Authenticate", "Basic realm=opencsg-git") @@ -344,7 +334,7 @@ func (h *GitHTTPHandler) CreateLock(ctx *gin.Context) { req.RepoType = types.RepositoryType(ctx.GetString("repo_type")) req.CurrentUser = httpbase.GetCurrentUser(ctx) - lock, err := h.c.CreateLock(ctx, req) + lock, err := h.gitHttp.CreateLock(ctx, req) if err != nil { if errors.Is(err, component.ErrAlreadyExists) { ctx.PureJSON(http.StatusConflict, types.LFSLockError{ @@ -413,7 +403,7 @@ func (h *GitHTTPHandler) VerifyLock(ctx *gin.Context) { } req.Limit = limit - res, err := h.c.VerifyLock(ctx, req) + res, err := h.gitHttp.VerifyLock(ctx, req) if err != nil { slog.Error("Bad request format", "error", err) ctx.PureJSON(http.StatusInternalServerError, types.LFSLockError{ @@ -452,7 +442,7 @@ func (h *GitHTTPHandler) UnLock(ctx *gin.Context) { req.RepoType = types.RepositoryType(ctx.GetString("repo_type")) req.CurrentUser = httpbase.GetCurrentUser(ctx) - lock, err = h.c.UnLock(ctx, req) + lock, err = h.gitHttp.UnLock(ctx, req) if err != nil { if errors.Is(err, component.ErrUnauthorized) { ctx.Header("WWW-Authenticate", "Basic realm=opencsg-git") diff --git a/api/handler/git_http_test.go b/api/handler/git_http_test.go new file mode 100644 index 00000000..e1d844fa --- /dev/null +++ b/api/handler/git_http_test.go @@ -0,0 +1,315 @@ +package handler + +import ( + "bytes" + "compress/gzip" + "io" + "net/url" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +type GitHTTPTester struct { + *GinTester + handler *GitHTTPHandler + mocks struct { + gitHttp *mockcomponent.MockGitHTTPComponent + } +} + +func NewGitHTTPTester(t *testing.T) *GitHTTPTester { + tester := &GitHTTPTester{GinTester: NewGinTester()} + tester.mocks.gitHttp = mockcomponent.NewMockGitHTTPComponent(t) + + tester.handler = &GitHTTPHandler{ + gitHttp: tester.mocks.gitHttp, + } + tester.WithParam("repo", "testRepo") + tester.WithParam("branch", "testBranch") + return tester +} + +func (t *GitHTTPTester) WithHandleFunc(fn func(h *GitHTTPHandler) gin.HandlerFunc) *GitHTTPTester { + t.ginHandler = fn(t.handler) + return t +} + +func TestGitHTTPHandler_InfoRefs(t *testing.T) { + tester := NewGitHTTPTester(t).WithHandleFunc(func(h *GitHTTPHandler) gin.HandlerFunc { + return h.InfoRefs + }) + + reader := io.NopCloser(bytes.NewBuffer([]byte("foo"))) + tester.mocks.gitHttp.EXPECT().InfoRefs(tester.ctx, types.InfoRefsReq{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + Rpc: "git-upload-pack", + GitProtocol: "ssh", + CurrentUser: "u", + }).Return(reader, nil) + tester.WithQuery("service", "git-upload-pack").WithHeader("Git-Protocol", "ssh") + tester.WithKV("namespace", "u").WithKV("name", "r") + tester.WithKV("repo_type", "model").WithUser().WithHeader("Accept-Encoding", "gzip").Execute() + + require.Equal(t, 200, tester.response.Code) + var b bytes.Buffer + gz := gzip.NewWriter(&b) + _, err := gz.Write([]byte("foo")) + require.NoError(t, err) + err = gz.Close() + require.NoError(t, err) + require.Equal(t, b.String(), tester.response.Body.String()) + headers := tester.response.Header() + require.Equal(t, "application/x-git-upload-pack-advertisement", headers.Get("Content-Type")) + require.Equal(t, "no-cache", headers.Get("Cache-Control")) +} + +func TestGitHTTPHandler_GitUploadPack(t *testing.T) { + tester := NewGitHTTPTester(t).WithHandleFunc(func(h *GitHTTPHandler) gin.HandlerFunc { + return h.GitUploadPack + }) + + tester.mocks.gitHttp.EXPECT().GitUploadPack(tester.ctx, types.GitUploadPackReq{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + GitProtocol: "ssh", + Request: tester.ctx.Request, + Writer: tester.ctx.Writer, + CurrentUser: "u", + }).Return(nil) + tester.SetPath("git").WithQuery("service", "git-upload-pack").WithHeader("Git-Protocol", "ssh") + tester.WithKV("namespace", "u").WithKV("name", "r") + tester.WithKV("repo_type", "model").WithUser().WithHeader("Accept-Encoding", "gzip").Execute() + + require.Equal(t, 200, tester.response.Code) + headers := tester.response.Header() + require.Equal(t, "application/x-git-result", headers.Get("Content-Type")) + require.Equal(t, "no-cache", headers.Get("Cache-Control")) +} + +func TestGitHTTPHandler_GitReceivePack(t *testing.T) { + tester := NewGitHTTPTester(t).WithHandleFunc(func(h *GitHTTPHandler) gin.HandlerFunc { + return h.GitReceivePack + }) + + tester.mocks.gitHttp.EXPECT().GitReceivePack(tester.ctx, types.GitUploadPackReq{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + GitProtocol: "ssh", + Request: tester.ctx.Request, + Writer: tester.ctx.Writer, + CurrentUser: "u", + }).Return(nil) + tester.SetPath("git").WithQuery("service", "git-upload-pack").WithHeader("Git-Protocol", "ssh") + tester.WithKV("namespace", "u").WithKV("name", "r") + tester.WithKV("repo_type", "model").WithUser().WithHeader("Accept-Encoding", "gzip").Execute() + + require.Equal(t, 200, tester.response.Code) + headers := tester.response.Header() + require.Equal(t, "application/x-git-result", headers.Get("Content-Type")) + require.Equal(t, "no-cache", headers.Get("Cache-Control")) +} + +func TestGitHTTPHandler_LfsBatch(t *testing.T) { + + for _, c := range []string{"upload", "download"} { + t.Run("c", func(t *testing.T) { + + tester := NewGitHTTPTester(t).WithHandleFunc(func(h *GitHTTPHandler) gin.HandlerFunc { + return h.LfsBatch + }) + + tester.mocks.gitHttp.EXPECT().BuildObjectResponse(tester.ctx, types.BatchRequest{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + CurrentUser: "u", + Authorization: "auth", + Operation: c, + }, c == "upload").Return(&types.BatchResponse{Transfer: "t"}, nil) + tester.SetPath("git").WithQuery("service", "git-upload-pack").WithHeader("Git-Protocol", "ssh") + tester.WithKV("namespace", "u").WithKV("name", "r").WithHeader("Authorization", "auth") + tester.WithBody(t, &types.BatchRequest{Operation: c}) + tester.WithKV("repo_type", "model").WithUser().WithHeader("Accept-Encoding", "gzip").Execute() + + tester.ResponseEqSimple(t, 200, &types.BatchResponse{Transfer: "t"}) + headers := tester.response.Header() + require.Equal(t, types.LfsMediaType, headers.Get("Content-Type")) + }) + } +} + +func TestGitHTTPHandler_LfsUpload(t *testing.T) { + tester := NewGitHTTPTester(t).WithHandleFunc(func(h *GitHTTPHandler) gin.HandlerFunc { + return h.LfsUpload + }) + + tester.mocks.gitHttp.EXPECT().LfsUpload(tester.ctx, tester.ctx.Request.Body, types.UploadRequest{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + CurrentUser: "u", + Oid: "o", + Size: 100, + }).Return(nil) + tester.SetPath("git").WithParam("oid", "o").WithParam("size", "100") + tester.WithKV("namespace", "u").WithKV("name", "r").WithHeader("Authorization", "auth") + tester.WithKV("repo_type", "model").WithUser().Execute() + + require.Equal(t, 200, tester.response.Code) + headers := tester.response.Header() + require.Equal(t, types.LfsMediaType, headers.Get("Content-Type")) +} + +func TestGitHTTPHandler_LfsDownload(t *testing.T) { + tester := NewGitHTTPTester(t).WithHandleFunc(func(h *GitHTTPHandler) gin.HandlerFunc { + return h.LfsDownload + }) + + tester.mocks.gitHttp.EXPECT().LfsDownload(tester.ctx, types.DownloadRequest{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + CurrentUser: "u", + Oid: "o", + Size: 100, + }).Return(&url.URL{Path: "url"}, nil) + tester.SetPath("git").WithParam("oid", "o").WithParam("size", "100") + tester.WithKV("namespace", "u").WithKV("name", "r").WithHeader("Authorization", "auth") + tester.WithKV("repo_type", "model").WithUser().Execute() + + require.Equal(t, 200, tester.response.Code) + resp := tester.response.Result() + defer resp.Body.Close() + lc, err := resp.Location() + require.NoError(t, err) + require.Equal(t, "url", lc.String()) +} + +func TestGitHTTPHandler_LfsVerify(t *testing.T) { + tester := NewGitHTTPTester(t).WithHandleFunc(func(h *GitHTTPHandler) gin.HandlerFunc { + return h.LfsVerify + }) + + tester.mocks.gitHttp.EXPECT().LfsVerify(tester.ctx, types.VerifyRequest{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + CurrentUser: "u", + }, types.Pointer{Oid: "o"}).Return(nil) + tester.WithKV("namespace", "u").WithKV("name", "r").WithHeader("Authorization", "auth") + tester.WithKV("repo_type", "model").WithUser().WithBody(t, &types.Pointer{ + Oid: "o", + }).Execute() + + tester.ResponseEqSimple(t, 200, nil) +} + +func TestGitHTTPHandler_ListLocks(t *testing.T) { + tester := NewGitHTTPTester(t).WithHandleFunc(func(h *GitHTTPHandler) gin.HandlerFunc { + return h.ListLocks + }) + + tester.mocks.gitHttp.EXPECT().ListLocks(tester.ctx, types.ListLFSLockReq{ + ID: 1, + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + CurrentUser: "u", + Cursor: 12, + Path: "p", + Limit: 5, + }).Return(&types.LFSLockList{Next: "n"}, nil) + tester.WithKV("namespace", "u").WithKV("name", "r").WithQuery("path", "p").WithQuery("id", "1") + tester.WithKV("repo_type", "model").WithUser().WithQuery("cursor", "12").WithQuery("limit", "5").Execute() + + tester.ResponseEqSimple(t, 200, &types.LFSLockList{Next: "n"}) +} + +func TestGitHTTPHandler_CreateLock(t *testing.T) { + tester := NewGitHTTPTester(t).WithHandleFunc(func(h *GitHTTPHandler) gin.HandlerFunc { + return h.CreateLock + }) + + tester.mocks.gitHttp.EXPECT().CreateLock(tester.ctx, types.LfsLockReq{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + CurrentUser: "u", + }).Return(&database.LfsLock{ + ID: 1, + Path: "p", + User: database.User{Username: "u"}, + }, nil) + tester.WithKV("namespace", "u").WithKV("name", "r") + tester.WithKV("repo_type", "model").WithUser().WithBody(t, types.LfsLockReq{}).Execute() + + tester.ResponseEqSimple(t, 200, &types.LFSLockResponse{ + Lock: &types.LFSLock{ + ID: "1", + Path: "p", + Owner: &types.LFSLockOwner{ + Name: "u", + }, + }, + }) +} + +func TestGitHTTPHandler_VerifyLock(t *testing.T) { + tester := NewGitHTTPTester(t).WithHandleFunc(func(h *GitHTTPHandler) gin.HandlerFunc { + return h.VerifyLock + }) + + tester.mocks.gitHttp.EXPECT().VerifyLock(tester.ctx, types.VerifyLFSLockReq{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + CurrentUser: "u", + }).Return(&types.LFSLockListVerify{ + Next: "n", + }, nil) + tester.WithKV("namespace", "u").WithKV("name", "r") + tester.WithKV("repo_type", "model").WithUser().WithBody(t, types.VerifyLFSLockReq{}).Execute() + + tester.ResponseEqSimple(t, 200, &types.LFSLockListVerify{ + Next: "n", + }) +} + +func TestGitHTTPHandler_UnLock(t *testing.T) { + tester := NewGitHTTPTester(t).WithHandleFunc(func(h *GitHTTPHandler) gin.HandlerFunc { + return h.UnLock + }) + + tester.mocks.gitHttp.EXPECT().UnLock(tester.ctx, types.UnlockLFSReq{ + ID: 1, + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + CurrentUser: "u", + }).Return(&database.LfsLock{ + ID: 1, + Path: "p", + User: database.User{Username: "u"}, + }, nil) + tester.WithKV("namespace", "u").WithKV("name", "r").WithParam("lid", "1") + tester.WithKV("repo_type", "model").WithUser().WithBody(t, types.UnlockLFSReq{}).Execute() + + tester.ResponseEqSimple(t, 200, &types.LFSLockResponse{ + Lock: &types.LFSLock{ + ID: "1", + Path: "p", + Owner: &types.LFSLockOwner{ + Name: "u", + }, + }, + }) +} diff --git a/api/handler/helper_test.go b/api/handler/helper_test.go index 4c62d995..7b9bdc38 100644 --- a/api/handler/helper_test.go +++ b/api/handler/helper_test.go @@ -22,6 +22,7 @@ type GinTester struct { ctx *gin.Context response *httptest.ResponseRecorder OKText string // text of httpbase.OK + _executed bool } func NewGinTester() *GinTester { @@ -45,6 +46,7 @@ func (g *GinTester) Handler(handler gin.HandlerFunc) { func (g *GinTester) Execute() { g.ginHandler(g.ctx) + g._executed = true } func (g *GinTester) WithUser() *GinTester { g.ctx.Set(httpbase.CurrentUserCtxVar, "u") @@ -87,6 +89,21 @@ func (g *GinTester) WithQuery(key, value string) *GinTester { return g } +func (g *GinTester) SetPath(path string) *GinTester { + g.ctx.Request.URL.Path = path + return g +} + +func (g *GinTester) WithHeader(key, value string) *GinTester { + h := g.ctx.Request.Header + if h == nil { + h = map[string][]string{} + } + h.Add(key, value) + g.ctx.Request.Header = h + return g +} + func (g *GinTester) AddPagination(page int, per int) *GinTester { g.WithQuery("page", cast.ToString(page)) g.WithQuery("per", cast.ToString(per)) @@ -94,6 +111,9 @@ func (g *GinTester) AddPagination(page int, per int) *GinTester { } func (g *GinTester) ResponseEq(t *testing.T, code int, msg string, expected any) { + if !g._executed { + require.FailNow(t, "call Execute method first") + } var r = struct { Msg string `json:"msg"` Data any `json:"data,omitempty"` @@ -109,9 +129,12 @@ func (g *GinTester) ResponseEq(t *testing.T, code int, msg string, expected any) } func (g *GinTester) ResponseEqSimple(t *testing.T, code int, expected any) { + if !g._executed { + require.FailNow(t, "call Execute method first") + } b, err := json.Marshal(expected) require.NoError(t, err) - require.Equal(t, code, g.response.Code) + require.Equal(t, code, g.response.Code, g.response.Body.String()) require.JSONEq(t, string(b), g.response.Body.String()) } @@ -121,6 +144,7 @@ func (g *GinTester) RequireUser(t *testing.T) { tmp := NewGinTester() tmp.ctx.Params = g.ctx.Params g.ginHandler(tmp.ctx) + tmp._executed = true tmp.ResponseEq(t, http.StatusUnauthorized, component.ErrUserNotFound.Error(), nil) // add user to original test ctx now _ = g.WithUser() diff --git a/api/handler/model.go b/api/handler/model.go index cd0eb829..218c1294 100644 --- a/api/handler/model.go +++ b/api/handler/model.go @@ -32,16 +32,16 @@ func NewModelHandler(config *config.Config) (*ModelHandler, error) { } return &ModelHandler{ - c: uc, - sc: sc, - repo: repo, + model: uc, + sensitive: sc, + repo: repo, }, nil } type ModelHandler struct { - c component.ModelComponent - sc component.SensitiveComponent - repo component.RepoComponent + model component.ModelComponent + repo component.RepoComponent + sensitive component.SensitiveComponent } // GetVisiableModels godoc @@ -92,7 +92,8 @@ func (h *ModelHandler) Index(ctx *gin.Context) { return } - models, total, err := h.c.Index(ctx, filter, per, page, false) + models, total, err := h.model.Index(ctx, filter, per, page, false) + if err != nil { slog.Error("Failed to get models", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -133,14 +134,14 @@ func (h *ModelHandler) Create(ctx *gin.Context) { } req.Username = currentUser - _, err := h.sc.CheckRequestV2(ctx, req) + _, err := h.sensitive.CheckRequestV2(ctx, req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) return } - model, err := h.c.Create(ctx, req) + model, err := h.model.Create(ctx, req) if err != nil { slog.Error("Failed to create model", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -178,7 +179,7 @@ func (h *ModelHandler) Update(ctx *gin.Context) { return } - _, err := h.sc.CheckRequestV2(ctx, req) + _, err := h.sensitive.CheckRequestV2(ctx, req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -195,7 +196,7 @@ func (h *ModelHandler) Update(ctx *gin.Context) { req.Name = name req.Username = currentUser - model, err := h.c.Update(ctx, req) + model, err := h.model.Update(ctx, req) if err != nil { slog.Error("Failed to update model", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -232,7 +233,7 @@ func (h *ModelHandler) Delete(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - err = h.c.Delete(ctx, namespace, name, currentUser) + err = h.model.Delete(ctx, namespace, name, currentUser) if err != nil { slog.Error("Failed to delete model", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -264,7 +265,8 @@ func (h *ModelHandler) Show(ctx *gin.Context) { return } currentUser := httpbase.GetCurrentUser(ctx) - detail, err := h.c.Show(ctx, namespace, name, currentUser, false) + detail, err := h.model.Show(ctx, namespace, name, currentUser, false) + if err != nil { if errors.Is(err, component.ErrUnauthorized) { httpbase.UnauthorizedError(ctx, err) @@ -292,7 +294,7 @@ func (h *ModelHandler) SDKModelInfo(ctx *gin.Context) { ref = mappedBranch } currentUser := httpbase.GetCurrentUser(ctx) - modelInfo, err := h.c.SDKModelInfo(ctx, namespace, name, ref, currentUser) + modelInfo, err := h.model.SDKModelInfo(ctx, namespace, name, ref, currentUser) if err != nil { if errors.Is(err, component.ErrUnauthorized) { httpbase.UnauthorizedError(ctx, err) @@ -327,7 +329,7 @@ func (h *ModelHandler) Relations(ctx *gin.Context) { return } currentUser := httpbase.GetCurrentUser(ctx) - detail, err := h.c.Relations(ctx, namespace, name, currentUser) + detail, err := h.model.Relations(ctx, namespace, name, currentUser) if err != nil { if errors.Is(err, component.ErrUnauthorized) { httpbase.UnauthorizedError(ctx, err) @@ -379,7 +381,7 @@ func (h *ModelHandler) SetRelations(ctx *gin.Context) { req.Name = name req.CurrentUser = currentUser - err = h.c.SetRelationDatasets(ctx, req) + err = h.model.SetRelationDatasets(ctx, req) if err != nil { slog.Error("Failed to set datasets for model", slog.Any("req", req), slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -426,7 +428,7 @@ func (h *ModelHandler) AddDatasetRelation(ctx *gin.Context) { req.Name = name req.CurrentUser = currentUser - err = h.c.AddRelationDataset(ctx, req) + err = h.model.AddRelationDataset(ctx, req) if err != nil { slog.Error("Failed to add dataset for model", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -473,7 +475,7 @@ func (h *ModelHandler) DelDatasetRelation(ctx *gin.Context) { req.Name = name req.CurrentUser = currentUser - err = h.c.DelRelationDataset(ctx, req) + err = h.model.DelRelationDataset(ctx, req) if err != nil { slog.Error("Failed to delete dataset for model", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -598,7 +600,7 @@ func (h *ModelHandler) DeployDedicated(ctx *gin.Context) { return } - _, err = h.sc.CheckRequestV2(ctx, &req) + _, err = h.sensitive.CheckRequestV2(ctx, &req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -611,7 +613,7 @@ func (h *ModelHandler) DeployDedicated(ctx *gin.Context) { CurrentUser: currentUser, DeployType: types.InferenceType, } - deployID, err := h.c.Deploy(ctx, epReq, req) + deployID, err := h.model.Deploy(ctx, epReq, req) if err != nil { slog.Error("failed to deploy model as inference", slog.String("namespace", namespace), slog.String("name", name), slog.Any("currentUser", currentUser), slog.Any("req", req), slog.Any("error", err)) @@ -693,7 +695,7 @@ func (h *ModelHandler) FinetuneCreate(ctx *gin.Context) { DeployType: types.FinetuneType, } - deployID, err := h.c.Deploy(ctx, ftReq, *modelReq) + deployID, err := h.model.Deploy(ctx, ftReq, *modelReq) if err != nil { slog.Error("failed to deploy model as notebook instance", slog.String("namespace", namespace), slog.String("name", name), slog.Any("error", err)) @@ -975,7 +977,7 @@ func (h *ModelHandler) ListByRuntimeFrameworkID(ctx *gin.Context) { return } - models, total, err := h.c.ListModelsByRuntimeFrameworkID(ctx, currentUser, per, page, id, deployType) + models, total, err := h.model.ListModelsByRuntimeFrameworkID(ctx, currentUser, per, page, id, deployType) if err != nil { slog.Error("Failed to get models", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -1115,7 +1117,7 @@ func (h *ModelHandler) ListAllRuntimeFramework(ctx *gin.Context) { return } - runtimes, err := h.c.ListAllByRuntimeFramework(ctx, currentUser) + runtimes, err := h.model.ListAllByRuntimeFramework(ctx, currentUser) if err != nil { slog.Error("Failed to get runtime frameworks", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -1171,7 +1173,7 @@ func (h *ModelHandler) UpdateModelRuntimeFrameworks(ctx *gin.Context) { slog.Info("update runtime frameworks models", slog.Any("req", req), slog.Any("runtime framework id", id), slog.Any("deployType", deployType)) - list, err := h.c.SetRuntimeFrameworkModes(ctx, deployType, id, req.Models) + list, err := h.model.SetRuntimeFrameworkModes(ctx, deployType, id, req.Models) if err != nil { slog.Error("Failed to set models runtime framework", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -1224,7 +1226,7 @@ func (h *ModelHandler) DeleteModelRuntimeFrameworks(ctx *gin.Context) { slog.Info("update runtime frameworks models", slog.Any("req", req), slog.Any("runtime framework id", id), slog.Any("deployType", deployType)) - list, err := h.c.DeleteRuntimeFrameworkModes(ctx, deployType, id, req.Models) + list, err := h.model.DeleteRuntimeFrameworkModes(ctx, deployType, id, req.Models) if err != nil { slog.Error("Failed to set models runtime framework", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -1277,7 +1279,7 @@ func (h *ModelHandler) ListModelsOfRuntimeFrameworks(ctx *gin.Context) { return } - models, total, err := h.c.ListModelsOfRuntimeFrameworks(ctx, currentUser, filter.Search, filter.Sort, per, page, deployType) + models, total, err := h.model.ListModelsOfRuntimeFrameworks(ctx, currentUser, filter.Search, filter.Sort, per, page, deployType) if err != nil { slog.Error("fail to get models for all runtime frameworks", slog.Any("deployType", deployType), slog.Any("per", per), slog.Any("page", page), slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -1382,7 +1384,7 @@ func (h *ModelHandler) DeployServerless(ctx *gin.Context) { } req.SecureLevel = 1 // public for serverless - deployID, err := h.c.Deploy(ctx, deployReq, req) + deployID, err := h.model.Deploy(ctx, deployReq, req) if err != nil { slog.Error("failed to deploy model as serverless", slog.String("namespace", namespace), slog.String("name", name), slog.Any("currentUser", currentUser), slog.Any("req", req), slog.Any("error", err)) @@ -1532,7 +1534,7 @@ func (h *ModelHandler) GetDeployServerless(ctx *gin.Context) { return } - response, err := h.c.GetServerless(ctx, namespace, name, currentUser) + response, err := h.model.GetServerless(ctx, namespace, name, currentUser) if err != nil { slog.Error("failed to get model serverless endpoint", slog.String("namespace", namespace), slog.String("name", name), slog.Any("currentUser", currentUser), slog.Any("error", err)) diff --git a/api/handler/model_test.go b/api/handler/model_test.go new file mode 100644 index 00000000..5a082cdc --- /dev/null +++ b/api/handler/model_test.go @@ -0,0 +1,561 @@ +package handler + +import ( + "fmt" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +type ModelTester struct { + *GinTester + handler *ModelHandler + mocks struct { + model *mockcomponent.MockModelComponent + sensitive *mockcomponent.MockSensitiveComponent + repo *mockcomponent.MockRepoComponent + } +} + +func NewModelTester(t *testing.T) *ModelTester { + tester := &ModelTester{GinTester: NewGinTester()} + tester.mocks.model = mockcomponent.NewMockModelComponent(t) + tester.mocks.sensitive = mockcomponent.NewMockSensitiveComponent(t) + tester.mocks.repo = mockcomponent.NewMockRepoComponent(t) + + tester.handler = &ModelHandler{ + model: tester.mocks.model, sensitive: tester.mocks.sensitive, + repo: tester.mocks.repo, + } + tester.WithParam("name", "r") + tester.WithParam("namespace", "u") + return tester + +} + +func (t *ModelTester) WithHandleFunc(fn func(h *ModelHandler) gin.HandlerFunc) *ModelTester { + t.ginHandler = fn(t.handler) + return t + +} + +func TestModelHandler_Index(t *testing.T) { + cases := []struct { + sort string + source string + error bool + }{ + {"most_download", "local", false}, + {"foo", "local", true}, + {"most_download", "bar", true}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("%+v", c), func(t *testing.T) { + + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.Index + }) + + if !c.error { + tester.mocks.model.EXPECT().Index(tester.ctx, &types.RepoFilter{ + Search: "foo", + Sort: c.sort, + Source: c.source, + }, 10, 1, false).Return([]*types.Model{ + {Name: "cc"}, + }, 100, nil) + } + + tester.AddPagination(1, 10).WithQuery("search", "foo"). + WithQuery("sort", c.sort). + WithQuery("source", c.source).Execute() + + if c.error { + require.Equal(t, 400, tester.response.Code) + } else { + tester.ResponseEqSimple(t, 200, gin.H{ + "data": []*types.Model{{Name: "cc"}}, + "total": 100, + }) + } + }) + } +} + +func TestModelHandler_Create(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.Create + }) + tester.RequireUser(t) + + req := &types.CreateModelReq{CreateRepoReq: types.CreateRepoReq{Username: "u"}} + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, req).Return(true, nil) + tester.mocks.model.EXPECT().Create(tester.ctx, req).Return(&types.Model{Name: "m"}, nil) + tester.WithBody(t, req).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Model{Name: "m"}) +} + +func TestModelHandler_Update(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.Update + }) + tester.RequireUser(t) + + req := &types.UpdateModelReq{UpdateRepoReq: types.UpdateRepoReq{}} + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, req).Return(true, nil) + tester.mocks.model.EXPECT().Update(tester.ctx, &types.UpdateModelReq{ + UpdateRepoReq: types.UpdateRepoReq{ + Namespace: "u", + Name: "r", + Username: "u", + }, + }).Return(&types.Model{Name: "m"}, nil) + tester.WithBody(t, req).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Model{Name: "m"}) +} + +func TestModelHandler_Delete(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.Delete + }) + tester.RequireUser(t) + + tester.mocks.model.EXPECT().Delete(tester.ctx, "u", "r", "u").Return(nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestModelHandler_Show(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.Show + }) + + tester.WithUser() + tester.mocks.model.EXPECT().Show(tester.ctx, "u", "r", "u", false).Return(&types.Model{ + Name: "m", + }, nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Model{Name: "m"}) +} + +func TestModelHandler_SDKModelInfo(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.SDKModelInfo + }) + + tester.WithUser() + tester.mocks.model.EXPECT().SDKModelInfo(tester.ctx, "u", "r", "main", "u").Return(&types.SDKModelInfo{ + ID: "m", + }, nil) + tester.WithParam("ref", "main").Execute() + + tester.ResponseEqSimple(t, 200, &types.SDKModelInfo{ID: "m"}) +} + +func TestModelHandler_Relations(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.Relations + }) + + tester.WithUser() + tester.mocks.model.EXPECT().Relations(tester.ctx, "u", "r", "u").Return(&types.Relations{ + Models: []*types.Model{{Name: "m1"}}, + }, nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Relations{ + Models: []*types.Model{{Name: "m1"}}, + }) +} + +func TestModelHandler_SetRelations(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.SetRelations + }) + tester.RequireUser(t) + + req := &types.RelationDatasets{ + Namespace: "u", + Name: "r", + CurrentUser: "u", + } + tester.mocks.model.EXPECT().SetRelationDatasets(tester.ctx, *req).Return(nil) + tester.WithBody(t, &types.RelationDatasets{}).Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestModelHandler_AddDatasetRelation(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.AddDatasetRelation + }) + tester.RequireUser(t) + + req := &types.RelationDataset{ + Namespace: "u", + Name: "r", + CurrentUser: "u", + } + tester.mocks.model.EXPECT().AddRelationDataset(tester.ctx, *req).Return(nil) + tester.WithBody(t, &types.RelationDataset{}).Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestModelHandler_DelDatasetRelation(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.DelDatasetRelation + }) + tester.RequireUser(t) + + req := &types.RelationDataset{ + Namespace: "u", + Name: "r", + CurrentUser: "u", + } + tester.mocks.model.EXPECT().DelRelationDataset(tester.ctx, *req).Return(nil) + tester.WithBody(t, &types.RelationDataset{}).Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestModelHandler_DeployDedicated(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.DeployDedicated + }) + tester.RequireUser(t) + + tester.mocks.repo.EXPECT().AllowReadAccess(tester.ctx, types.ModelRepo, "u", "r", "u").Return(true, nil) + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, &types.ModelRunReq{ + Revision: "main", + MinReplica: 1, + MaxReplica: 2, + }).Return(true, nil) + tester.mocks.model.EXPECT().Deploy(tester.ctx, types.DeployActReq{ + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployType: types.InferenceType, + }, types.ModelRunReq{MinReplica: 1, MaxReplica: 2, Revision: "main"}).Return(123, nil) + + tester.WithBody(t, &types.ModelRunReq{MinReplica: 1, MaxReplica: 2}).Execute() + + tester.ResponseEq(t, 200, tester.OKText, types.DeployRepo{DeployID: 123}) +} + +func TestModelHandler_FinetuneCreate(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.FinetuneCreate + }) + tester.RequireUser(t) + + tester.mocks.repo.EXPECT().AllowAdminAccess(tester.ctx, types.ModelRepo, "u", "r", "u").Return(true, nil) + + tester.mocks.model.EXPECT().Deploy(tester.ctx, types.DeployActReq{ + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployType: types.FinetuneType, + }, types.ModelRunReq{MinReplica: 1, MaxReplica: 1, Revision: "main", SecureLevel: 2}).Return(123, nil) + + tester.WithBody(t, &types.ModelRunReq{MinReplica: 1, MaxReplica: 2, Revision: "main"}).Execute() + + tester.ResponseEq(t, 200, tester.OKText, types.DeployRepo{DeployID: 123}) + +} + +func TestModelHandler_DeployDelete(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.DeployDelete + }) + tester.RequireUser(t) + + tester.mocks.repo.EXPECT().DeleteDeploy(tester.ctx, types.DeployActReq{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployID: 1, + DeployType: types.InferenceType, + }).Return(nil) + + tester.WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestModelHandler_FinetuneDelete(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.FinetuneDelete + }) + tester.RequireUser(t) + + tester.mocks.repo.EXPECT().DeleteDeploy(tester.ctx, types.DeployActReq{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployID: 1, + DeployType: types.FinetuneType, + }).Return(nil) + + tester.WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestModelHandler_DeployStop(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.DeployStop + }) + tester.RequireUser(t) + + tester.mocks.repo.EXPECT().DeployStop(tester.ctx, types.DeployActReq{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployID: 1, + DeployType: types.InferenceType, + }).Return(nil) + + tester.WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestModelHandler_DeployStart(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.DeployStart + }) + tester.RequireUser(t) + + tester.mocks.repo.EXPECT().DeployStart(tester.ctx, types.DeployActReq{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployID: 1, + DeployType: types.InferenceType, + }).Return(nil) + + tester.WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestModelHandler_FinetuneStop(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.FinetuneStop + }) + tester.RequireUser(t) + + tester.mocks.repo.EXPECT().DeployStop(tester.ctx, types.DeployActReq{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployID: 1, + DeployType: types.FinetuneType, + }).Return(nil) + + tester.WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestModelHandler_FinetuneStart(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.FinetuneStart + }) + tester.RequireUser(t) + + tester.mocks.repo.EXPECT().DeployStart(tester.ctx, types.DeployActReq{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployID: 1, + DeployType: types.FinetuneType, + }).Return(nil) + + tester.WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestModelHandler_ListByRuntimeFrameworkID(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.ListByRuntimeFrameworkID + }) + tester.RequireUser(t) + + tester.WithQuery("deploy_type", "").AddPagination(1, 10).WithParam("id", "1") + tester.mocks.model.EXPECT().ListModelsByRuntimeFrameworkID( + tester.ctx, "u", 10, 1, int64(1), types.InferenceType, + ).Return([]types.Model{{Name: "foo"}}, 100, nil) + tester.Execute() + tester.ResponseEqSimple(t, 200, gin.H{ + "data": []types.Model{{Name: "foo"}}, + "total": 100, + }) +} + +func TestModelHandler_ListAllRuntimeFramework(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.ListAllRuntimeFramework + }) + tester.RequireUser(t) + + tester.mocks.model.EXPECT().ListAllByRuntimeFramework( + tester.ctx, "u", + ).Return([]database.RuntimeFramework{{FrameName: "foo"}}, nil) + tester.Execute() + tester.ResponseEqSimple(t, 200, gin.H{ + "data": []database.RuntimeFramework{{FrameName: "foo"}}, + }) +} + +func TestModelHandler_UpdateModelRuntimeFramework(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.UpdateModelRuntimeFrameworks + }) + + tester.WithUser().WithQuery("deploy_type", "").AddPagination(1, 10).WithParam("id", "1") + tester.mocks.model.EXPECT().SetRuntimeFrameworkModes( + tester.ctx, types.InferenceType, int64(1), []string{"foo"}, + ).Return([]string{"bar"}, nil) + tester.WithBody(t, types.RuntimeFrameworkModels{ + Models: []string{"foo"}, + }).Execute() + tester.ResponseEq(t, 200, tester.OKText, []string{"bar"}) +} + +func TestModelHandler_DeleteModelRuntimeFramework(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.DeleteModelRuntimeFrameworks + }) + + tester.WithUser().WithQuery("deploy_type", "").AddPagination(1, 10).WithParam("id", "1") + tester.mocks.model.EXPECT().DeleteRuntimeFrameworkModes( + tester.ctx, types.InferenceType, int64(1), []string{"foo"}, + ).Return([]string{"bar"}, nil) + tester.WithBody(t, types.RuntimeFrameworkModels{ + Models: []string{"foo"}, + }).Execute() + tester.ResponseEq(t, 200, tester.OKText, []string{"bar"}) +} + +func TestModelHandler_ListModelsOfRuntimeFrameworks(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.ListModelsOfRuntimeFrameworks + }) + tester.RequireUser(t) + + tester.WithQuery("deploy_type", "").AddPagination(1, 10).WithParam("id", "1") + tester.mocks.model.EXPECT().ListModelsOfRuntimeFrameworks( + tester.ctx, "u", "foo", "most_downloads", 10, 1, types.InferenceType, + ).Return([]types.Model{{Name: "foo"}}, 100, nil) + tester.WithQuery("search", "foo").WithQuery("sort", "most_downloads").Execute() + + tester.ResponseEqSimple(t, 200, gin.H{ + "data": []types.Model{{Name: "foo"}}, + "total": 100, + }) +} + +func TestModelHandler_AllFiles(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.AllFiles + }) + + tester.mocks.repo.EXPECT().AllFiles(tester.ctx, types.GetAllFilesReq{ + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + CurrentUser: "u", + }).Return([]*types.File{{Name: "foo"}}, nil) + tester.WithUser().Execute() + + tester.ResponseEq(t, 200, tester.OKText, []*types.File{{Name: "foo"}}) +} + +func TestModelHandler_DeployServerless(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.DeployServerless + }) + tester.RequireUser(t) + + tester.mocks.model.EXPECT().Deploy(tester.ctx, types.DeployActReq{ + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployType: types.ServerlessType, + }, types.ModelRunReq{MinReplica: 1, MaxReplica: 2, Revision: "main", SecureLevel: 1}).Return(123, nil) + + tester.WithBody(t, &types.ModelRunReq{MinReplica: 1, MaxReplica: 2}).Execute() + + tester.ResponseEq(t, 200, tester.OKText, types.DeployRepo{DeployID: 123}) +} + +func TestModelHandler_ServerlessStop(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.ServerlessStop + }) + tester.RequireUser(t) + + tester.mocks.repo.EXPECT().DeployStop(tester.ctx, types.DeployActReq{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployID: 1, + DeployType: types.ServerlessType, + }).Return(nil) + + tester.WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestModelHandler_ServerlessStart(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.ServerlessStart + }) + tester.RequireUser(t) + + tester.mocks.repo.EXPECT().DeployStart(tester.ctx, types.DeployActReq{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + CurrentUser: "u", + DeployID: 1, + DeployType: types.ServerlessType, + }).Return(nil) + + tester.WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestModelHandler_GetDeployServerless(t *testing.T) { + tester := NewModelTester(t).WithHandleFunc(func(h *ModelHandler) gin.HandlerFunc { + return h.GetDeployServerless + }) + + tester.mocks.model.EXPECT().GetServerless(tester.ctx, "u", "r", "u").Return(&types.DeployRepo{ + DeployID: 1, + }, nil) + tester.WithUser().Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.DeployRepo{DeployID: 1}) +} diff --git a/api/handler/user.go b/api/handler/user.go index c062710f..41984d76 100644 --- a/api/handler/user.go +++ b/api/handler/user.go @@ -20,12 +20,12 @@ func NewUserHandler(config *config.Config) (*UserHandler, error) { return nil, err } return &UserHandler{ - c: uc, + user: uc, }, nil } type UserHandler struct { - c component.UserComponent + user component.UserComponent } // GetUserDatasets godoc @@ -53,7 +53,7 @@ func (h *UserHandler) Datasets(ctx *gin.Context) { req.CurrentUser = httpbase.GetCurrentUser(ctx) req.Page = page req.PageSize = per - ds, total, err := h.c.Datasets(ctx, &req) + ds, total, err := h.user.Datasets(ctx, &req) if err != nil { slog.Error("Failed to gat user datasets", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -94,7 +94,7 @@ func (h *UserHandler) Models(ctx *gin.Context) { req.CurrentUser = httpbase.GetCurrentUser(ctx) req.Page = page req.PageSize = per - ms, total, err := h.c.Models(ctx, &req) + ms, total, err := h.user.Models(ctx, &req) if err != nil { slog.Error("Failed to gat user models", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -136,7 +136,7 @@ func (h *UserHandler) Codes(ctx *gin.Context) { req.CurrentUser = httpbase.GetCurrentUser(ctx) req.Page = page req.PageSize = per - ms, total, err := h.c.Codes(ctx, &req) + ms, total, err := h.user.Codes(ctx, &req) if err != nil { slog.Error("Failed to gat user codes", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -177,7 +177,7 @@ func (h *UserHandler) Spaces(ctx *gin.Context) { req.CurrentUser = httpbase.GetCurrentUser(ctx) req.Page = page req.PageSize = per - ms, total, err := h.c.Spaces(ctx, &req) + ms, total, err := h.user.Spaces(ctx, &req) if err != nil { slog.Error("Failed to gat user space", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -219,8 +219,8 @@ func (h *UserHandler) LikesAdd(ctx *gin.Context) { httpbase.ServerError(ctx, err) return } - req.Repo_id = repo_id - err = h.c.AddLikes(ctx, &req) + req.RepoID = repo_id + err = h.user.AddLikes(ctx, &req) if err != nil { httpbase.ServerError(ctx, err) return @@ -257,7 +257,7 @@ func (h *UserHandler) LikesCollections(ctx *gin.Context) { req.CurrentUser = httpbase.GetCurrentUser(ctx) req.Page = page req.PageSize = per - ms, total, err := h.c.LikesCollection(ctx, &req) + ms, total, err := h.user.LikesCollection(ctx, &req) if err != nil { slog.Error("Failed to get user collections", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -296,7 +296,7 @@ func (h *UserHandler) UserCollections(ctx *gin.Context) { req.CurrentUser = httpbase.GetCurrentUser(ctx) req.Page = page req.PageSize = per - ms, total, err := h.c.Collections(ctx, &req) + ms, total, err := h.user.Collections(ctx, &req) if err != nil { slog.Error("Failed to get user collections", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -331,13 +331,13 @@ func (h *UserHandler) LikeCollection(ctx *gin.Context) { } var req types.UserLikesRequest req.CurrentUser = currentUser - collection_id, err := strconv.ParseInt(ctx.Param("id"), 10, 64) + collectionID, err := strconv.ParseInt(ctx.Param("id"), 10, 64) if err != nil { httpbase.ServerError(ctx, err) return } - req.Collection_id = collection_id - err = h.c.LikeCollection(ctx, &req) + req.CollectionID = collectionID + err = h.user.LikeCollection(ctx, &req) if err != nil { httpbase.ServerError(ctx, err) return @@ -370,8 +370,8 @@ func (h *UserHandler) UnLikeCollection(ctx *gin.Context) { httpbase.ServerError(ctx, err) return } - req.Collection_id = collection_id - err = h.c.UnLikeCollection(ctx, &req) + req.CollectionID = collection_id + err = h.user.UnLikeCollection(ctx, &req) if err != nil { httpbase.ServerError(ctx, err) return @@ -406,9 +406,9 @@ func (h *UserHandler) LikesDelete(ctx *gin.Context) { httpbase.ServerError(ctx, err) return } - req.Repo_id = repo_id + req.RepoID = repo_id // slog.Info("user.likes.delete.req=%v", req) - err = h.c.DeleteLikes(ctx, &req) + err = h.user.DeleteLikes(ctx, &req) if err != nil { httpbase.ServerError(ctx, err) return @@ -446,7 +446,7 @@ func (h *UserHandler) LikesSpaces(ctx *gin.Context) { req.Page = page req.PageSize = per - ms, total, err := h.c.LikesSpaces(ctx, &req) + ms, total, err := h.user.LikesSpaces(ctx, &req) if err != nil { slog.Error("Failed to gat user space", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -492,7 +492,7 @@ func (h *UserHandler) LikesCodes(ctx *gin.Context) { req.CurrentUser = currentUser req.Page = page req.PageSize = per - ms, total, err := h.c.LikesCodes(ctx, &req) + ms, total, err := h.user.LikesCodes(ctx, &req) if err != nil { slog.Error("Failed to gat user codes", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -539,7 +539,7 @@ func (h *UserHandler) LikesModels(ctx *gin.Context) { req.CurrentUser = currentUser req.Page = page req.PageSize = per - ms, total, err := h.c.LikesModels(ctx, &req) + ms, total, err := h.user.LikesModels(ctx, &req) if err != nil { slog.Error("Failed to gat user models", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -586,7 +586,7 @@ func (h *UserHandler) LikesDatasets(ctx *gin.Context) { req.CurrentUser = currentUser req.Page = page req.PageSize = per - ds, total, err := h.c.LikesDatasets(ctx, &req) + ds, total, err := h.user.LikesDatasets(ctx, &req) if err != nil { slog.Error("Failed to gat user datasets", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -683,7 +683,7 @@ func (h *UserHandler) GetRunDeploys(ctx *gin.Context) { req.PageSize = per req.RepoType = repoType req.DeployType = deployType - ds, total, err := h.c.ListDeploys(ctx, repoType, &req) + ds, total, err := h.user.ListDeploys(ctx, repoType, &req) if err != nil { slog.Error("Failed to get deploy repo list", slog.Any("error", err), slog.Any("req", req)) httpbase.ServerError(ctx, err) @@ -743,7 +743,7 @@ func (h *UserHandler) GetFinetuneInstances(ctx *gin.Context) { req.CurrentUser = currentUser req.Page = page req.PageSize = per - ds, total, err := h.c.ListInstances(ctx, &req) + ds, total, err := h.user.ListInstances(ctx, &req) if err != nil { slog.Error("Failed to get instance list", slog.Any("error", err), slog.Any("req", req)) httpbase.ServerError(ctx, err) @@ -799,7 +799,7 @@ func (h *UserHandler) GetRunServerless(ctx *gin.Context) { req.PageSize = per req.RepoType = types.ModelRepo req.DeployType = types.ServerlessType - ds, total, err := h.c.ListServerless(ctx, req) + ds, total, err := h.user.ListServerless(ctx, req) if err != nil { slog.Error("Failed to get serverless list", slog.Any("error", err), slog.Any("req", req)) httpbase.ServerError(ctx, err) @@ -840,7 +840,7 @@ func (h *UserHandler) Prompts(ctx *gin.Context) { req.CurrentUser = httpbase.GetCurrentUser(ctx) req.Page = page req.PageSize = per - ds, total, err := h.c.Prompts(ctx, &req) + ds, total, err := h.user.Prompts(ctx, &req) if err != nil { slog.Error("Failed to get user prompts", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -887,7 +887,7 @@ func (h *UserHandler) GetEvaluations(ctx *gin.Context) { req.CurrentUser = httpbase.GetCurrentUser(ctx) req.Page = page req.PageSize = per - ds, total, err := h.c.Evaluations(ctx, &req) + ds, total, err := h.user.Evaluations(ctx, &req) if err != nil { slog.Error("Failed to get user evaluations", slog.Any("error", err)) httpbase.ServerError(ctx, err) diff --git a/api/handler/user_test.go b/api/handler/user_test.go new file mode 100644 index 00000000..d8544d54 --- /dev/null +++ b/api/handler/user_test.go @@ -0,0 +1,439 @@ +package handler + +import ( + "testing" + + "github.com/gin-gonic/gin" + mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/common/types" +) + +type UserTester struct { + *GinTester + handler *UserHandler + mocks struct { + user *mockcomponent.MockUserComponent + } +} + +func NewUserTester(t *testing.T) *UserTester { + tester := &UserTester{GinTester: NewGinTester()} + tester.mocks.user = mockcomponent.NewMockUserComponent(t) + + tester.handler = &UserHandler{ + user: tester.mocks.user, + } + tester.WithParam("name", "u") + tester.WithParam("namespace", "r") + return tester +} + +func (t *UserTester) WithHandleFunc(fn func(h *UserHandler) gin.HandlerFunc) *UserTester { + t.ginHandler = fn(t.handler) + return t +} + +func TestUserHandler_Datasets(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.Datasets + }) + + tester.mocks.user.EXPECT().Datasets(tester.ctx, &types.UserDatasetsReq{ + Owner: "go", + CurrentUser: "u", + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + }).Return([]types.Dataset{{Name: "ds"}}, 100, nil) + tester.AddPagination(1, 10).WithUser().WithParam("username", "go").Execute() + tester.ResponseEqSimple(t, 200, gin.H{ + "message": "OK", + "data": []types.Dataset{{Name: "ds"}}, + "total": 100, + }) +} + +func TestUserHandler_Models(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.Models + }) + + tester.mocks.user.EXPECT().Models(tester.ctx, &types.UserDatasetsReq{ + Owner: "go", + CurrentUser: "u", + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + }).Return([]types.Model{{Name: "ds"}}, 100, nil) + tester.AddPagination(1, 10).WithUser().WithParam("username", "go").Execute() + tester.ResponseEqSimple(t, 200, gin.H{ + "message": "OK", + "data": []types.Model{{Name: "ds"}}, + "total": 100, + }) +} + +func TestUserHandler_Codes(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.Codes + }) + + tester.mocks.user.EXPECT().Codes(tester.ctx, &types.UserDatasetsReq{ + Owner: "go", + CurrentUser: "u", + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + }).Return([]types.Code{{Name: "ds"}}, 100, nil) + tester.AddPagination(1, 10).WithUser().WithParam("username", "go").Execute() + tester.ResponseEqSimple(t, 200, gin.H{ + "message": "OK", + "data": []types.Code{{Name: "ds"}}, + "total": 100, + }) +} + +func TestUserHandler_Spaces(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.Spaces + }) + + tester.mocks.user.EXPECT().Spaces(tester.ctx, &types.UserDatasetsReq{ + Owner: "go", + CurrentUser: "u", + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + }).Return([]types.Space{{Name: "ds"}}, 100, nil) + tester.AddPagination(1, 10).WithUser().WithParam("username", "go").Execute() + tester.ResponseEqSimple(t, 200, gin.H{ + "message": "OK", + "data": []types.Space{{Name: "ds"}}, + "total": 100, + }) +} + +func TestUserHandler_LikesAdd(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.LikesAdd + }) + tester.RequireUser(t) + + tester.mocks.user.EXPECT().AddLikes(tester.ctx, &types.UserLikesRequest{ + Username: "go", + CurrentUser: "u", + RepoID: 123, + }).Return(nil) + tester.WithParam("username", "go").WithParam("repo_id", "123").Execute() + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestUserHandler_LikesCollections(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.LikesCollections + }) + tester.RequireUser(t) + + tester.mocks.user.EXPECT().LikesCollection(tester.ctx, &types.UserCollectionReq{ + Owner: "go", + CurrentUser: "u", + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + }).Return([]types.Collection{{ID: 1}}, 100, nil) + tester.WithParam("username", "go").AddPagination(1, 10).Execute() + tester.ResponseEqSimple(t, 200, gin.H{ + "data": []types.Collection{{ID: 1}}, + "total": 100, + }) +} + +func TestUserHandler_UserCollections(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.UserCollections + }) + + tester.mocks.user.EXPECT().Collections(tester.ctx, &types.UserCollectionReq{ + Owner: "go", + CurrentUser: "u", + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + }).Return([]types.Collection{{ID: 1}}, 100, nil) + tester.WithParam("username", "go").WithUser().AddPagination(1, 10).Execute() + tester.ResponseEqSimple(t, 200, gin.H{ + "data": []types.Collection{{ID: 1}}, + "total": 100, + }) +} + +func TestUserHandler_LikeCollection(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.LikeCollection + }) + tester.RequireUser(t) + + tester.mocks.user.EXPECT().LikeCollection(tester.ctx, &types.UserLikesRequest{ + CurrentUser: "u", + CollectionID: 123, + }).Return(nil) + tester.WithParam("id", "123").Execute() + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestUserHandler_UnLikeCollection(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.UnLikeCollection + }) + tester.RequireUser(t) + + tester.mocks.user.EXPECT().UnLikeCollection(tester.ctx, &types.UserLikesRequest{ + CurrentUser: "u", + CollectionID: 123, + }).Return(nil) + tester.WithParam("id", "123").Execute() + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestUserHandler_LikesDelete(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.LikesDelete + }) + tester.RequireUser(t) + + tester.mocks.user.EXPECT().DeleteLikes(tester.ctx, &types.UserLikesRequest{ + CurrentUser: "u", + RepoID: 123, + }).Return(nil) + tester.WithParam("repo_id", "123").Execute() + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestUserHandler_LikesSpaces(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.LikesSpaces + }) + tester.RequireUser(t) + + tester.mocks.user.EXPECT().LikesSpaces(tester.ctx, &types.UserSpacesReq{ + Owner: "foo", + CurrentUser: "u", + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + }).Return([]types.Space{{Name: "sp"}}, 100, nil) + tester.WithParam("username", "foo").AddPagination(1, 10).Execute() + tester.ResponseEqSimple(t, 200, gin.H{ + "message": "OK", + "data": []types.Space{{Name: "sp"}}, + "total": 100, + }) +} + +func TestUserHandler_LikesCodes(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.LikesCodes + }) + tester.RequireUser(t) + + tester.mocks.user.EXPECT().LikesCodes(tester.ctx, &types.UserDatasetsReq{ + Owner: "foo", + CurrentUser: "u", + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + }).Return([]types.Code{{Name: "sp"}}, 100, nil) + tester.WithParam("username", "foo").AddPagination(1, 10).Execute() + tester.ResponseEqSimple(t, 200, gin.H{ + "message": "OK", + "data": []types.Code{{Name: "sp"}}, + "total": 100, + }) +} + +func TestUserHandler_LikesModels(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.LikesModels + }) + tester.RequireUser(t) + + tester.mocks.user.EXPECT().LikesModels(tester.ctx, &types.UserDatasetsReq{ + Owner: "foo", + CurrentUser: "u", + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + }).Return([]types.Model{{Name: "sp"}}, 100, nil) + tester.WithParam("username", "foo").AddPagination(1, 10).Execute() + tester.ResponseEqSimple(t, 200, gin.H{ + "message": "OK", + "data": []types.Model{{Name: "sp"}}, + "total": 100, + }) +} + +func TestUserHandler_LikesDatasets(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.LikesDatasets + }) + tester.RequireUser(t) + + tester.mocks.user.EXPECT().LikesDatasets(tester.ctx, &types.UserDatasetsReq{ + Owner: "foo", + CurrentUser: "u", + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + }).Return([]types.Dataset{{Name: "sp"}}, 100, nil) + tester.WithParam("username", "foo").AddPagination(1, 10).Execute() + tester.ResponseEqSimple(t, 200, gin.H{ + "message": "OK", + "data": []types.Dataset{{Name: "sp"}}, + "total": 100, + }) +} + +func TestUserHandler_UserPermission(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.UserPermission + }) + tester.RequireUser(t) + + tester.Execute() + tester.ResponseEqSimple(t, 200, types.WhoamiResponse{ + Name: "u", + Auth: types.Auth{ + AccessToken: types.AccessToken{ + DisplayName: "u", + Role: "write", + }, + Type: "Bearer", + }, + }) +} + +func TestUserHandler_GetRunDeploys(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.GetRunDeploys + }) + tester.RequireUser(t) + + tester.mocks.user.EXPECT().ListDeploys(tester.ctx, types.ModelRepo, &types.DeployReq{ + CurrentUser: "u", + RepoType: types.ModelRepo, + DeployType: 1, + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + }).Return([]types.DeployRepo{{DeployID: 1}}, 100, nil) + tester.WithParam("username", "u").WithQuery("deploy_type", "").AddPagination(1, 10) + tester.WithParam("repo_type", "model").Execute() + + tester.ResponseEqSimple(t, 200, gin.H{ + "message": "OK", + "data": []types.DeployRepo{{DeployID: 1}}, + "total": 100, + }) +} + +func TestUserHandler_GetFinetuneInstances(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.GetFinetuneInstances + }) + + tester.mocks.user.EXPECT().ListInstances(tester.ctx, &types.UserRepoReq{ + CurrentUser: "u", + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + }).Return([]types.DeployRepo{{DeployID: 1}}, 100, nil) + tester.WithUser().WithParam("username", "u").WithQuery("deploy_type", "").AddPagination(1, 10) + tester.Execute() + + tester.ResponseEqSimple(t, 200, gin.H{ + "message": "OK", + "data": []types.DeployRepo{{DeployID: 1}}, + "total": 100, + }) +} + +func TestUserHandler_GetRunServerless(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.GetRunServerless + }) + tester.RequireUser(t) + + tester.mocks.user.EXPECT().ListServerless(tester.ctx, types.DeployReq{ + CurrentUser: "u", + RepoType: types.ModelRepo, + DeployType: 3, + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + }).Return([]types.DeployRepo{{DeployID: 1}}, 100, nil) + tester.WithParam("username", "u").WithQuery("deploy_type", "").AddPagination(1, 10) + tester.WithParam("repo_type", "model").Execute() + + tester.ResponseEqSimple(t, 200, gin.H{ + "message": "OK", + "data": []types.DeployRepo{{DeployID: 1}}, + "total": 100, + }) +} + +func TestUserHandler_Prompts(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.Prompts + }) + + tester.mocks.user.EXPECT().Prompts(tester.ctx, &types.UserPromptsReq{ + CurrentUser: "u", + Owner: "u", + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + }).Return([]types.PromptRes{{ID: 123}}, 100, nil) + tester.WithUser().WithParam("username", "u").AddPagination(1, 10).Execute() + + tester.ResponseEqSimple(t, 200, gin.H{ + "message": "OK", + "data": []types.PromptRes{{ID: 123}}, + "total": 100, + }) +} + +func TestUserHandler_GetEvaluations(t *testing.T) { + tester := NewUserTester(t).WithHandleFunc(func(h *UserHandler) gin.HandlerFunc { + return h.GetEvaluations + }) + tester.RequireUser(t) + + tester.mocks.user.EXPECT().Evaluations(tester.ctx, &types.UserEvaluationReq{ + CurrentUser: "u", + Owner: "u", + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + }).Return([]types.ArgoWorkFlowRes{{ID: 123}}, 100, nil) + tester.WithParam("username", "u").AddPagination(1, 10).Execute() + + tester.ResponseEqSimple(t, 200, gin.H{ + "data": []types.ArgoWorkFlowRes{{ID: 123}}, + "total": 100, + }) +} diff --git a/common/types/user.go b/common/types/user.go index 18d4ab65..4459713f 100644 --- a/common/types/user.go +++ b/common/types/user.go @@ -166,10 +166,10 @@ type User struct { } type UserLikesRequest struct { - Username string `json:"username"` - Repo_id int64 `json:"repo_id"` - Collection_id int64 `json:"collection_id"` - CurrentUser string `json:"current_user"` + Username string `json:"username"` + RepoID int64 `json:"repo_id"` + CollectionID int64 `json:"collection_id"` + CurrentUser string `json:"current_user"` } /* for HF compatible apis */ diff --git a/component/user.go b/component/user.go index 86cc31e9..85426990 100644 --- a/component/user.go +++ b/component/user.go @@ -296,7 +296,7 @@ func (c *userComponentImpl) AddLikes(ctx context.Context, req *types.UserLikesRe return newError } var likesRepoIDs []int64 - likesRepoIDs = append(likesRepoIDs, req.Repo_id) + likesRepoIDs = append(likesRepoIDs, req.RepoID) var opts []database.SelectOption opts = append(opts, database.Columns("id", "repository_type", "path", "user_id", "private")) @@ -314,7 +314,7 @@ func (c *userComponentImpl) AddLikes(ctx context.Context, req *types.UserLikesRe return fmt.Errorf("do not found likes repositories visiable to user:%s, %w", req.CurrentUser, err) } - err = c.userLikeStore.Add(ctx, user.ID, req.Repo_id) + err = c.userLikeStore.Add(ctx, user.ID, req.RepoID) return err } @@ -391,7 +391,7 @@ func (c *userComponentImpl) LikeCollection(ctx context.Context, req *types.UserL return newError } - collection, err := c.collectionStore.FindById(ctx, req.Collection_id) + collection, err := c.collectionStore.FindById(ctx, req.CollectionID) if err != nil { return fmt.Errorf("failed to get likes collection by id, error: %w", err) } @@ -400,7 +400,7 @@ func (c *userComponentImpl) LikeCollection(ctx context.Context, req *types.UserL return fmt.Errorf("no permission to like this collection for user:%s", req.CurrentUser) } - err = c.userLikeStore.LikeCollection(ctx, user.ID, req.Collection_id) + err = c.userLikeStore.LikeCollection(ctx, user.ID, req.CollectionID) return err } @@ -410,7 +410,7 @@ func (c *userComponentImpl) UnLikeCollection(ctx context.Context, req *types.Use newError := fmt.Errorf("failed to check for the presence of the user,error:%w", err) return newError } - err = c.userLikeStore.UnLikeCollection(ctx, user.ID, req.Collection_id) + err = c.userLikeStore.UnLikeCollection(ctx, user.ID, req.CollectionID) return err } @@ -420,7 +420,7 @@ func (c *userComponentImpl) DeleteLikes(ctx context.Context, req *types.UserLike newError := fmt.Errorf("failed to check for the presence of the user,error:%w", err) return newError } - err = c.userLikeStore.Delete(ctx, user.ID, req.Repo_id) + err = c.userLikeStore.Delete(ctx, user.ID, req.RepoID) return err } diff --git a/component/user_test.go b/component/user_test.go index efc2b3fb..0052ab62 100644 --- a/component/user_test.go +++ b/component/user_test.go @@ -136,10 +136,10 @@ func TestUserComponent_AddLikes(t *testing.T) { uc.mocks.stores.UserLikesMock().EXPECT().Add(ctx, int64(1), int64(123)).Return(nil) err := uc.AddLikes(ctx, &types.UserLikesRequest{ - Username: "user", - Repo_id: 123, - Collection_id: 456, - CurrentUser: "user", + Username: "user", + RepoID: 123, + CollectionID: 456, + CurrentUser: "user", }) require.Nil(t, err) } @@ -201,10 +201,10 @@ func TestUserComponent_LikeCollection(t *testing.T) { }, nil) uc.mocks.stores.UserLikesMock().EXPECT().LikeCollection(ctx, int64(1), int64(456)).Return(nil) err := uc.LikeCollection(ctx, &types.UserLikesRequest{ - Username: "user", - Repo_id: 123, - Collection_id: 456, - CurrentUser: "user", + Username: "user", + RepoID: 123, + CollectionID: 456, + CurrentUser: "user", }) require.Nil(t, err) } @@ -216,10 +216,10 @@ func TestUserComponent_UnLikeCollection(t *testing.T) { uc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ID: 1}, nil) uc.mocks.stores.UserLikesMock().EXPECT().UnLikeCollection(ctx, int64(1), int64(456)).Return(nil) err := uc.UnLikeCollection(ctx, &types.UserLikesRequest{ - Username: "user", - Repo_id: 123, - Collection_id: 456, - CurrentUser: "user", + Username: "user", + RepoID: 123, + CollectionID: 456, + CurrentUser: "user", }) require.Nil(t, err) } @@ -231,10 +231,10 @@ func TestUserComponent_DeleteLikes(t *testing.T) { uc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ID: 1}, nil) uc.mocks.stores.UserLikesMock().EXPECT().Delete(ctx, int64(1), int64(123)).Return(nil) err := uc.DeleteLikes(ctx, &types.UserLikesRequest{ - Username: "user", - Repo_id: 123, - Collection_id: 456, - CurrentUser: "user", + Username: "user", + RepoID: 123, + CollectionID: 456, + CurrentUser: "user", }) require.Nil(t, err) } From 378004a9e669a11dc13e23b0719d5fca414e3a0e Mon Sep 17 00:00:00 2001 From: yiling Date: Tue, 24 Dec 2024 17:53:53 +0800 Subject: [PATCH 23/34] Add opencsg check back to git http handler --- api/handler/git_http.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/api/handler/git_http.go b/api/handler/git_http.go index a9bcb86b..2c22bf57 100644 --- a/api/handler/git_http.go +++ b/api/handler/git_http.go @@ -175,6 +175,11 @@ func (h *GitHTTPHandler) LfsBatch(ctx *gin.Context) { return } + s3Internal := ctx.GetHeader("X-OPENCSG-S3-Internal") + if s3Internal == "true" { + ctx.Set("X-OPENCSG-S3-Internal", true) + } + objectResponse, err := h.gitHttp.BuildObjectResponse(ctx, batchRequest, isUpload) if err != nil { if errors.Is(err, component.ErrUnauthorized) { @@ -235,6 +240,11 @@ func (h *GitHTTPHandler) LfsDownload(ctx *gin.Context) { downloadRequest.CurrentUser = httpbase.GetCurrentUser(ctx) downloadRequest.SaveAs = ctx.Query("save_as") + s3Internal := ctx.GetHeader("X-OPENCSG-S3-Internal") + if s3Internal == "true" { + ctx.Set("X-OPENCSG-S3-Internal", true) + } + url, err := h.gitHttp.LfsDownload(ctx, downloadRequest) if err != nil { httpbase.ServerError(ctx, err) From 9ceab22a070f497f37caba71479ac92bd2a8ef3a Mon Sep 17 00:00:00 2001 From: SeanHH86 <154984842+SeanHH86@users.noreply.github.com> Date: Wed, 25 Dec 2024 10:13:21 +0800 Subject: [PATCH 24/34] cherry pick code of prompt (#221) --- api/handler/prompt.go | 393 --------------------------------------- api/router/api.go | 11 -- component/prompt.go | 195 ------------------- component/prompt_ce.go | 51 +++++ component/prompt_test.go | 230 ----------------------- 5 files changed, 51 insertions(+), 829 deletions(-) create mode 100644 component/prompt_ce.go diff --git a/api/handler/prompt.go b/api/handler/prompt.go index 6354c638..a8defe4b 100644 --- a/api/handler/prompt.go +++ b/api/handler/prompt.go @@ -1,14 +1,11 @@ package handler import ( - "encoding/json" "errors" "fmt" "log/slog" "net/http" "slices" - "strconv" - "strings" "time" "github.com/gin-gonic/gin" @@ -369,396 +366,6 @@ func (h *PromptHandler) DeletePrompt(ctx *gin.Context) { httpbase.OK(ctx, nil) } -// NewConversation godoc -// @Security ApiKey -// @Summary Create new conversation -// @Description Create new conversation -// @Tags Prompt -// @Accept json -// @Produce json -// @Param body body types.Conversation true "body" -// @Success 200 {object} types.Response{} "OK" -// @Failure 400 {object} types.APIBadRequest "Bad request" -// @Failure 500 {object} types.APIInternalServerError "Internal server error" -// @Router /prompts/conversations [post] -func (h *PromptHandler) NewConversation(ctx *gin.Context) { - currentUser := httpbase.GetCurrentUser(ctx) - if currentUser == "" { - httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) - return - } - var body *types.ConversationTitle - if err := ctx.ShouldBindJSON(&body); err != nil { - slog.Error("Bad request conversation body", "error", err) - httpbase.BadRequest(ctx, err.Error()) - return - } - req := types.ConversationTitleReq{ - CurrentUser: currentUser, - ConversationTitle: types.ConversationTitle{ - Uuid: body.Uuid, - Title: body.Title, - }, - } - resp, err := h.pc.NewConversation(ctx, req) - if err != nil { - slog.Error("Failed to create conversation", slog.Any("req", req), slog.Any("error", err)) - httpbase.ServerError(ctx, err) - return - } - httpbase.OK(ctx, resp) -} - -// ListConversation godoc -// @Security ApiKey -// @Summary List conversations of user -// @Description List conversations of user -// @Tags Prompt -// @Accept json -// @Produce json -// @Success 200 {object} types.Response{} "OK" -// @Failure 400 {object} types.APIBadRequest "Bad request" -// @Failure 500 {object} types.APIInternalServerError "Internal server error" -// @Router /prompts/conversations [get] -func (h *PromptHandler) ListConversation(ctx *gin.Context) { - currentUser := httpbase.GetCurrentUser(ctx) - if currentUser == "" { - httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) - return - } - data, err := h.pc.ListConversationsByUserID(ctx, currentUser) - if err != nil { - slog.Error("Failed to list conversations", slog.Any("currentUser", currentUser), slog.Any("error", err)) - httpbase.ServerError(ctx, err) - return - } - httpbase.OK(ctx, data) -} - -// GetConversation godoc -// @Security ApiKey -// @Summary Get a conversation by uuid -// @Description Get a conversation by uuid -// @Tags Prompt -// @Accept json -// @Produce json -// @Param id path string true "conversation uuid" -// @Success 200 {object} types.Response{} "OK" -// @Failure 400 {object} types.APIBadRequest "Bad request" -// @Failure 500 {object} types.APIInternalServerError "Internal server error" -// @Router /prompts/conversations/{id} [get] -func (h *PromptHandler) GetConversation(ctx *gin.Context) { - currentUser := httpbase.GetCurrentUser(ctx) - if currentUser == "" { - httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) - return - } - uuid := ctx.Param("id") - if len(uuid) < 1 { - slog.Error("Bad request conversation uuid") - httpbase.BadRequest(ctx, "uuid is empty") - return - } - req := types.ConversationReq{ - CurrentUser: currentUser, - Conversation: types.Conversation{ - Uuid: uuid, - }, - } - conversation, err := h.pc.GetConversation(ctx, req) - if err != nil { - slog.Error("Failed to get conversation by id", slog.Any("req", req), slog.Any("error", err)) - httpbase.ServerError(ctx, err) - return - } - httpbase.OK(ctx, conversation) -} - -// SubmitMessage godoc -// @Security ApiKey -// @Summary Submit a conversation message -// @Description Submit a conversation message -// @Tags Prompt -// @Accept json -// @Produce json -// @Param id path string true "conversation uuid" -// @Param body body types.Conversation true "body" -// @Success 200 {object} types.Response{} "OK" -// @Failure 400 {object} types.APIBadRequest "Bad request" -// @Failure 500 {object} types.APIInternalServerError "Internal server error" -// @Router /prompts/conversations/{id} [post] -func (h *PromptHandler) SubmitMessage(ctx *gin.Context) { - currentUser := httpbase.GetCurrentUser(ctx) - if currentUser == "" { - httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) - return - } - uuid := ctx.Param("id") - if len(uuid) < 1 { - slog.Error("Bad request conversation uuid") - httpbase.BadRequest(ctx, "uuid is empty") - return - } - var body *types.Conversation - if err := ctx.ShouldBindJSON(&body); err != nil { - slog.Error("Bad request messsage body", "error", err) - httpbase.BadRequest(ctx, err.Error()) - return - } - req := types.ConversationReq{ - CurrentUser: currentUser, - Conversation: types.Conversation{ - Uuid: uuid, - Message: body.Message, - Temperature: body.Temperature, - }, - } - - ch, err := h.pc.SubmitMessage(ctx, req) - if err != nil { - slog.Error("Failed to submit message", slog.Any("req", req), slog.Any("error", err)) - httpbase.ServerError(ctx, err) - return - } - ctx.Writer.Header().Set("Content-Type", "text/event-stream") - ctx.Writer.Header().Set("Cache-Control", "no-cache") - ctx.Writer.Header().Set("Connection", "keep-alive") - ctx.Writer.Header().Set("Transfer-Encoding", "chunked") - - ctx.Writer.WriteHeader(http.StatusOK) - ctx.Writer.Flush() - - generatedText := "" - for { - select { - case <-ctx.Request.Context().Done(): - slog.Debug("generate respose end for context done", slog.Any("error", ctx.Request.Context().Err())) - res := types.Conversation{ - Uuid: uuid, - Message: generatedText, - } - _, err = h.pc.SaveGeneratedText(ctx, res) - if err != nil { - slog.Error("fail to save generated message for request cancel", slog.Any("res", res), slog.Any("error", err)) - httpbase.ServerError(ctx, err) - } - return - case data, ok := <-ch: - if ok { - if len(data) < 1 { - continue - } - data := strings.TrimSpace(strings.TrimPrefix(data, "data:")) - ctx.SSEvent("data", data) - ctx.Writer.Flush() - resp := types.LLMResponse{} - err := json.Unmarshal([]byte(data), &resp) - if err != nil { - slog.Warn("unmarshal llm response", slog.Any("data", data), slog.Any("error", err)) - continue - } - if len(resp.Choices) < 1 { - continue - } - generatedText = fmt.Sprintf("%s%s", generatedText, resp.Choices[0].Delta.Content) - } else { - slog.Debug("stream channel closed") - res := types.Conversation{ - Uuid: uuid, - Message: generatedText, - } - msg, err := h.pc.SaveGeneratedText(ctx, res) - if err != nil { - slog.Error("fail to save generated message for stream close", slog.Any("res", res), slog.Any("error", err)) - httpbase.ServerError(ctx, err) - } - ctx.SSEvent("data", fmt.Sprintf("{\"msg_id\": %d}", msg.ID)) - ctx.Writer.Flush() - return - } - } - } -} - -// UpdateConversation godoc -// @Security ApiKey -// @Summary Update a conversation title -// @Description Update a conversation title -// @Tags Prompt -// @Accept json -// @Produce json -// @Param id path string true "conversation uuid" -// @Param body body types.ConversationTitle true "body" -// @Success 200 {object} types.Response{} "OK" -// @Failure 400 {object} types.APIBadRequest "Bad request" -// @Failure 500 {object} types.APIInternalServerError "Internal server error" -// @Router /prompts/conversations/{id} [put] -func (h *PromptHandler) UpdateConversation(ctx *gin.Context) { - currentUser := httpbase.GetCurrentUser(ctx) - if currentUser == "" { - httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) - return - } - uuid := ctx.Param("id") - if len(uuid) < 1 { - slog.Error("Bad request conversation uuid") - httpbase.BadRequest(ctx, "uuid is empty") - return - } - var body *types.ConversationTitle - if err := ctx.ShouldBindJSON(&body); err != nil { - slog.Error("Bad request messsage body", "error", err) - httpbase.BadRequest(ctx, err.Error()) - return - } - - req := types.ConversationTitleReq{ - CurrentUser: currentUser, - ConversationTitle: types.ConversationTitle{ - Uuid: uuid, - Title: body.Title, - }, - } - resp, err := h.pc.UpdateConversation(ctx, req) - if err != nil { - slog.Error("Failed to update conversation", slog.Any("req", req), slog.Any("error", err)) - httpbase.ServerError(ctx, err) - return - } - httpbase.OK(ctx, resp) -} - -// DeleteConversation godoc -// @Security ApiKey -// @Summary Delete a conversation -// @Description Delete a conversation -// @Tags Prompt -// @Accept json -// @Produce json -// @Param id path string true "conversation uuid" -// @Success 200 {object} types.Response{} "OK" -// @Failure 400 {object} types.APIBadRequest "Bad request" -// @Failure 500 {object} types.APIInternalServerError "Internal server error" -// @Router /prompts/conversations/{id} [delete] -func (h *PromptHandler) RemoveConversation(ctx *gin.Context) { - currentUser := httpbase.GetCurrentUser(ctx) - if currentUser == "" { - httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) - return - } - uuid := ctx.Param("id") - if len(uuid) < 1 { - slog.Error("Bad request conversation uuid") - httpbase.BadRequest(ctx, "uuid is empty") - return - } - req := types.ConversationReq{ - CurrentUser: currentUser, - Conversation: types.Conversation{ - Uuid: uuid, - }, - } - err := h.pc.RemoveConversation(ctx, req) - if err != nil { - slog.Error("Failed to remove conversation by id", slog.Any("req", req), slog.Any("error", err)) - httpbase.ServerError(ctx, err) - return - } - httpbase.OK(ctx, nil) -} - -// LikeMessage godoc -// @Security ApiKey -// @Summary Like a conversation message -// @Description Like a conversation message -// @Tags Prompt -// @Accept json -// @Produce json -// @Param uuid path string true "conversation uuid" -// @Param id path string true "message id" -// @Success 200 {object} types.Response{} "OK" -// @Failure 400 {object} types.APIBadRequest "Bad request" -// @Failure 500 {object} types.APIInternalServerError "Internal server error" -// @Router /prompts/conversations/{id}/message/{msgid}/like [put] -func (h *PromptHandler) LikeMessage(ctx *gin.Context) { - currentUser := httpbase.GetCurrentUser(ctx) - if currentUser == "" { - httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) - return - } - uuid := ctx.Param("id") - if len(uuid) < 1 { - slog.Error("Bad request conversation uuid") - httpbase.BadRequest(ctx, "uuid is empty") - return - } - msgid := ctx.Param("msgid") - idInt, err := strconv.ParseInt(msgid, 10, 64) - if err != nil { - slog.Error("Bad request message id", slog.Any("msgid", msgid), slog.Any("error", err)) - httpbase.BadRequest(ctx, err.Error()) - return - } - req := types.ConversationMessageReq{ - Uuid: uuid, - Id: idInt, - CurrentUser: currentUser, - } - err = h.pc.LikeConversationMessage(ctx, req) - if err != nil { - slog.Error("Failed to like conversation message", slog.Any("req", req), slog.Any("error", err)) - httpbase.ServerError(ctx, err) - return - } - httpbase.OK(ctx, nil) -} - -// HateMessage godoc -// @Security ApiKey -// @Summary Hate a conversation message -// @Description Hate a conversation message -// @Tags Prompt -// @Accept json -// @Produce json -// @Param uuid path string true "conversation uuid" -// @Param id path string true "message id" -// @Success 200 {object} types.Response{} "OK" -// @Failure 400 {object} types.APIBadRequest "Bad request" -// @Failure 500 {object} types.APIInternalServerError "Internal server error" -// @Router /prompts/conversations/{id}/message/{msgid}/hate [put] -func (h *PromptHandler) HateMessage(ctx *gin.Context) { - currentUser := httpbase.GetCurrentUser(ctx) - if currentUser == "" { - httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) - return - } - uuid := ctx.Param("id") - if len(uuid) < 1 { - slog.Error("Bad request conversation uuid") - httpbase.BadRequest(ctx, "uuid is empty") - return - } - msgid := ctx.Param("msgid") - idInt, err := strconv.ParseInt(msgid, 10, 64) - if err != nil { - slog.Error("Bad request message id", slog.Any("msgid", msgid), slog.Any("error", err)) - httpbase.BadRequest(ctx, err.Error()) - return - } - req := types.ConversationMessageReq{ - Uuid: uuid, - Id: idInt, - CurrentUser: currentUser, - } - err = h.pc.HateConversationMessage(ctx, req) - if err != nil { - slog.Error("Failed to hate conversation message", slog.Any("req", req), slog.Any("error", err)) - httpbase.ServerError(ctx, err) - return - } - httpbase.OK(ctx, nil) -} - // PromptRelations godoc // @Security ApiKey // @Summary Get prompt related assets diff --git a/api/router/api.go b/api/router/api.go index 6fc3ae6a..5b3f197c 100644 --- a/api/router/api.go +++ b/api/router/api.go @@ -756,17 +756,6 @@ func createPromptRoutes(apiGroup *gin.RouterGroup, promptHandler *handler.Prompt promptGrp.PUT("/:namespace/:name/relations", promptHandler.SetRelations) promptGrp.POST("/:namespace/:name/relations/model", promptHandler.AddModelRelation) promptGrp.DELETE("/:namespace/:name/relations/model", promptHandler.DelModelRelation) - conversationGrp := promptGrp.Group("/conversations") - { - conversationGrp.POST("", promptHandler.NewConversation) - conversationGrp.GET("", promptHandler.ListConversation) - conversationGrp.GET("/:id", promptHandler.GetConversation) - conversationGrp.POST("/:id", promptHandler.SubmitMessage) - conversationGrp.PUT("/:id", promptHandler.UpdateConversation) - conversationGrp.DELETE("/:id", promptHandler.RemoveConversation) - conversationGrp.PUT("/:id/message/:msgid/like", promptHandler.LikeMessage) - conversationGrp.PUT("/:id/message/:msgid/hate", promptHandler.HateMessage) - } promptGrp.POST("", promptHandler.Create) promptGrp.PUT("/:namespace/:name", promptHandler.Update) diff --git a/component/prompt.go b/component/prompt.go index dbe41c1b..1d4e789e 100644 --- a/component/prompt.go +++ b/component/prompt.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "log/slog" - "regexp" "strings" "sync" @@ -377,200 +376,6 @@ func (c *promptComponentImpl) checkPromptRepoPermission(ctx context.Context, req return &user, nil } -func (c *promptComponentImpl) NewConversation(ctx context.Context, req types.ConversationTitleReq) (*database.PromptConversation, error) { - user, err := c.userStore.FindByUsername(ctx, req.CurrentUser) - if err != nil { - return nil, errors.New("user does not exist") - } - conversation := database.PromptConversation{ - UserID: user.ID, - ConversationID: req.Uuid, - Title: req.Title, - } - - err = c.promptConvStore.CreateConversation(ctx, conversation) - if err != nil { - return nil, fmt.Errorf("new conversation error: %w", err) - } - - return &conversation, nil -} - -func (c *promptComponentImpl) ListConversationsByUserID(ctx context.Context, currentUser string) ([]database.PromptConversation, error) { - user, err := c.userStore.FindByUsername(ctx, currentUser) - if err != nil { - return nil, errors.New("user does not exist") - } - conversations, err := c.promptConvStore.FindConversationsByUserID(ctx, user.ID) - if err != nil { - return nil, fmt.Errorf("find conversations by user %s error: %w", currentUser, err) - } - return conversations, nil -} - -func (c *promptComponentImpl) GetConversation(ctx context.Context, req types.ConversationReq) (*database.PromptConversation, error) { - user, err := c.userStore.FindByUsername(ctx, req.CurrentUser) - if err != nil { - return nil, errors.New("user does not exist") - } - conversation, err := c.promptConvStore.GetConversationByID(ctx, user.ID, req.Uuid, true) - if err != nil { - return nil, fmt.Errorf("get conversation by id %s error: %w", req.Uuid, err) - } - return conversation, nil -} - -func (c *promptComponentImpl) SubmitMessage(ctx context.Context, req types.ConversationReq) (<-chan string, error) { - user, err := c.userStore.FindByUsername(ctx, req.CurrentUser) - if err != nil { - return nil, errors.New("user does not exist") - } - - _, err = c.promptConvStore.GetConversationByID(ctx, user.ID, req.Uuid, false) - if err != nil { - return nil, fmt.Errorf("invalid conversation by uuid %s error: %w", req.Uuid, err) - } - - reqMsg := database.PromptConversationMessage{ - ConversationID: req.Uuid, - Role: UserRole, - Content: req.Message, - } - _, err = c.promptConvStore.SaveConversationMessage(ctx, reqMsg) - if err != nil { - return nil, fmt.Errorf("save user prompt input error: %w", err) - } - - llmConfig, err := c.llmConfigStore.GetOptimization(ctx) - if err != nil { - return nil, fmt.Errorf("get llm config error: %w", err) - } - slog.Debug("use llm", slog.Any("llmConfig", llmConfig)) - var headers map[string]string - err = json.Unmarshal([]byte(llmConfig.AuthHeader), &headers) - if err != nil { - return nil, fmt.Errorf("parse llm config header error: %w", err) - } - - promptPrefix := "" - prefix, err := c.promptPrefixStore.Get(ctx) - if err != nil { - slog.Warn("fail to find prompt prefix", slog.Any("err", err)) - } else { - chs := isChinese(reqMsg.Content) - if chs { - promptPrefix = prefix.ZH - } else { - promptPrefix = prefix.EN - } - } - - reqData := types.LLMReqBody{ - Model: llmConfig.ModelName, - Messages: []types.LLMMessage{ - {Role: SystemRole, Content: promptPrefix}, - {Role: UserRole, Content: reqMsg.Content}, - }, - Stream: true, - Temperature: 0.2, - } - if req.Temperature != nil { - reqData.Temperature = *req.Temperature - } - - slog.Debug("llm request", slog.Any("reqData", reqData)) - ch, err := c.llmClient.Chat(ctx, llmConfig.ApiEndpoint, headers, reqData) - if err != nil { - return nil, fmt.Errorf("call llm error: %w", err) - } - return ch, nil -} - -func (c *promptComponentImpl) SaveGeneratedText(ctx context.Context, req types.Conversation) (*database.PromptConversationMessage, error) { - respMsg := database.PromptConversationMessage{ - ConversationID: req.Uuid, - Role: AssistantRole, - Content: req.Message, - } - msg, err := c.promptConvStore.SaveConversationMessage(ctx, respMsg) - if err != nil { - return nil, fmt.Errorf("save system generated response error: %w", err) - } - return msg, nil -} - -func (c *promptComponentImpl) RemoveConversation(ctx context.Context, req types.ConversationReq) error { - user, err := c.userStore.FindByUsername(ctx, req.CurrentUser) - if err != nil { - return errors.New("user does not exist") - } - - err = c.promptConvStore.DeleteConversationsByID(ctx, user.ID, req.Uuid) - if err != nil { - return fmt.Errorf("remove conversation error: %w", err) - } - return nil -} - -func (c *promptComponentImpl) UpdateConversation(ctx context.Context, req types.ConversationTitleReq) (*database.PromptConversation, error) { - user, err := c.userStore.FindByUsername(ctx, req.CurrentUser) - if err != nil { - return nil, errors.New("user does not exist") - } - - err = c.promptConvStore.UpdateConversation(ctx, database.PromptConversation{ - UserID: user.ID, - ConversationID: req.Uuid, - Title: req.Title, - }) - if err != nil { - return nil, fmt.Errorf("update conversation title error: %w", err) - } - - resp, err := c.promptConvStore.GetConversationByID(ctx, user.ID, req.Uuid, false) - if err != nil { - return nil, fmt.Errorf("invalid conversation by uuid %s error: %w", req.Uuid, err) - } - return resp, nil -} - -func (c *promptComponentImpl) LikeConversationMessage(ctx context.Context, req types.ConversationMessageReq) error { - user, err := c.userStore.FindByUsername(ctx, req.CurrentUser) - if err != nil { - return errors.New("user does not exist") - } - _, err = c.promptConvStore.GetConversationByID(ctx, user.ID, req.Uuid, false) - if err != nil { - return fmt.Errorf("invalid conversation by uuid %s error: %w", req.Uuid, err) - } - err = c.promptConvStore.LikeMessageByID(ctx, req.Id) - if err != nil { - return fmt.Errorf("update like message by id %d error: %w", req.Id, err) - } - return nil -} - -func (c *promptComponentImpl) HateConversationMessage(ctx context.Context, req types.ConversationMessageReq) error { - user, err := c.userStore.FindByUsername(ctx, req.CurrentUser) - if err != nil { - return errors.New("user does not exist") - } - _, err = c.promptConvStore.GetConversationByID(ctx, user.ID, req.Uuid, false) - if err != nil { - return fmt.Errorf("invalid conversation by uuid %s error: %w", req.Uuid, err) - } - err = c.promptConvStore.HateMessageByID(ctx, req.Id) - if err != nil { - return fmt.Errorf("update hate message by id %d error: %w", req.Id, err) - } - return nil -} - -func isChinese(s string) bool { - re := regexp.MustCompile(`[\p{Han}]`) - return re.MatchString(s) -} - func (c *promptComponentImpl) SetRelationModels(ctx context.Context, req types.RelationModels) error { user, err := c.userStore.FindByUsername(ctx, req.CurrentUser) if err != nil { diff --git a/component/prompt_ce.go b/component/prompt_ce.go new file mode 100644 index 00000000..37af931a --- /dev/null +++ b/component/prompt_ce.go @@ -0,0 +1,51 @@ +//go:build !ee && !saas + +package component + +import ( + "context" + + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func (c *promptComponentImpl) NewConversation(ctx context.Context, req types.ConversationTitleReq) (*database.PromptConversation, error) { + return nil, nil +} + +func (c *promptComponentImpl) ListConversationsByUserID(ctx context.Context, currentUser string) ([]database.PromptConversation, error) { + return nil, nil +} + +func (c *promptComponentImpl) GetConversation(ctx context.Context, req types.ConversationReq) (*database.PromptConversation, error) { + return nil, nil +} + +func (c *promptComponentImpl) SubmitMessage(ctx context.Context, req types.ConversationReq) (<-chan string, error) { + return nil, nil +} + +func (c *promptComponentImpl) SaveGeneratedText(ctx context.Context, req types.Conversation) (*database.PromptConversationMessage, error) { + return nil, nil +} + +func (c *promptComponentImpl) RemoveConversation(ctx context.Context, req types.ConversationReq) error { + + return nil +} + +func (c *promptComponentImpl) UpdateConversation(ctx context.Context, req types.ConversationTitleReq) (*database.PromptConversation, error) { + return nil, nil +} + +func (c *promptComponentImpl) LikeConversationMessage(ctx context.Context, req types.ConversationMessageReq) error { + return nil +} + +func (c *promptComponentImpl) HateConversationMessage(ctx context.Context, req types.ConversationMessageReq) error { + return nil +} + +func (c *promptComponentImpl) SummarizeConversationTitle(ctx context.Context, req types.ConversationTitleReq) (*database.PromptConversation, error) { + return nil, nil +} diff --git a/component/prompt_test.go b/component/prompt_test.go index 5704c66c..dddc8b26 100644 --- a/component/prompt_test.go +++ b/component/prompt_test.go @@ -3,15 +3,10 @@ package component import ( "context" "encoding/base64" - "encoding/json" "errors" "fmt" - "net/http" "testing" - "github.com/alibabacloud-go/tea/tea" - "github.com/jarcoal/httpmock" - "github.com/spf13/cast" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "opencsg.com/csghub-server/builder/git/gitserver" @@ -361,231 +356,6 @@ func TestPromptComponent_GetPrompt(t *testing.T) { require.Equal(t, "foo.jsonl", output.FilePath) } -func TestPromptComponent_NewConversation(t *testing.T) { - ctx := context.TODO() - pc := initializeTestPromptComponent(ctx, t) - - pc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "foo").Return(database.User{ID: 123}, nil).Once() - pc.mocks.stores.PromptConversationMock().EXPECT().CreateConversation(ctx, database.PromptConversation{ - UserID: 123, - ConversationID: "zzz", - Title: "test", - }).Return(nil).Once() - - cv, err := pc.NewConversation(ctx, types.ConversationTitleReq{ - CurrentUser: "foo", - ConversationTitle: types.ConversationTitle{ - Uuid: "zzz", - Title: "test", - }, - }) - require.Nil(t, err) - require.Equal(t, 123, int(cv.UserID)) - -} - -func TestPromptComponent_ListConversationByUserID(t *testing.T) { - ctx := context.TODO() - pc := initializeTestPromptComponent(ctx, t) - - pc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "foo").Return(database.User{ID: 123}, nil).Once() - mockedResults := []database.PromptConversation{ - {Title: "foo"}, - {Title: "bar"}, - } - pc.mocks.stores.PromptConversationMock().EXPECT().FindConversationsByUserID(ctx, int64(123)).Return(mockedResults, nil).Once() - - results, err := pc.ListConversationsByUserID(ctx, "foo") - require.Nil(t, err) - require.Equal(t, mockedResults, results) - -} - -func TestPromptComponent_GetConversation(t *testing.T) { - ctx := context.TODO() - pc := initializeTestPromptComponent(ctx, t) - - pc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "foo").Return(database.User{ID: 123}, nil).Once() - mocked := &database.PromptConversation{} - pc.mocks.stores.PromptConversationMock().EXPECT().GetConversationByID(ctx, int64(123), "uuid", true).Return(mocked, nil).Once() - - cv, err := pc.GetConversation(ctx, types.ConversationReq{ - CurrentUser: "foo", - Conversation: types.Conversation{ - Uuid: "uuid", - }, - }) - require.Nil(t, err) - require.Equal(t, mocked, cv) - -} - -func TestPromptComponent_SubmitMessage(t *testing.T) { - ctx := context.TODO() - pc := initializeTestPromptComponent(ctx, t) - - for _, lang := range []string{"en", "zh"} { - t.Run(lang, func(t *testing.T) { - - pc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "foo").Return(database.User{ID: 123}, nil).Once() - pc.mocks.stores.PromptConversationMock().EXPECT().GetConversationByID(ctx, int64(123), "uuid", false).Return(&database.PromptConversation{}, nil).Once() - - content := "go" - if lang == "zh" { - content = "围棋" - } - pc.mocks.stores.PromptConversationMock().EXPECT().SaveConversationMessage( - ctx, database.PromptConversationMessage{ - ConversationID: "uuid", - Role: UserRole, - Content: content, - }, - ).Return(&database.PromptConversationMessage{}, nil) - pc.mocks.stores.LLMConfigMock().EXPECT().GetOptimization(ctx).Return(&database.LLMConfig{ - ApiEndpoint: "https://llm.com", - AuthHeader: `{"token": "foobar"}`, - }, nil).Once() - pc.mocks.stores.PromptPrefixMock().EXPECT().Get(ctx).Return(&database.PromptPrefix{ - ZH: "use Chinese", - EN: "use English", - }, nil).Once() - httpmock.Activate() - t.Cleanup(httpmock.DeactivateAndReset) - - httpmock.RegisterResponder("POST", "https://llm.com", - func(req *http.Request) (*http.Response, error) { - article := make(map[string]interface{}) - if err := json.NewDecoder(req.Body).Decode(&article); err != nil { - return httpmock.NewStringResponse(400, ""), nil - } - prefix := cast.ToStringMap(cast.ToSlice(article["messages"])[0])["content"] - d := "" - switch prefix { - case "use English": - d = `[{"id": 1, "name": "My Great Article"}]` - case "use Chinese": - d = `[{"id": 1, "name": "好好好"}]` - default: - d = "wrong" - } - return httpmock.NewStringResponse( - 200, d, - ), nil - }) - - ch, err := pc.SubmitMessage(ctx, types.ConversationReq{ - CurrentUser: "foo", - Conversation: types.Conversation{ - Uuid: "uuid", - Message: content, - }, - }) - require.Nil(t, err) - all := "" - for i := range ch { - all += i - } - if lang == "en" { - require.Equal(t, "[{\"id\": 1, \"name\": \"My Great Article\"}]", all) - } else { - require.Equal(t, "[{\"id\": 1, \"name\": \"好好好\"}]", all) - } - }) - } -} - -func TestPromptComponent_SaveGeneratedText(t *testing.T) { - ctx := context.TODO() - pc := initializeTestPromptComponent(ctx, t) - - mocked := &database.PromptConversationMessage{} - pc.mocks.stores.PromptConversationMock().EXPECT().SaveConversationMessage(ctx, database.PromptConversationMessage{ - ConversationID: "uuid", - Role: AssistantRole, - Content: "m", - }).Return(mocked, nil).Once() - - m, err := pc.SaveGeneratedText(ctx, types.Conversation{ - Uuid: "uuid", - Message: "m", - Temperature: tea.Float64(0.8), - }) - require.Nil(t, err) - require.Equal(t, mocked, m) -} - -func TestPromptComponent_RemoveConversation(t *testing.T) { - ctx := context.TODO() - pc := initializeTestPromptComponent(ctx, t) - - pc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "foo").Return(database.User{ID: 123}, nil).Once() - pc.mocks.stores.PromptConversationMock().EXPECT().DeleteConversationsByID(ctx, int64(123), "uuid").Return(nil).Once() - - err := pc.RemoveConversation(ctx, types.ConversationReq{ - CurrentUser: "foo", - Conversation: types.Conversation{ - Uuid: "uuid", - }, - }) - require.Nil(t, err) -} - -func TestPromptComponent_UpdateConversation(t *testing.T) { - ctx := context.TODO() - pc := initializeTestPromptComponent(ctx, t) - - pc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "foo").Return(database.User{ID: 123}, nil).Once() - pc.mocks.stores.PromptConversationMock().EXPECT().UpdateConversation(ctx, database.PromptConversation{ - UserID: 123, - ConversationID: "uuid", - Title: "title", - }).Return(nil).Once() - mocked := &database.PromptConversation{} - pc.mocks.stores.PromptConversationMock().EXPECT().GetConversationByID(ctx, int64(123), "uuid", false).Return(mocked, nil) - - cv, err := pc.UpdateConversation(ctx, types.ConversationTitleReq{ - CurrentUser: "foo", - ConversationTitle: types.ConversationTitle{ - Uuid: "uuid", - Title: "title", - }, - }) - require.Nil(t, err) - require.Equal(t, mocked, cv) -} - -func TestPromptComponent_LikeConversationMessage(t *testing.T) { - ctx := context.TODO() - pc := initializeTestPromptComponent(ctx, t) - - pc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "foo").Return(database.User{ID: 123}, nil).Once() - pc.mocks.stores.PromptConversationMock().EXPECT().GetConversationByID(ctx, int64(123), "uuid", false).Return(&database.PromptConversation{}, nil).Once() - pc.mocks.stores.PromptConversationMock().EXPECT().LikeMessageByID(ctx, int64(123)).Return(nil).Once() - - err := pc.LikeConversationMessage(ctx, types.ConversationMessageReq{ - Uuid: "uuid", - Id: 123, - CurrentUser: "foo", - }) - require.Nil(t, err) -} - -func TestPromptComponent_HateConversationMessage(t *testing.T) { - ctx := context.TODO() - pc := initializeTestPromptComponent(ctx, t) - - pc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "foo").Return(database.User{ID: 123}, nil).Once() - pc.mocks.stores.PromptConversationMock().EXPECT().GetConversationByID(ctx, int64(123), "uuid", false).Return(&database.PromptConversation{}, nil).Once() - pc.mocks.stores.PromptConversationMock().EXPECT().HateMessageByID(ctx, int64(123)).Return(nil).Once() - - err := pc.HateConversationMessage(ctx, types.ConversationMessageReq{ - Uuid: "uuid", - Id: 123, - CurrentUser: "foo", - }) - require.Nil(t, err) -} - func TestPromptComponent_SetRelationModels(t *testing.T) { ctx := context.TODO() pc := initializeTestPromptComponent(ctx, t) From b0a0d3ace534567b3f332a0a8e90130f16054fb8 Mon Sep 17 00:00:00 2001 From: SeanHH86 <154984842+SeanHH86@users.noreply.github.com> Date: Wed, 25 Dec 2024 11:19:58 +0800 Subject: [PATCH 25/34] [UT] add metering consumer ut (#222) --- .mockery.yaml | 4 + .../nats-io/nats.go/jetstream/mock_Msg.go | 600 ++++++++++++++++++ .../component/mock_MeteringComponent.go | 211 ++++++ .../database/mock_AccountMeteringStore.go | 59 ++ .../csghub-server/mq/mock_MessageQueue.go | 196 ++++++ accounting/component/metering.go | 23 +- accounting/component/metering_test.go | 109 ++++ accounting/consumer/metering.go | 53 +- accounting/consumer/metering_test.go | 307 +++++++++ accounting/utils/format_test.go | 22 + accounting/utils/parameters_test.go | 43 ++ accounting/utils/scene.go | 58 ++ accounting/utils/scene_test.go | 75 +++ builder/store/database/account_metering.go | 29 +- .../store/database/account_metering_test.go | 114 +++- common/config/config.go | 5 +- common/types/accounting.go | 4 + mq/messagequeue.go | 4 + mq/nats.go | 14 + 19 files changed, 1888 insertions(+), 42 deletions(-) create mode 100644 _mocks/github.com/nats-io/nats.go/jetstream/mock_Msg.go create mode 100644 _mocks/opencsg.com/csghub-server/accounting/component/mock_MeteringComponent.go create mode 100644 accounting/component/metering_test.go create mode 100644 accounting/consumer/metering_test.go create mode 100644 accounting/utils/format_test.go create mode 100644 accounting/utils/parameters_test.go create mode 100644 accounting/utils/scene.go create mode 100644 accounting/utils/scene_test.go diff --git a/.mockery.yaml b/.mockery.yaml index 04274577..6f7dc0b1 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -112,3 +112,7 @@ packages: config: interfaces: Client: + github.com/nats-io/nats.go/jetstream: + config: + interfaces: + Msg: diff --git a/_mocks/github.com/nats-io/nats.go/jetstream/mock_Msg.go b/_mocks/github.com/nats-io/nats.go/jetstream/mock_Msg.go new file mode 100644 index 00000000..11cf26ac --- /dev/null +++ b/_mocks/github.com/nats-io/nats.go/jetstream/mock_Msg.go @@ -0,0 +1,600 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package jetstream + +import ( + context "context" + + jetstream "github.com/nats-io/nats.go/jetstream" + mock "github.com/stretchr/testify/mock" + + nats "github.com/nats-io/nats.go" + + time "time" +) + +// MockMsg is an autogenerated mock type for the Msg type +type MockMsg struct { + mock.Mock +} + +type MockMsg_Expecter struct { + mock *mock.Mock +} + +func (_m *MockMsg) EXPECT() *MockMsg_Expecter { + return &MockMsg_Expecter{mock: &_m.Mock} +} + +// Ack provides a mock function with given fields: +func (_m *MockMsg) Ack() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Ack") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMsg_Ack_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Ack' +type MockMsg_Ack_Call struct { + *mock.Call +} + +// Ack is a helper method to define mock.On call +func (_e *MockMsg_Expecter) Ack() *MockMsg_Ack_Call { + return &MockMsg_Ack_Call{Call: _e.mock.On("Ack")} +} + +func (_c *MockMsg_Ack_Call) Run(run func()) *MockMsg_Ack_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockMsg_Ack_Call) Return(_a0 error) *MockMsg_Ack_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMsg_Ack_Call) RunAndReturn(run func() error) *MockMsg_Ack_Call { + _c.Call.Return(run) + return _c +} + +// Data provides a mock function with given fields: +func (_m *MockMsg) Data() []byte { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Data") + } + + var r0 []byte + if rf, ok := ret.Get(0).(func() []byte); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + return r0 +} + +// MockMsg_Data_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Data' +type MockMsg_Data_Call struct { + *mock.Call +} + +// Data is a helper method to define mock.On call +func (_e *MockMsg_Expecter) Data() *MockMsg_Data_Call { + return &MockMsg_Data_Call{Call: _e.mock.On("Data")} +} + +func (_c *MockMsg_Data_Call) Run(run func()) *MockMsg_Data_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockMsg_Data_Call) Return(_a0 []byte) *MockMsg_Data_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMsg_Data_Call) RunAndReturn(run func() []byte) *MockMsg_Data_Call { + _c.Call.Return(run) + return _c +} + +// DoubleAck provides a mock function with given fields: _a0 +func (_m *MockMsg) DoubleAck(_a0 context.Context) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for DoubleAck") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMsg_DoubleAck_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DoubleAck' +type MockMsg_DoubleAck_Call struct { + *mock.Call +} + +// DoubleAck is a helper method to define mock.On call +// - _a0 context.Context +func (_e *MockMsg_Expecter) DoubleAck(_a0 interface{}) *MockMsg_DoubleAck_Call { + return &MockMsg_DoubleAck_Call{Call: _e.mock.On("DoubleAck", _a0)} +} + +func (_c *MockMsg_DoubleAck_Call) Run(run func(_a0 context.Context)) *MockMsg_DoubleAck_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockMsg_DoubleAck_Call) Return(_a0 error) *MockMsg_DoubleAck_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMsg_DoubleAck_Call) RunAndReturn(run func(context.Context) error) *MockMsg_DoubleAck_Call { + _c.Call.Return(run) + return _c +} + +// Headers provides a mock function with given fields: +func (_m *MockMsg) Headers() nats.Header { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Headers") + } + + var r0 nats.Header + if rf, ok := ret.Get(0).(func() nats.Header); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(nats.Header) + } + } + + return r0 +} + +// MockMsg_Headers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Headers' +type MockMsg_Headers_Call struct { + *mock.Call +} + +// Headers is a helper method to define mock.On call +func (_e *MockMsg_Expecter) Headers() *MockMsg_Headers_Call { + return &MockMsg_Headers_Call{Call: _e.mock.On("Headers")} +} + +func (_c *MockMsg_Headers_Call) Run(run func()) *MockMsg_Headers_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockMsg_Headers_Call) Return(_a0 nats.Header) *MockMsg_Headers_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMsg_Headers_Call) RunAndReturn(run func() nats.Header) *MockMsg_Headers_Call { + _c.Call.Return(run) + return _c +} + +// InProgress provides a mock function with given fields: +func (_m *MockMsg) InProgress() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for InProgress") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMsg_InProgress_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InProgress' +type MockMsg_InProgress_Call struct { + *mock.Call +} + +// InProgress is a helper method to define mock.On call +func (_e *MockMsg_Expecter) InProgress() *MockMsg_InProgress_Call { + return &MockMsg_InProgress_Call{Call: _e.mock.On("InProgress")} +} + +func (_c *MockMsg_InProgress_Call) Run(run func()) *MockMsg_InProgress_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockMsg_InProgress_Call) Return(_a0 error) *MockMsg_InProgress_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMsg_InProgress_Call) RunAndReturn(run func() error) *MockMsg_InProgress_Call { + _c.Call.Return(run) + return _c +} + +// Metadata provides a mock function with given fields: +func (_m *MockMsg) Metadata() (*jetstream.MsgMetadata, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Metadata") + } + + var r0 *jetstream.MsgMetadata + var r1 error + if rf, ok := ret.Get(0).(func() (*jetstream.MsgMetadata, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() *jetstream.MsgMetadata); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*jetstream.MsgMetadata) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMsg_Metadata_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Metadata' +type MockMsg_Metadata_Call struct { + *mock.Call +} + +// Metadata is a helper method to define mock.On call +func (_e *MockMsg_Expecter) Metadata() *MockMsg_Metadata_Call { + return &MockMsg_Metadata_Call{Call: _e.mock.On("Metadata")} +} + +func (_c *MockMsg_Metadata_Call) Run(run func()) *MockMsg_Metadata_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockMsg_Metadata_Call) Return(_a0 *jetstream.MsgMetadata, _a1 error) *MockMsg_Metadata_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMsg_Metadata_Call) RunAndReturn(run func() (*jetstream.MsgMetadata, error)) *MockMsg_Metadata_Call { + _c.Call.Return(run) + return _c +} + +// Nak provides a mock function with given fields: +func (_m *MockMsg) Nak() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Nak") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMsg_Nak_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Nak' +type MockMsg_Nak_Call struct { + *mock.Call +} + +// Nak is a helper method to define mock.On call +func (_e *MockMsg_Expecter) Nak() *MockMsg_Nak_Call { + return &MockMsg_Nak_Call{Call: _e.mock.On("Nak")} +} + +func (_c *MockMsg_Nak_Call) Run(run func()) *MockMsg_Nak_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockMsg_Nak_Call) Return(_a0 error) *MockMsg_Nak_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMsg_Nak_Call) RunAndReturn(run func() error) *MockMsg_Nak_Call { + _c.Call.Return(run) + return _c +} + +// NakWithDelay provides a mock function with given fields: delay +func (_m *MockMsg) NakWithDelay(delay time.Duration) error { + ret := _m.Called(delay) + + if len(ret) == 0 { + panic("no return value specified for NakWithDelay") + } + + var r0 error + if rf, ok := ret.Get(0).(func(time.Duration) error); ok { + r0 = rf(delay) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMsg_NakWithDelay_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NakWithDelay' +type MockMsg_NakWithDelay_Call struct { + *mock.Call +} + +// NakWithDelay is a helper method to define mock.On call +// - delay time.Duration +func (_e *MockMsg_Expecter) NakWithDelay(delay interface{}) *MockMsg_NakWithDelay_Call { + return &MockMsg_NakWithDelay_Call{Call: _e.mock.On("NakWithDelay", delay)} +} + +func (_c *MockMsg_NakWithDelay_Call) Run(run func(delay time.Duration)) *MockMsg_NakWithDelay_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(time.Duration)) + }) + return _c +} + +func (_c *MockMsg_NakWithDelay_Call) Return(_a0 error) *MockMsg_NakWithDelay_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMsg_NakWithDelay_Call) RunAndReturn(run func(time.Duration) error) *MockMsg_NakWithDelay_Call { + _c.Call.Return(run) + return _c +} + +// Reply provides a mock function with given fields: +func (_m *MockMsg) Reply() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Reply") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// MockMsg_Reply_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Reply' +type MockMsg_Reply_Call struct { + *mock.Call +} + +// Reply is a helper method to define mock.On call +func (_e *MockMsg_Expecter) Reply() *MockMsg_Reply_Call { + return &MockMsg_Reply_Call{Call: _e.mock.On("Reply")} +} + +func (_c *MockMsg_Reply_Call) Run(run func()) *MockMsg_Reply_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockMsg_Reply_Call) Return(_a0 string) *MockMsg_Reply_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMsg_Reply_Call) RunAndReturn(run func() string) *MockMsg_Reply_Call { + _c.Call.Return(run) + return _c +} + +// Subject provides a mock function with given fields: +func (_m *MockMsg) Subject() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Subject") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// MockMsg_Subject_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Subject' +type MockMsg_Subject_Call struct { + *mock.Call +} + +// Subject is a helper method to define mock.On call +func (_e *MockMsg_Expecter) Subject() *MockMsg_Subject_Call { + return &MockMsg_Subject_Call{Call: _e.mock.On("Subject")} +} + +func (_c *MockMsg_Subject_Call) Run(run func()) *MockMsg_Subject_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockMsg_Subject_Call) Return(_a0 string) *MockMsg_Subject_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMsg_Subject_Call) RunAndReturn(run func() string) *MockMsg_Subject_Call { + _c.Call.Return(run) + return _c +} + +// Term provides a mock function with given fields: +func (_m *MockMsg) Term() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Term") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMsg_Term_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Term' +type MockMsg_Term_Call struct { + *mock.Call +} + +// Term is a helper method to define mock.On call +func (_e *MockMsg_Expecter) Term() *MockMsg_Term_Call { + return &MockMsg_Term_Call{Call: _e.mock.On("Term")} +} + +func (_c *MockMsg_Term_Call) Run(run func()) *MockMsg_Term_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockMsg_Term_Call) Return(_a0 error) *MockMsg_Term_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMsg_Term_Call) RunAndReturn(run func() error) *MockMsg_Term_Call { + _c.Call.Return(run) + return _c +} + +// TermWithReason provides a mock function with given fields: reason +func (_m *MockMsg) TermWithReason(reason string) error { + ret := _m.Called(reason) + + if len(ret) == 0 { + panic("no return value specified for TermWithReason") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(reason) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMsg_TermWithReason_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TermWithReason' +type MockMsg_TermWithReason_Call struct { + *mock.Call +} + +// TermWithReason is a helper method to define mock.On call +// - reason string +func (_e *MockMsg_Expecter) TermWithReason(reason interface{}) *MockMsg_TermWithReason_Call { + return &MockMsg_TermWithReason_Call{Call: _e.mock.On("TermWithReason", reason)} +} + +func (_c *MockMsg_TermWithReason_Call) Run(run func(reason string)) *MockMsg_TermWithReason_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockMsg_TermWithReason_Call) Return(_a0 error) *MockMsg_TermWithReason_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMsg_TermWithReason_Call) RunAndReturn(run func(string) error) *MockMsg_TermWithReason_Call { + _c.Call.Return(run) + return _c +} + +// NewMockMsg creates a new instance of MockMsg. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockMsg(t interface { + mock.TestingT + Cleanup(func()) +}) *MockMsg { + mock := &MockMsg{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/accounting/component/mock_MeteringComponent.go b/_mocks/opencsg.com/csghub-server/accounting/component/mock_MeteringComponent.go new file mode 100644 index 00000000..865c3b2a --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/accounting/component/mock_MeteringComponent.go @@ -0,0 +1,211 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + database "opencsg.com/csghub-server/builder/store/database" + + types "opencsg.com/csghub-server/common/types" +) + +// MockMeteringComponent is an autogenerated mock type for the MeteringComponent type +type MockMeteringComponent struct { + mock.Mock +} + +type MockMeteringComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockMeteringComponent) EXPECT() *MockMeteringComponent_Expecter { + return &MockMeteringComponent_Expecter{mock: &_m.Mock} +} + +// GetMeteringStatByDate provides a mock function with given fields: ctx, req +func (_m *MockMeteringComponent) GetMeteringStatByDate(ctx context.Context, req types.ACCT_STATEMENTS_REQ) ([]map[string]interface{}, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for GetMeteringStatByDate") + } + + var r0 []map[string]interface{} + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.ACCT_STATEMENTS_REQ) ([]map[string]interface{}, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.ACCT_STATEMENTS_REQ) []map[string]interface{}); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]map[string]interface{}) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.ACCT_STATEMENTS_REQ) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMeteringComponent_GetMeteringStatByDate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetMeteringStatByDate' +type MockMeteringComponent_GetMeteringStatByDate_Call struct { + *mock.Call +} + +// GetMeteringStatByDate is a helper method to define mock.On call +// - ctx context.Context +// - req types.ACCT_STATEMENTS_REQ +func (_e *MockMeteringComponent_Expecter) GetMeteringStatByDate(ctx interface{}, req interface{}) *MockMeteringComponent_GetMeteringStatByDate_Call { + return &MockMeteringComponent_GetMeteringStatByDate_Call{Call: _e.mock.On("GetMeteringStatByDate", ctx, req)} +} + +func (_c *MockMeteringComponent_GetMeteringStatByDate_Call) Run(run func(ctx context.Context, req types.ACCT_STATEMENTS_REQ)) *MockMeteringComponent_GetMeteringStatByDate_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.ACCT_STATEMENTS_REQ)) + }) + return _c +} + +func (_c *MockMeteringComponent_GetMeteringStatByDate_Call) Return(_a0 []map[string]interface{}, _a1 error) *MockMeteringComponent_GetMeteringStatByDate_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMeteringComponent_GetMeteringStatByDate_Call) RunAndReturn(run func(context.Context, types.ACCT_STATEMENTS_REQ) ([]map[string]interface{}, error)) *MockMeteringComponent_GetMeteringStatByDate_Call { + _c.Call.Return(run) + return _c +} + +// ListMeteringByUserIDAndDate provides a mock function with given fields: ctx, req +func (_m *MockMeteringComponent) ListMeteringByUserIDAndDate(ctx context.Context, req types.ACCT_STATEMENTS_REQ) ([]database.AccountMetering, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for ListMeteringByUserIDAndDate") + } + + var r0 []database.AccountMetering + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, types.ACCT_STATEMENTS_REQ) ([]database.AccountMetering, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.ACCT_STATEMENTS_REQ) []database.AccountMetering); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]database.AccountMetering) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.ACCT_STATEMENTS_REQ) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, types.ACCT_STATEMENTS_REQ) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockMeteringComponent_ListMeteringByUserIDAndDate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListMeteringByUserIDAndDate' +type MockMeteringComponent_ListMeteringByUserIDAndDate_Call struct { + *mock.Call +} + +// ListMeteringByUserIDAndDate is a helper method to define mock.On call +// - ctx context.Context +// - req types.ACCT_STATEMENTS_REQ +func (_e *MockMeteringComponent_Expecter) ListMeteringByUserIDAndDate(ctx interface{}, req interface{}) *MockMeteringComponent_ListMeteringByUserIDAndDate_Call { + return &MockMeteringComponent_ListMeteringByUserIDAndDate_Call{Call: _e.mock.On("ListMeteringByUserIDAndDate", ctx, req)} +} + +func (_c *MockMeteringComponent_ListMeteringByUserIDAndDate_Call) Run(run func(ctx context.Context, req types.ACCT_STATEMENTS_REQ)) *MockMeteringComponent_ListMeteringByUserIDAndDate_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.ACCT_STATEMENTS_REQ)) + }) + return _c +} + +func (_c *MockMeteringComponent_ListMeteringByUserIDAndDate_Call) Return(_a0 []database.AccountMetering, _a1 int, _a2 error) *MockMeteringComponent_ListMeteringByUserIDAndDate_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockMeteringComponent_ListMeteringByUserIDAndDate_Call) RunAndReturn(run func(context.Context, types.ACCT_STATEMENTS_REQ) ([]database.AccountMetering, int, error)) *MockMeteringComponent_ListMeteringByUserIDAndDate_Call { + _c.Call.Return(run) + return _c +} + +// SaveMeteringEventRecord provides a mock function with given fields: ctx, req +func (_m *MockMeteringComponent) SaveMeteringEventRecord(ctx context.Context, req *types.METERING_EVENT) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for SaveMeteringEventRecord") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *types.METERING_EVENT) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMeteringComponent_SaveMeteringEventRecord_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveMeteringEventRecord' +type MockMeteringComponent_SaveMeteringEventRecord_Call struct { + *mock.Call +} + +// SaveMeteringEventRecord is a helper method to define mock.On call +// - ctx context.Context +// - req *types.METERING_EVENT +func (_e *MockMeteringComponent_Expecter) SaveMeteringEventRecord(ctx interface{}, req interface{}) *MockMeteringComponent_SaveMeteringEventRecord_Call { + return &MockMeteringComponent_SaveMeteringEventRecord_Call{Call: _e.mock.On("SaveMeteringEventRecord", ctx, req)} +} + +func (_c *MockMeteringComponent_SaveMeteringEventRecord_Call) Run(run func(ctx context.Context, req *types.METERING_EVENT)) *MockMeteringComponent_SaveMeteringEventRecord_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.METERING_EVENT)) + }) + return _c +} + +func (_c *MockMeteringComponent_SaveMeteringEventRecord_Call) Return(_a0 error) *MockMeteringComponent_SaveMeteringEventRecord_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMeteringComponent_SaveMeteringEventRecord_Call) RunAndReturn(run func(context.Context, *types.METERING_EVENT) error) *MockMeteringComponent_SaveMeteringEventRecord_Call { + _c.Call.Return(run) + return _c +} + +// NewMockMeteringComponent creates a new instance of MockMeteringComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockMeteringComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockMeteringComponent { + mock := &MockMeteringComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_AccountMeteringStore.go b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_AccountMeteringStore.go index ad7bb0ec..cb239af9 100644 --- a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_AccountMeteringStore.go +++ b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_AccountMeteringStore.go @@ -71,6 +71,65 @@ func (_c *MockAccountMeteringStore_Create_Call) RunAndReturn(run func(context.Co return _c } +// GetStatByDate provides a mock function with given fields: ctx, req +func (_m *MockAccountMeteringStore) GetStatByDate(ctx context.Context, req types.ACCT_STATEMENTS_REQ) ([]map[string]interface{}, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for GetStatByDate") + } + + var r0 []map[string]interface{} + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.ACCT_STATEMENTS_REQ) ([]map[string]interface{}, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.ACCT_STATEMENTS_REQ) []map[string]interface{}); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]map[string]interface{}) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.ACCT_STATEMENTS_REQ) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockAccountMeteringStore_GetStatByDate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetStatByDate' +type MockAccountMeteringStore_GetStatByDate_Call struct { + *mock.Call +} + +// GetStatByDate is a helper method to define mock.On call +// - ctx context.Context +// - req types.ACCT_STATEMENTS_REQ +func (_e *MockAccountMeteringStore_Expecter) GetStatByDate(ctx interface{}, req interface{}) *MockAccountMeteringStore_GetStatByDate_Call { + return &MockAccountMeteringStore_GetStatByDate_Call{Call: _e.mock.On("GetStatByDate", ctx, req)} +} + +func (_c *MockAccountMeteringStore_GetStatByDate_Call) Run(run func(ctx context.Context, req types.ACCT_STATEMENTS_REQ)) *MockAccountMeteringStore_GetStatByDate_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.ACCT_STATEMENTS_REQ)) + }) + return _c +} + +func (_c *MockAccountMeteringStore_GetStatByDate_Call) Return(_a0 []map[string]interface{}, _a1 error) *MockAccountMeteringStore_GetStatByDate_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockAccountMeteringStore_GetStatByDate_Call) RunAndReturn(run func(context.Context, types.ACCT_STATEMENTS_REQ) ([]map[string]interface{}, error)) *MockAccountMeteringStore_GetStatByDate_Call { + _c.Call.Return(run) + return _c +} + // ListAllByUserUUID provides a mock function with given fields: ctx, userUUID func (_m *MockAccountMeteringStore) ListAllByUserUUID(ctx context.Context, userUUID string) ([]database.AccountMetering, error) { ret := _m.Called(ctx, userUUID) diff --git a/_mocks/opencsg.com/csghub-server/mq/mock_MessageQueue.go b/_mocks/opencsg.com/csghub-server/mq/mock_MessageQueue.go index 855c7fc9..0d4f51b5 100644 --- a/_mocks/opencsg.com/csghub-server/mq/mock_MessageQueue.go +++ b/_mocks/opencsg.com/csghub-server/mq/mock_MessageQueue.go @@ -236,6 +236,64 @@ func (_c *MockMessageQueue_CreateOrUpdateStream_Call) RunAndReturn(run func(cont return _c } +// FetchMeterEventMessages provides a mock function with given fields: batch +func (_m *MockMessageQueue) FetchMeterEventMessages(batch int) (jetstream.MessageBatch, error) { + ret := _m.Called(batch) + + if len(ret) == 0 { + panic("no return value specified for FetchMeterEventMessages") + } + + var r0 jetstream.MessageBatch + var r1 error + if rf, ok := ret.Get(0).(func(int) (jetstream.MessageBatch, error)); ok { + return rf(batch) + } + if rf, ok := ret.Get(0).(func(int) jetstream.MessageBatch); ok { + r0 = rf(batch) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(jetstream.MessageBatch) + } + } + + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(batch) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMessageQueue_FetchMeterEventMessages_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'FetchMeterEventMessages' +type MockMessageQueue_FetchMeterEventMessages_Call struct { + *mock.Call +} + +// FetchMeterEventMessages is a helper method to define mock.On call +// - batch int +func (_e *MockMessageQueue_Expecter) FetchMeterEventMessages(batch interface{}) *MockMessageQueue_FetchMeterEventMessages_Call { + return &MockMessageQueue_FetchMeterEventMessages_Call{Call: _e.mock.On("FetchMeterEventMessages", batch)} +} + +func (_c *MockMessageQueue_FetchMeterEventMessages_Call) Run(run func(batch int)) *MockMessageQueue_FetchMeterEventMessages_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int)) + }) + return _c +} + +func (_c *MockMessageQueue_FetchMeterEventMessages_Call) Return(_a0 jetstream.MessageBatch, _a1 error) *MockMessageQueue_FetchMeterEventMessages_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMessageQueue_FetchMeterEventMessages_Call) RunAndReturn(run func(int) (jetstream.MessageBatch, error)) *MockMessageQueue_FetchMeterEventMessages_Call { + _c.Call.Return(run) + return _c +} + // GetConn provides a mock function with given fields: func (_m *MockMessageQueue) GetConn() *nats.Conn { ret := _m.Called() @@ -375,6 +433,144 @@ func (_c *MockMessageQueue_PublishData_Call) RunAndReturn(run func(string, []byt return _c } +// PublishFeeCreditData provides a mock function with given fields: data +func (_m *MockMessageQueue) PublishFeeCreditData(data []byte) error { + ret := _m.Called(data) + + if len(ret) == 0 { + panic("no return value specified for PublishFeeCreditData") + } + + var r0 error + if rf, ok := ret.Get(0).(func([]byte) error); ok { + r0 = rf(data) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMessageQueue_PublishFeeCreditData_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PublishFeeCreditData' +type MockMessageQueue_PublishFeeCreditData_Call struct { + *mock.Call +} + +// PublishFeeCreditData is a helper method to define mock.On call +// - data []byte +func (_e *MockMessageQueue_Expecter) PublishFeeCreditData(data interface{}) *MockMessageQueue_PublishFeeCreditData_Call { + return &MockMessageQueue_PublishFeeCreditData_Call{Call: _e.mock.On("PublishFeeCreditData", data)} +} + +func (_c *MockMessageQueue_PublishFeeCreditData_Call) Run(run func(data []byte)) *MockMessageQueue_PublishFeeCreditData_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]byte)) + }) + return _c +} + +func (_c *MockMessageQueue_PublishFeeCreditData_Call) Return(_a0 error) *MockMessageQueue_PublishFeeCreditData_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMessageQueue_PublishFeeCreditData_Call) RunAndReturn(run func([]byte) error) *MockMessageQueue_PublishFeeCreditData_Call { + _c.Call.Return(run) + return _c +} + +// PublishFeeQuotaData provides a mock function with given fields: data +func (_m *MockMessageQueue) PublishFeeQuotaData(data []byte) error { + ret := _m.Called(data) + + if len(ret) == 0 { + panic("no return value specified for PublishFeeQuotaData") + } + + var r0 error + if rf, ok := ret.Get(0).(func([]byte) error); ok { + r0 = rf(data) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMessageQueue_PublishFeeQuotaData_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PublishFeeQuotaData' +type MockMessageQueue_PublishFeeQuotaData_Call struct { + *mock.Call +} + +// PublishFeeQuotaData is a helper method to define mock.On call +// - data []byte +func (_e *MockMessageQueue_Expecter) PublishFeeQuotaData(data interface{}) *MockMessageQueue_PublishFeeQuotaData_Call { + return &MockMessageQueue_PublishFeeQuotaData_Call{Call: _e.mock.On("PublishFeeQuotaData", data)} +} + +func (_c *MockMessageQueue_PublishFeeQuotaData_Call) Run(run func(data []byte)) *MockMessageQueue_PublishFeeQuotaData_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]byte)) + }) + return _c +} + +func (_c *MockMessageQueue_PublishFeeQuotaData_Call) Return(_a0 error) *MockMessageQueue_PublishFeeQuotaData_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMessageQueue_PublishFeeQuotaData_Call) RunAndReturn(run func([]byte) error) *MockMessageQueue_PublishFeeQuotaData_Call { + _c.Call.Return(run) + return _c +} + +// PublishFeeTokenData provides a mock function with given fields: data +func (_m *MockMessageQueue) PublishFeeTokenData(data []byte) error { + ret := _m.Called(data) + + if len(ret) == 0 { + panic("no return value specified for PublishFeeTokenData") + } + + var r0 error + if rf, ok := ret.Get(0).(func([]byte) error); ok { + r0 = rf(data) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMessageQueue_PublishFeeTokenData_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PublishFeeTokenData' +type MockMessageQueue_PublishFeeTokenData_Call struct { + *mock.Call +} + +// PublishFeeTokenData is a helper method to define mock.On call +// - data []byte +func (_e *MockMessageQueue_Expecter) PublishFeeTokenData(data interface{}) *MockMessageQueue_PublishFeeTokenData_Call { + return &MockMessageQueue_PublishFeeTokenData_Call{Call: _e.mock.On("PublishFeeTokenData", data)} +} + +func (_c *MockMessageQueue_PublishFeeTokenData_Call) Run(run func(data []byte)) *MockMessageQueue_PublishFeeTokenData_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]byte)) + }) + return _c +} + +func (_c *MockMessageQueue_PublishFeeTokenData_Call) Return(_a0 error) *MockMessageQueue_PublishFeeTokenData_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMessageQueue_PublishFeeTokenData_Call) RunAndReturn(run func([]byte) error) *MockMessageQueue_PublishFeeTokenData_Call { + _c.Call.Return(run) + return _c +} + // PublishMeterDataToDLQ provides a mock function with given fields: data func (_m *MockMessageQueue) PublishMeterDataToDLQ(data []byte) error { ret := _m.Called(data) diff --git a/accounting/component/metering.go b/accounting/component/metering.go index 73d245ad..d5cf26d4 100644 --- a/accounting/component/metering.go +++ b/accounting/component/metering.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "opencsg.com/csghub-server/accounting/utils" "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/types" ) @@ -15,6 +16,7 @@ type meteringComponentImpl struct { type MeteringComponent interface { SaveMeteringEventRecord(ctx context.Context, req *types.METERING_EVENT) error ListMeteringByUserIDAndDate(ctx context.Context, req types.ACCT_STATEMENTS_REQ) ([]database.AccountMetering, int, error) + GetMeteringStatByDate(ctx context.Context, req types.ACCT_STATEMENTS_REQ) ([]map[string]interface{}, error) } func NewMeteringComponent() MeteringComponent { @@ -37,7 +39,7 @@ func (mc *meteringComponentImpl) SaveMeteringEventRecord(ctx context.Context, re CustomerID: req.CustomerID, RecordedAt: req.CreatedAt, Extra: req.Extra, - SkuUnitType: getUnitString(req.Scene), + SkuUnitType: utils.GetSkuUnitTypeByScene(types.SceneType(req.Scene)), } err := mc.ams.Create(ctx, am) if err != nil { @@ -54,19 +56,10 @@ func (mc *meteringComponentImpl) ListMeteringByUserIDAndDate(ctx context.Context return meters, total, nil } -func getUnitString(scene int) string { - switch types.SceneType(scene) { - case types.SceneModelInference: - return types.UnitMinute - case types.SceneSpace: - return types.UnitMinute - case types.SceneModelFinetune: - return types.UnitMinute - case types.SceneStarship: - return types.UnitToken - case types.SceneMultiSync: - return types.UnitRepo - default: - return types.UnitMinute +func (mc *meteringComponentImpl) GetMeteringStatByDate(ctx context.Context, req types.ACCT_STATEMENTS_REQ) ([]map[string]interface{}, error) { + res, err := mc.ams.GetStatByDate(ctx, req) + if err != nil { + return nil, fmt.Errorf("fail to get metering stat, error: %w", err) } + return res, nil } diff --git a/accounting/component/metering_test.go b/accounting/component/metering_test.go new file mode 100644 index 00000000..a9ccb621 --- /dev/null +++ b/accounting/component/metering_test.go @@ -0,0 +1,109 @@ +package component + +import ( + "context" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + mockdb "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/accounting/utils" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func NewTestMeteringComponent(amss database.AccountMeteringStore) MeteringComponent { + ams := &meteringComponentImpl{ + ams: amss, + } + return ams +} + +func TestMeteringComponent_SaveMeteringEventRecord(t *testing.T) { + ctx := context.TODO() + + uid := uuid.New() + req := types.METERING_EVENT{ + Uuid: uid, + UserUUID: "test-user-uuid", + Value: 100, + ValueType: 10, + Scene: 1, + OpUID: "test-op-uid", + ResourceID: "test-ID", + } + + data := database.AccountMetering{ + EventUUID: req.Uuid, + UserUUID: req.UserUUID, + Value: float64(req.Value), + ValueType: req.ValueType, + Scene: types.SceneType(req.Scene), + OpUID: req.OpUID, + ResourceID: req.ResourceID, + ResourceName: req.ResourceName, + CustomerID: req.CustomerID, + RecordedAt: req.CreatedAt, + Extra: req.Extra, + SkuUnitType: utils.GetSkuUnitTypeByScene(types.SceneType(req.Scene)), + } + + mockStore := mockdb.NewMockAccountMeteringStore(t) + mockStore.EXPECT().Create(ctx, data).Return(nil) + + mockComp := NewTestMeteringComponent(mockStore) + + err := mockComp.SaveMeteringEventRecord(ctx, &req) + + require.Nil(t, err) +} + +func TestMeteringComponent_ListMeteringByUserIDAndDate(t *testing.T) { + ctx := context.TODO() + + req := types.ACCT_STATEMENTS_REQ{ + UserUUID: "test-user-uuid", + Scene: int(types.SceneModelInference), + StartTime: "2024-01-01", + EndTime: "2024-12-31", + Per: 10, + Page: 1, + } + + data := []database.AccountMetering{ + { + EventUUID: uuid.New(), + UserUUID: "test-user-uuid", + Value: 100, + ValueType: 10, + Scene: types.SceneType(types.SceneModelInference), + OpUID: "test-op-uid", + }, + } + + mockStore := mockdb.NewMockAccountMeteringStore(t) + mockStore.EXPECT().ListByUserIDAndTime(ctx, req).Return(data, 1, nil) + + mockComp := NewTestMeteringComponent(mockStore) + + res, total, err := mockComp.ListMeteringByUserIDAndDate(ctx, req) + require.Nil(t, err) + require.Equal(t, 1, total) + require.NotNil(t, res) +} + +func TestMeteringComponent_GetMeteringStatByDate(t *testing.T) { + ctx := context.TODO() + + req := types.ACCT_STATEMENTS_REQ{} + data := []map[string]interface{}{} + + mockStore := mockdb.NewMockAccountMeteringStore(t) + + mockStore.EXPECT().GetStatByDate(ctx, req).Return(data, nil) + + mockComp := NewTestMeteringComponent(mockStore) + res, err := mockComp.GetMeteringStatByDate(ctx, req) + require.Nil(t, err) + require.NotNil(t, res) +} diff --git a/accounting/consumer/metering.go b/accounting/consumer/metering.go index e2afb3a5..572d8158 100644 --- a/accounting/consumer/metering.go +++ b/accounting/consumer/metering.go @@ -15,15 +15,21 @@ import ( "opencsg.com/csghub-server/mq" ) +var ( + idleDuration = 10 * time.Second +) + type Metering struct { - sysMQ *mq.NatsHandler - meterComp component.MeteringComponent + sysMQ mq.MessageQueue + meterComp component.MeteringComponent + chargingEnable bool } -func NewMetering(natHandler *mq.NatsHandler, config *config.Config) *Metering { +func NewMetering(natHandler mq.MessageQueue, config *config.Config) *Metering { meter := &Metering{ - sysMQ: natHandler, - meterComp: component.NewMeteringComponent(), + sysMQ: natHandler, + meterComp: component.NewMeteringComponent(), + chargingEnable: config.Accounting.ChargingEnable, } return meter } @@ -36,7 +42,7 @@ func (m *Metering) startMetering() { for { m.preReadMsgs() m.handleReadMsgs(10) - time.Sleep(10 * time.Second) + time.Sleep(2 * idleDuration) } } @@ -105,10 +111,11 @@ func (m *Metering) handleMsgWithRetry(msg jetstream.Msg) { slog.Debug("Meter->received", slog.Any("msg.subject", msg.Subject()), slog.Any("msg.data", strData)) // A maximum of 3 attempts var ( - err error = nil + err error = nil + evt *types.METERING_EVENT = nil ) for j := 0; j < 3; j++ { - _, err = m.handleMsgData(msg) + evt, err = m.handleMsgData(msg) if err == nil { break } @@ -122,6 +129,15 @@ func (m *Metering) handleMsgWithRetry(msg jetstream.Msg) { tip := fmt.Sprintf("failed to move meter msg to DLQ with %d retries", 5) slog.Error(tip, slog.Any("msg.data", string(msg.Data())), slog.Any("error", err)) } + } else { + if m.chargingEnable { + err = m.pubFeeEventWithReTry(msg, evt, 5) + if err != nil { + tip := fmt.Sprintf("failed to pub fee event msg with %d retries", 5) + slog.Error(tip, slog.Any("msg.data", string(msg.Data())), slog.Any("error", err)) + // todo: need more discuss on how to persist failed message finally + } + } } // ack for handle metering message done @@ -155,9 +171,28 @@ func (m *Metering) parseMessageData(msg jetstream.Msg) (*types.METERING_EVENT, e return &evt, nil } +func (m *Metering) pubFeeEventWithReTry(msg jetstream.Msg, evt *types.METERING_EVENT, limit int) error { + // A maximum of five attempts for pub fee event + var err error + for i := 0; i < limit; i++ { + switch evt.ValueType { + case types.TimeDurationMinType: + err = m.sysMQ.PublishFeeCreditData(msg.Data()) + case types.TokenNumberType: + err = m.sysMQ.PublishFeeTokenData(msg.Data()) + case types.QuotaNumberType: + err = m.sysMQ.PublishFeeQuotaData(msg.Data()) + } + if err == nil { + break + } + } + return err +} + func (m *Metering) moveMsgToDLQWithReTry(msg jetstream.Msg, limit int) error { // A maximum of five attempts for move DLQ - var err error = nil + var err error for i := 0; i < limit; i++ { err = m.sysMQ.PublishMeterDataToDLQ(msg.Data()) if err == nil { diff --git a/accounting/consumer/metering_test.go b/accounting/consumer/metering_test.go new file mode 100644 index 00000000..979aedfd --- /dev/null +++ b/accounting/consumer/metering_test.go @@ -0,0 +1,307 @@ +package consumer + +import ( + "encoding/json" + "errors" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + mockNats "opencsg.com/csghub-server/_mocks/github.com/nats-io/nats.go/jetstream" + mockacct "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/accounting/component" + mockmq "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/mq" + "opencsg.com/csghub-server/accounting/component" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" + "opencsg.com/csghub-server/mq" +) + +func NewTestConsumerMetering(natHandler mq.MessageQueue, meterComp component.MeteringComponent, config *config.Config) *Metering { + meter := &Metering{ + sysMQ: natHandler, + meterComp: meterComp, + chargingEnable: config.Accounting.ChargingEnable, + } + return meter +} + +func TestConsumerMetering_preReadMsgs(t *testing.T) { + cfg, err := config.LoadConfig() + cfg.Accounting.ChargingEnable = true + require.Nil(t, err) + + mq := mockmq.NewMockMessageQueue(t) + mq.EXPECT().BuildMeterEventStream().Return(nil) + mq.EXPECT().BuildDLQStream().Return(nil) + + meterComp := mockacct.NewMockMeteringComponent(t) + + meter := NewTestConsumerMetering(mq, meterComp, cfg) + + meter.preReadMsgs() +} + +func TestConsumerMetering_handleReadMsgs(t *testing.T) { + cfg, err := config.LoadConfig() + cfg.Accounting.ChargingEnable = true + require.Nil(t, err) + + mq := mockmq.NewMockMessageQueue(t) + mq.EXPECT().VerifyMeteringStream().Return(nil) + mq.EXPECT().FetchMeterEventMessages(5).Return(nil, errors.New("can not get msg")) + + meterComp := mockacct.NewMockMeteringComponent(t) + meter := NewTestConsumerMetering(mq, meterComp, cfg) + + done := make(chan bool) + go func() { + meter.handleReadMsgs(1) + close(done) + }() + + <-done +} + +func TestConsumerMetering_handleMsgWithRetry(t *testing.T) { + + cfg, err := config.LoadConfig() + cfg.Accounting.ChargingEnable = true + require.Nil(t, err) + + var event = types.METERING_EVENT{ + Uuid: uuid.MustParse("e2a0683d-ff52-4caf-915d-1ab052c57322"), + UserUUID: "bd05a582-a185-42d7-bf19-ad8108c4523b", + Value: 5000, + ValueType: 1, + Scene: 22, + OpUID: "", + ResourceID: "Autohub/gui_agent", + ResourceName: "Autohub/gui_agent", + CustomerID: "gui_agent", + CreatedAt: time.Date(2024, time.November, 6, 13, 19, 0, 0, time.UTC), + Extra: "{}", + } + + testData := []struct { + typeStr string + typeValue int + }{ + {"token", types.TokenNumberType}, + {"mintue", types.TimeDurationMinType}, + {"quota", types.QuotaNumberType}, + } + + for _, k := range testData { + t.Run(k.typeStr, func(t *testing.T) { + event.ValueType = k.typeValue + str, err := json.Marshal(event) + require.Nil(t, err) + + msg := mockNats.NewMockMsg(t) + msg.EXPECT().Data().Return(str) + msg.EXPECT().Subject().Return("") + msg.EXPECT().Ack().Return(nil) + + mq := mockmq.NewMockMessageQueue(t) + if k.typeStr == "token" { + mq.EXPECT().PublishFeeTokenData(str).Return(nil) + } + if k.typeStr == "mintue" { + mq.EXPECT().PublishFeeCreditData(str).Return(nil) + } + if k.typeStr == "quota" { + mq.EXPECT().PublishFeeQuotaData(str).Return(nil) + } + + meterComp := mockacct.NewMockMeteringComponent(t) + meterComp.EXPECT().SaveMeteringEventRecord(mock.Anything, &event).Return(nil) + + meter := NewTestConsumerMetering(mq, meterComp, cfg) + meter.handleMsgWithRetry(msg) + }) + } + + t.Run("error", func(t *testing.T) { + str := []byte("error error error") + + msg := mockNats.NewMockMsg(t) + msg.EXPECT().Data().Return(str) + msg.EXPECT().Subject().Return("") + msg.EXPECT().Ack().Return(nil) + + mq := mockmq.NewMockMessageQueue(t) + mq.EXPECT().PublishMeterDataToDLQ(str).Return(nil) + + meterComp := mockacct.NewMockMeteringComponent(t) + + meter := NewTestConsumerMetering(mq, meterComp, cfg) + meter.handleMsgWithRetry(msg) + }) + +} + +func TestConsumerMetering_handleMsgData(t *testing.T) { + cfg, err := config.LoadConfig() + cfg.Accounting.ChargingEnable = true + require.Nil(t, err) + + var event = types.METERING_EVENT{ + Uuid: uuid.MustParse("e2a0683d-ff52-4caf-915d-1ab052c57322"), + UserUUID: "bd05a582-a185-42d7-bf19-ad8108c4523b", + Value: 5000, + ValueType: 1, + Scene: 22, + OpUID: "", + ResourceID: "Autohub/gui_agent", + ResourceName: "Autohub/gui_agent", + CustomerID: "gui_agent", + CreatedAt: time.Date(2024, time.November, 6, 13, 19, 0, 0, time.UTC), + Extra: "{}", + } + + str, err := json.Marshal(event) + require.Nil(t, err) + + msg := mockNats.NewMockMsg(t) + msg.EXPECT().Data().Return(str) + + mq := mockmq.NewMockMessageQueue(t) + + meterComp := mockacct.NewMockMeteringComponent(t) + meterComp.EXPECT().SaveMeteringEventRecord(mock.Anything, &event).Return(nil) + + meter := NewTestConsumerMetering(mq, meterComp, cfg) + res, err := meter.handleMsgData(msg) + require.Nil(t, err) + require.Equal(t, event, *res) +} + +func TestConsumerMetering_parseMessageData(t *testing.T) { + cfg, err := config.LoadConfig() + cfg.Accounting.ChargingEnable = true + require.Nil(t, err) + + var event = types.METERING_EVENT{ + Uuid: uuid.MustParse("e2a0683d-ff52-4caf-915d-1ab052c57322"), + UserUUID: "bd05a582-a185-42d7-bf19-ad8108c4523b", + Value: 5000, + ValueType: 1, + Scene: 22, + OpUID: "", + ResourceID: "Autohub/gui_agent", + ResourceName: "Autohub/gui_agent", + CustomerID: "gui_agent", + CreatedAt: time.Date(2024, time.November, 6, 13, 19, 0, 0, time.UTC), + Extra: "{}", + } + + str, err := json.Marshal(event) + require.Nil(t, err) + + msg := mockNats.NewMockMsg(t) + msg.EXPECT().Data().Return(str) + + mq := mockmq.NewMockMessageQueue(t) + + meterComp := mockacct.NewMockMeteringComponent(t) + + meter := NewTestConsumerMetering(mq, meterComp, cfg) + res, err := meter.parseMessageData(msg) + require.Nil(t, err) + require.Equal(t, event, *res) +} + +func TestConsumerMetering_pubFeeEventWithReTry(t *testing.T) { + cfg, err := config.LoadConfig() + cfg.Accounting.ChargingEnable = true + require.Nil(t, err) + + var event = types.METERING_EVENT{ + Uuid: uuid.MustParse("e2a0683d-ff52-4caf-915d-1ab052c57322"), + UserUUID: "bd05a582-a185-42d7-bf19-ad8108c4523b", + Value: 5000, + ValueType: 1, + Scene: 22, + OpUID: "", + ResourceID: "Autohub/gui_agent", + ResourceName: "Autohub/gui_agent", + CustomerID: "gui_agent", + CreatedAt: time.Date(2024, time.November, 6, 13, 19, 0, 0, time.UTC), + Extra: "{}", + } + + testData := []struct { + typeStr string + typeValue int + }{ + {"token", types.TokenNumberType}, + {"mintue", types.TimeDurationMinType}, + {"quota", types.QuotaNumberType}, + } + + for _, k := range testData { + t.Run(k.typeStr, func(t *testing.T) { + event.ValueType = k.typeValue + str, err := json.Marshal(event) + require.Nil(t, err) + + msg := mockNats.NewMockMsg(t) + msg.EXPECT().Data().Return(str) + + mq := mockmq.NewMockMessageQueue(t) + if k.typeStr == "token" { + mq.EXPECT().PublishFeeTokenData(str).Return(nil) + } + if k.typeStr == "mintue" { + mq.EXPECT().PublishFeeCreditData(str).Return(nil) + } + if k.typeStr == "quota" { + mq.EXPECT().PublishFeeQuotaData(str).Return(nil) + } + + meterComp := mockacct.NewMockMeteringComponent(t) + + meter := NewTestConsumerMetering(mq, meterComp, cfg) + err = meter.pubFeeEventWithReTry(msg, &event, 1) + require.Nil(t, err) + }) + } +} + +func TestConsumerMetering_moveMsgToDLQWithReTry(t *testing.T) { + cfg, err := config.LoadConfig() + cfg.Accounting.ChargingEnable = true + require.Nil(t, err) + + var event = types.METERING_EVENT{ + Uuid: uuid.MustParse("e2a0683d-ff52-4caf-915d-1ab052c57322"), + UserUUID: "bd05a582-a185-42d7-bf19-ad8108c4523b", + Value: 5000, + ValueType: 1, + Scene: 22, + OpUID: "", + ResourceID: "Autohub/gui_agent", + ResourceName: "Autohub/gui_agent", + CustomerID: "gui_agent", + CreatedAt: time.Date(2024, time.November, 6, 13, 19, 0, 0, time.UTC), + Extra: "{}", + } + + str, err := json.Marshal(event) + require.Nil(t, err) + + msg := mockNats.NewMockMsg(t) + msg.EXPECT().Data().Return(str) + + mq := mockmq.NewMockMessageQueue(t) + mq.EXPECT().PublishMeterDataToDLQ(str).Return(nil) + + meterComp := mockacct.NewMockMeteringComponent(t) + + meter := NewTestConsumerMetering(mq, meterComp, cfg) + err = meter.moveMsgToDLQWithReTry(msg, 3) + require.Nil(t, err) + +} diff --git a/accounting/utils/format_test.go b/accounting/utils/format_test.go new file mode 100644 index 00000000..f3e91ea8 --- /dev/null +++ b/accounting/utils/format_test.go @@ -0,0 +1,22 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFormat_ValidateDateTimeFormat(t *testing.T) { + timeStr := "2024-11-08 12:13:23" + layout := "2006-01-02 15:04:05" + + res := ValidateDateTimeFormat(timeStr, layout) + + require.True(t, res) + + timeStr = "2024-11-08" + + res = ValidateDateTimeFormat(timeStr, layout) + + require.False(t, res) +} diff --git a/accounting/utils/parameters_test.go b/accounting/utils/parameters_test.go new file mode 100644 index 00000000..9778b6fd --- /dev/null +++ b/accounting/utils/parameters_test.go @@ -0,0 +1,43 @@ +package utils + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestParameters_GetSceneFromContext(t *testing.T) { + t.Run("valid scene value", func(t *testing.T) { + values := url.Values{} + values.Add("scene", "2") + req, err := http.NewRequest(http.MethodGet, "/test?"+values.Encode(), nil) + require.Nil(t, err) + + hr := httptest.NewRecorder() + ginContext, _ := gin.CreateTestContext(hr) + ginContext.Request = req + + res, err := GetSceneFromContext(ginContext) + require.Nil(t, err) + require.Equal(t, 2, res) + }) + + t.Run("invalid scene value", func(t *testing.T) { + values := url.Values{} + values.Add("scene", "a") + req, err := http.NewRequest(http.MethodGet, "/test?"+values.Encode(), nil) + require.Nil(t, err) + + hr := httptest.NewRecorder() + ginContext, _ := gin.CreateTestContext(hr) + ginContext.Request = req + + res, err := GetSceneFromContext(ginContext) + require.NotNil(t, err) + require.Equal(t, 0, res) + }) +} diff --git a/accounting/utils/scene.go b/accounting/utils/scene.go new file mode 100644 index 00000000..a621566f --- /dev/null +++ b/accounting/utils/scene.go @@ -0,0 +1,58 @@ +package utils + +import "opencsg.com/csghub-server/common/types" + +func IsNeedCalculateBill(scene types.SceneType) bool { + switch scene { + case types.SceneModelInference, + types.SceneSpace, + types.SceneModelFinetune, + types.SceneEvaluation, + types.SceneStarship, + types.SceneGuiAgent: + return true + default: + return false + } +} + +func GetSkuUnitTypeByScene(scene types.SceneType) string { + switch scene { + case types.SceneModelInference: + return types.UnitMinute + case types.SceneSpace: + return types.UnitMinute + case types.SceneModelFinetune: + return types.UnitMinute + case types.SceneMultiSync: + return types.UnitRepo + case types.SceneEvaluation: + return types.UnitMinute + case types.SceneStarship: + return types.UnitToken + case types.SceneGuiAgent: + return types.UnitToken + default: + return types.UnitMinute + } +} + +func GetSKUTypeByScene(scene types.SceneType) types.SKUType { + switch scene { + case types.SceneModelInference: + return types.SKUCSGHub + case types.SceneSpace: + return types.SKUCSGHub + case types.SceneModelFinetune: + return types.SKUCSGHub + case types.SceneMultiSync: + return types.SKUCSGHub + case types.SceneEvaluation: + return types.SKUCSGHub + case types.SceneStarship: + return types.SKUStarship + case types.SceneGuiAgent: + return types.SKUStarship + } + return types.SKUReserve +} diff --git a/accounting/utils/scene_test.go b/accounting/utils/scene_test.go new file mode 100644 index 00000000..6542be56 --- /dev/null +++ b/accounting/utils/scene_test.go @@ -0,0 +1,75 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/common/types" +) + +func TestScene_IsNeedCalculateBill(t *testing.T) { + + scenes := []types.SceneType{ + types.SceneModelInference, + types.SceneSpace, + types.SceneModelFinetune, + types.SceneEvaluation, + types.SceneStarship, + types.SceneGuiAgent, + } + + for _, scene := range scenes { + res := IsNeedCalculateBill(scene) + require.True(t, res) + } + + scenes = []types.SceneType{ + types.SceneReserve, + types.ScenePortalCharge, + types.ScenePayOrder, + types.SceneCashCharge, + types.SceneMultiSync, + types.SceneUnknow, + } + + for _, scene := range scenes { + res := IsNeedCalculateBill(scene) + require.False(t, res) + } + +} + +func TestScene_GetSkuUnitTypeByScene(t *testing.T) { + + scenes := map[types.SceneType]string{ + types.SceneModelInference: types.UnitMinute, + types.SceneSpace: types.UnitMinute, + types.SceneModelFinetune: types.UnitMinute, + types.SceneMultiSync: types.UnitRepo, + types.SceneEvaluation: types.UnitMinute, + types.SceneStarship: types.UnitToken, + types.SceneGuiAgent: types.UnitToken, + } + + for scene, unit := range scenes { + res := GetSkuUnitTypeByScene(scene) + require.Equal(t, unit, res) + } +} + +func TestScene_GetSKUTypeByScene(t *testing.T) { + scenes := map[types.SceneType]types.SKUType{ + types.SceneModelInference: types.SKUCSGHub, + types.SceneSpace: types.SKUCSGHub, + types.SceneModelFinetune: types.SKUCSGHub, + types.SceneMultiSync: types.SKUCSGHub, + types.SceneEvaluation: types.SKUCSGHub, + types.SceneStarship: types.SKUStarship, + types.SceneGuiAgent: types.SKUStarship, + } + + for scene, skuType := range scenes { + res := GetSKUTypeByScene(scene) + require.Equal(t, skuType, res) + } +} diff --git a/builder/store/database/account_metering.go b/builder/store/database/account_metering.go index 95b46fa9..682ce9e0 100644 --- a/builder/store/database/account_metering.go +++ b/builder/store/database/account_metering.go @@ -7,7 +7,6 @@ import ( "github.com/google/uuid" "opencsg.com/csghub-server/common/types" - commonTypes "opencsg.com/csghub-server/common/types" ) type accountMeteringStoreImpl struct { @@ -17,6 +16,7 @@ type accountMeteringStoreImpl struct { type AccountMeteringStore interface { Create(ctx context.Context, input AccountMetering) error ListByUserIDAndTime(ctx context.Context, req types.ACCT_STATEMENTS_REQ) ([]AccountMetering, int, error) + GetStatByDate(ctx context.Context, req types.ACCT_STATEMENTS_REQ) ([]map[string]interface{}, error) ListAllByUserUUID(ctx context.Context, userUUID string) ([]AccountMetering, error) } @@ -57,7 +57,7 @@ func (am *accountMeteringStoreImpl) Create(ctx context.Context, input AccountMet return nil } -func (am *accountMeteringStoreImpl) ListByUserIDAndTime(ctx context.Context, req commonTypes.ACCT_STATEMENTS_REQ) ([]AccountMetering, int, error) { +func (am *accountMeteringStoreImpl) ListByUserIDAndTime(ctx context.Context, req types.ACCT_STATEMENTS_REQ) ([]AccountMetering, int, error) { var accountMeters []AccountMetering q := am.db.Operator.Core.NewSelect().Model(&accountMeters).Where("user_uuid = ? and scene = ? and customer_id = ? and recorded_at >= ? and recorded_at <= ?", req.UserUUID, req.Scene, req.InstanceName, req.StartTime, req.EndTime) @@ -73,6 +73,31 @@ func (am *accountMeteringStoreImpl) ListByUserIDAndTime(ctx context.Context, req return accountMeters, count, nil } +func (am *accountMeteringStoreImpl) GetStatByDate(ctx context.Context, req types.ACCT_STATEMENTS_REQ) ([]map[string]interface{}, error) { + var meter []AccountMetering + var res []map[string]interface{} + err := am.db.Operator.Core.NewSelect().Model(&meter). + ColumnExpr("users.username"). + ColumnExpr("account_metering.user_uuid"). + ColumnExpr("account_metering.resource_id"). + ColumnExpr("sum(account_metering.value) as value"). + Join("join users on users.uuid = account_metering.user_uuid"). + Where("account_metering.scene = ?", req.Scene). + Where("account_metering.recorded_at >= ?", req.StartTime). + Where("account_metering.recorded_at <= ?", req.EndTime). + Group("users.username"). + Group("account_metering.user_uuid"). + Group("account_metering.resource_id"). + Order("account_metering.resource_id"). + Order("value desc"). + Scan(ctx, &res) + + if err != nil { + return nil, fmt.Errorf("select metering stat, error: %w", err) + } + return res, nil +} + func (am *accountMeteringStoreImpl) ListAllByUserUUID(ctx context.Context, userUUID string) ([]AccountMetering, error) { var accountMeters []AccountMetering err := am.db.Operator.Core.NewSelect().Model(&accountMeters).Where("user_uuid = ?", userUUID).Scan(ctx, &accountMeters) diff --git a/builder/store/database/account_metering_test.go b/builder/store/database/account_metering_test.go index adb4469d..7087f398 100644 --- a/builder/store/database/account_metering_test.go +++ b/builder/store/database/account_metering_test.go @@ -46,42 +46,42 @@ func TestAccountMeteringStore_ListByUserIDAndTime(t *testing.T) { ams := []database.AccountMetering{ { UserUUID: "foo", Value: 12.34, ValueType: 1, - ResourceName: "r1", Scene: types.SceneSpace, CustomerID: "bar", + ResourceName: "r1", Scene: types.ScenePayOrder, CustomerID: "bar", RecordedAt: dt.Add(-1 * time.Hour), EventUUID: uuid.New(), }, { UserUUID: "foo", Value: 12.34, ValueType: 1, - ResourceName: "r2", Scene: types.SceneSpace, CustomerID: "bar", + ResourceName: "r2", Scene: types.ScenePayOrder, CustomerID: "bar", RecordedAt: dt.Add(-2 * time.Hour), EventUUID: uuid.New(), }, { UserUUID: "foo", Value: 12.34, ValueType: 1, - ResourceName: "r3", Scene: types.SceneSpace, CustomerID: "bar", + ResourceName: "r3", Scene: types.ScenePayOrder, CustomerID: "bar", RecordedAt: dt.Add(1 * time.Hour), EventUUID: uuid.New(), }, { UserUUID: "foo", Value: 12.34, ValueType: 1, - ResourceName: "r4", Scene: types.SceneSpace, CustomerID: "bar", + ResourceName: "r4", Scene: types.ScenePayOrder, CustomerID: "bar", RecordedAt: dt.Add(2 * time.Hour), EventUUID: uuid.New(), }, { UserUUID: "foo", Value: 12.34, ValueType: 1, - ResourceName: "r5", Scene: types.SceneSpace, CustomerID: "bar", + ResourceName: "r5", Scene: types.ScenePayOrder, CustomerID: "bar", RecordedAt: dt.Add(-1 * time.Hour), EventUUID: uuid.New(), }, { UserUUID: "foo", Value: 12.34, ValueType: 1, - ResourceName: "r6", Scene: types.SceneSpace, CustomerID: "bar", + ResourceName: "r6", Scene: types.ScenePayOrder, CustomerID: "bar", RecordedAt: dt.Add(-6 * time.Hour), EventUUID: uuid.New(), }, { UserUUID: "foo", Value: 12.34, ValueType: 1, - ResourceName: "r7", Scene: types.SceneSpace, CustomerID: "bar", + ResourceName: "r7", Scene: types.ScenePayOrder, CustomerID: "bar", RecordedAt: dt.Add(6 * time.Hour), EventUUID: uuid.New(), }, { UserUUID: "bar", Value: 12.34, ValueType: 1, - ResourceName: "r8", Scene: types.SceneSpace, CustomerID: "bar", + ResourceName: "r8", Scene: types.ScenePayOrder, CustomerID: "bar", RecordedAt: dt.Add(-1 * time.Hour), EventUUID: uuid.New(), }, { @@ -91,7 +91,7 @@ func TestAccountMeteringStore_ListByUserIDAndTime(t *testing.T) { }, { UserUUID: "foo", Value: 12.34, ValueType: 1, - ResourceName: "r10", Scene: types.SceneSpace, CustomerID: "barz", + ResourceName: "r10", Scene: types.ScenePayOrder, CustomerID: "barz", RecordedAt: dt.Add(-1 * time.Hour), EventUUID: uuid.New(), }, } @@ -103,7 +103,7 @@ func TestAccountMeteringStore_ListByUserIDAndTime(t *testing.T) { ams, total, err := store.ListByUserIDAndTime(ctx, types.ACCT_STATEMENTS_REQ{ UserUUID: "foo", - Scene: 11, + Scene: 2, InstanceName: "bar", StartTime: dt.Add(-5 * time.Hour).Format(time.RFC3339), EndTime: dt.Add(5 * time.Hour).Format(time.RFC3339), @@ -117,6 +117,92 @@ func TestAccountMeteringStore_ListByUserIDAndTime(t *testing.T) { require.Equal(t, []string{"r5", "r4", "r3", "r2", "r1"}, names) } +func TestAccountMeteringStore_GetStatByDate(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewAccountMeteringStoreWithDB(db) + us := database.NewUserStoreWithDB(db) + err := us.Create(ctx, &database.User{ + Username: "u1", + UUID: "foo", + }, &database.Namespace{Path: "a"}) + require.Nil(t, err) + + err = us.Create(ctx, &database.User{ + Username: "u2", + UUID: "bar", + }, &database.Namespace{Path: "b"}) + require.Nil(t, err) + + dt := time.Date(2022, 11, 22, 3, 0, 0, 0, time.UTC) + ams := []database.AccountMetering{ + { + UserUUID: "foo", Value: 1.1, ValueType: 1, + ResourceID: "r1", Scene: types.ScenePayOrder, + RecordedAt: dt.Add(-1 * time.Hour), EventUUID: uuid.New(), + }, + { + UserUUID: "foo", Value: 1.2, ValueType: 2, + ResourceID: "r1", Scene: types.SceneModelFinetune, + RecordedAt: dt.Add(-2 * time.Hour), EventUUID: uuid.New(), + }, + { + UserUUID: "foo", Value: 1.2, ValueType: 2, + ResourceID: "r1", Scene: types.ScenePayOrder, + RecordedAt: dt.Add(-6 * time.Hour), EventUUID: uuid.New(), + }, + { + UserUUID: "foo", Value: 1.5, ValueType: 1, + ResourceID: "r2", Scene: types.ScenePayOrder, + RecordedAt: dt.Add(-1 * time.Hour), EventUUID: uuid.New(), + }, + { + UserUUID: "bar", Value: 1.1, ValueType: 1, + ResourceID: "r1", Scene: types.ScenePayOrder, + RecordedAt: dt.Add(-1 * time.Hour), EventUUID: uuid.New(), + }, + { + UserUUID: "bar", Value: 1.2, ValueType: 2, + ResourceID: "r1", Scene: types.ScenePayOrder, + RecordedAt: dt.Add(-6 * time.Hour), EventUUID: uuid.New(), + }, + { + UserUUID: "bar", Value: 1.2, ValueType: 2, + ResourceID: "r1", Scene: types.ScenePayOrder, + RecordedAt: dt.Add(-2 * time.Hour), EventUUID: uuid.New(), + }, + { + UserUUID: "bar", Value: 1.5, ValueType: 1, + ResourceID: "r2", Scene: types.ScenePayOrder, + RecordedAt: dt.Add(-1 * time.Hour), EventUUID: uuid.New(), + }, + } + + for _, am := range ams { + err := store.Create(ctx, am) + require.Nil(t, err) + } + + data, err := store.GetStatByDate(ctx, types.ACCT_STATEMENTS_REQ{ + UserUUID: "foo", + Scene: 2, + InstanceName: "bar", + StartTime: dt.Add(-5 * time.Hour).Format(time.RFC3339), + EndTime: dt.Add(5 * time.Hour).Format(time.RFC3339), + }) + require.Nil(t, err) + require.Equal(t, 4, len(data)) + expected := []map[string]interface{}{ + {"resource_id": "r1", "user_uuid": "bar", "username": "u2", "value": 2.3}, + {"resource_id": "r1", "user_uuid": "foo", "username": "u1", "value": 1.1}, + {"resource_id": "r2", "user_uuid": "foo", "username": "u1", "value": 1.5}, + {"resource_id": "r2", "user_uuid": "bar", "username": "u2", "value": 1.5}, + } + require.Equal(t, expected, data) +} + func TestAccountMeteringStore_ListAllByUserUUID(t *testing.T) { db := tests.InitTestDB() defer db.Close() @@ -126,22 +212,22 @@ func TestAccountMeteringStore_ListAllByUserUUID(t *testing.T) { ams := []database.AccountMetering{ { UserUUID: "foo", Value: 12.34, ValueType: 1, - ResourceName: "r1", Scene: types.SceneSpace, CustomerID: "bar", + ResourceName: "r1", Scene: types.ScenePayOrder, CustomerID: "bar", EventUUID: uuid.New(), }, { UserUUID: "foo", Value: 12.34, ValueType: 1, - ResourceName: "r2", Scene: types.SceneSpace, CustomerID: "bar", + ResourceName: "r2", Scene: types.ScenePayOrder, CustomerID: "bar", EventUUID: uuid.New(), }, { UserUUID: "foo", Value: 12.34, ValueType: 1, - ResourceName: "r3", Scene: types.SceneSpace, CustomerID: "bar", + ResourceName: "r3", Scene: types.ScenePayOrder, CustomerID: "bar", EventUUID: uuid.New(), }, { UserUUID: "bar", Value: 12.34, ValueType: 1, - ResourceName: "r4", Scene: types.SceneSpace, CustomerID: "bar", + ResourceName: "r4", Scene: types.ScenePayOrder, CustomerID: "bar", EventUUID: uuid.New(), }, } diff --git a/common/config/config.go b/common/config/config.go index bd87c130..c99d92ca 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -164,8 +164,9 @@ type Config struct { } Accounting struct { - Host string `env:"OPENCSG_ACCOUNTING_SERVER_HOST, default=http://localhost"` - Port int `env:"OPENCSG_ACCOUNTING_SERVER_PORT, default=8086"` + Host string `env:"OPENCSG_ACCOUNTING_SERVER_HOST, default=http://localhost"` + Port int `env:"OPENCSG_ACCOUNTING_SERVER_PORT, default=8086"` + ChargingEnable bool `env:"OPENCSG_ACCOUNTING_CHARGING_ENABLE, default=false"` } User struct { diff --git a/common/types/accounting.go b/common/types/accounting.go index 303f6a91..6b6028c0 100644 --- a/common/types/accounting.go +++ b/common/types/accounting.go @@ -27,12 +27,16 @@ type SceneType int var ( SceneReserve SceneType = 0 // system reserve + ScenePortalCharge SceneType = 1 // portal charge fee + ScenePayOrder SceneType = 2 // create order to reduce fee + SceneCashCharge SceneType = 3 // cash charge from user payment SceneModelInference SceneType = 10 // model inference endpoint SceneSpace SceneType = 11 // csghub space SceneModelFinetune SceneType = 12 // model finetune SceneMultiSync SceneType = 13 // multi sync SceneEvaluation SceneType = 14 // model evaluation SceneStarship SceneType = 20 // starship + SceneGuiAgent SceneType = 22 // gui agent SceneUnknow SceneType = 99 // unknow ) diff --git a/mq/messagequeue.go b/mq/messagequeue.go index 272ada8b..4ff06cc8 100644 --- a/mq/messagequeue.go +++ b/mq/messagequeue.go @@ -14,10 +14,14 @@ type MessageQueue interface { BuildEventStreamAndConsumer(cfg EventConfig, streamCfg jetstream.StreamConfig, consumerCfg jetstream.ConsumerConfig) (jetstream.Consumer, error) BuildMeterEventStream() error BuildDLQStream() error + FetchMeterEventMessages(batch int) (jetstream.MessageBatch, error) VerifyStreamByName(streamName string) error VerifyMeteringStream() error VerifyDLQStream() error PublishData(subject string, data []byte) error PublishMeterDataToDLQ(data []byte) error PublishMeterDurationData(data []byte) error + PublishFeeCreditData(data []byte) error + PublishFeeTokenData(data []byte) error + PublishFeeQuotaData(data []byte) error } diff --git a/mq/nats.go b/mq/nats.go index ee2e18ca..5f75a995 100644 --- a/mq/nats.go +++ b/mq/nats.go @@ -23,6 +23,7 @@ type DLQEventConfig struct { } type RequestSubject struct { + fee string // fee charging subject token string // token subject quota string // quota subject duration string // duration subject @@ -52,6 +53,7 @@ type NatsHandler struct { conn *nats.Conn msgFetchTimeoutInSec int + feeReqSub RequestSubject meterReqSub RequestSubject dlqEvtCfg jetstream.StreamConfig meterEvtCfg jetstream.StreamConfig @@ -200,3 +202,15 @@ func (nh *NatsHandler) PublishMeterDataToDLQ(data []byte) error { func (nh *NatsHandler) PublishMeterDurationData(data []byte) error { return nh.PublishData(nh.meterReqSub.duration, data) } + +func (nh *NatsHandler) PublishFeeCreditData(data []byte) error { + return nh.PublishData(nh.feeReqSub.fee, data) +} + +func (nh *NatsHandler) PublishFeeTokenData(data []byte) error { + return nh.PublishData(nh.feeReqSub.token, data) +} + +func (nh *NatsHandler) PublishFeeQuotaData(data []byte) error { + return nh.PublishData(nh.feeReqSub.quota, data) +} From 9a9d388f33f32422a134afd840a43aaa396b4686 Mon Sep 17 00:00:00 2001 From: yiling Date: Fri, 27 Dec 2024 16:49:20 +0800 Subject: [PATCH 26/34] Sync runtime arch component with enterprise --- component/runtime_architecture.go | 10 +- component/runtime_architecture_ce.go | 17 +++ component/runtime_architecture_test.go | 146 +++++++++++++------------ 3 files changed, 103 insertions(+), 70 deletions(-) create mode 100644 component/runtime_architecture_ce.go diff --git a/component/runtime_architecture.go b/component/runtime_architecture.go index b6f2f872..e83394a5 100644 --- a/component/runtime_architecture.go +++ b/component/runtime_architecture.go @@ -127,7 +127,7 @@ func (c *runtimeArchitectureComponentImpl) ScanArchitecture(ctx context.Context, if err != nil { return fmt.Errorf("list runtime arch failed, %w", err) } - var archMap map[string]string = make(map[string]string) + var archMap = make(map[string]string) for _, arch := range archs { archMap[arch.ArchitectureName] = arch.ArchitectureName } @@ -232,6 +232,10 @@ func (c *runtimeArchitectureComponentImpl) IsSupportedModelResource(ctx context. if strings.Contains(image, rm.EngineName) { return true, nil } + if matchRuntimeFrameworkWithEngineEE(rf, rm.EngineName) { + return true, nil + } + // special handling for nim models nimImage := strings.ReplaceAll(image, "-", "") nimMatchModel := strings.ReplaceAll(trimModel, "-", "") @@ -311,7 +315,7 @@ func (c *runtimeArchitectureComponentImpl) getConfigContent(ctx context.Context, func (c *runtimeArchitectureComponentImpl) RemoveRuntimeFrameworkTag(ctx context.Context, rftags []*database.Tag, repoId, rfId int64) { rfw, _ := c.runtimeFrameworksStore.FindByID(ctx, rfId) for _, tag := range rftags { - if strings.Contains(rfw.FrameImage, tag.Name) { + if checkTagName(rfw, tag.Name) { err := c.tagStore.RemoveRepoTags(ctx, repoId, []int64{tag.ID}) if err != nil { slog.Warn("fail to remove runtime_framework tag from model repo", slog.Any("repoId", repoId), slog.Any("runtime_framework_id", rfId), slog.Any("error", err)) @@ -327,7 +331,7 @@ func (c *runtimeArchitectureComponentImpl) AddRuntimeFrameworkTag(ctx context.Co return err } for _, tag := range rftags { - if strings.Contains(rfw.FrameImage, tag.Name) { + if checkTagName(rfw, tag.Name) { err := c.tagStore.UpsertRepoTags(ctx, repoId, []int64{}, []int64{tag.ID}) if err != nil { slog.Warn("fail to add runtime_framework tag to model repo", slog.Any("repoId", repoId), slog.Any("runtime_framework_id", rfId), slog.Any("error", err)) diff --git a/component/runtime_architecture_ce.go b/component/runtime_architecture_ce.go new file mode 100644 index 00000000..cbdaef81 --- /dev/null +++ b/component/runtime_architecture_ce.go @@ -0,0 +1,17 @@ +//go:build !ee && !saas + +package component + +import ( + "strings" + + "opencsg.com/csghub-server/builder/store/database" +) + +func matchRuntimeFrameworkWithEngineEE(rf *database.RuntimeFramework, engine string) bool { + return false +} + +func checkTagName(rf *database.RuntimeFramework, tag string) bool { + return strings.Contains(rf.FrameImage, tag) +} diff --git a/component/runtime_architecture_test.go b/component/runtime_architecture_test.go index 5db7dfdf..2755608d 100644 --- a/component/runtime_architecture_test.go +++ b/component/runtime_architecture_test.go @@ -3,6 +3,7 @@ package component import ( "context" "errors" + "fmt" "testing" "github.com/stretchr/testify/require" @@ -118,18 +119,36 @@ func TestRuntimeArchComponent_ScanArchitectures(t *testing.T) { } func TestRuntimeArchComponent_IsSupportedModelResource(t *testing.T) { - ctx := context.TODO() - rc := initializeTestRuntimeArchComponent(ctx, t) - rc.mocks.stores.ResourceModelMock().EXPECT().CheckModelNameNotInRFRepo(ctx, "model", int64(1)).Return( - &database.ResourceModel{EngineName: "a"}, nil, - ) + cases := []struct { + image string + support bool + }{ + {"foo", false}, + {"bar", true}, + {"foo/bar", true}, + {"bar/foo", false}, + {"foo-bar", true}, + {"foo-model", true}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("%+v", c), func(t *testing.T) { - r, err := rc.IsSupportedModelResource(ctx, "meta-model", &database.RuntimeFramework{ - FrameImage: "a/b", - }, 1) - require.Nil(t, err, nil) - require.False(t, r) + ctx := context.TODO() + rc := initializeTestRuntimeArchComponent(ctx, t) + + rc.mocks.stores.ResourceModelMock().EXPECT().CheckModelNameNotInRFRepo(ctx, "model", int64(1)).Return( + &database.ResourceModel{EngineName: "a"}, nil, + ) + + r, err := rc.IsSupportedModelResource(ctx, "meta-model", &database.RuntimeFramework{ + FrameImage: c.image, + }, 1) + require.Nil(t, err, nil) + require.Equal(t, c.support, r) + }) + } } func TestRuntimeArchComponent_GetArchitectureFromConfig(t *testing.T) { @@ -150,60 +169,53 @@ func TestRuntimeArchComponent_GetArchitectureFromConfig(t *testing.T) { } -// func TestRuntimeArchComponent_RemoveRuntimeFrameworkTag(t *testing.T) { -// ctx := context.TODO() -// rc := initializeTestRuntimeArchComponent(ctx, t) - -// rc.mocks.stores.RuntimeFrameworkMock().EXPECT().FindByID(ctx, int64(2)).Return( -// &database.RuntimeFramework{ -// FrameImage: "img", -// FrameNpuImage: "npu", -// }, nil, -// ) -// rc.mocks.stores.TagMock().EXPECT().RemoveRepoTags(ctx, int64(1), []int64{1}).Return(nil) -// rc.mocks.stores.TagMock().EXPECT().RemoveRepoTags(ctx, int64(1), []int64{2}).Return(nil) - -// rc.RemoveRuntimeFrameworkTag(ctx, []*database.Tag{ -// {Name: "img", ID: 1}, -// {Name: "npu", ID: 2}, -// }, int64(1), int64(2)) -// } - -// func TestRuntimeArchComponent_AddRuntimeFrameworkTag(t *testing.T) { -// ctx := context.TODO() -// rc := initializeTestRuntimeArchComponent(ctx, t) - -// rc.mocks.stores.RuntimeFrameworkMock().EXPECT().FindByID(ctx, int64(2)).Return( -// &database.RuntimeFramework{ -// FrameImage: "img", -// FrameNpuImage: "npu", -// }, nil, -// ) -// rc.mocks.stores.TagMock().EXPECT().UpsertRepoTags(ctx, int64(1), []int64{}, []int64{1}).Return(nil) -// rc.mocks.stores.TagMock().EXPECT().UpsertRepoTags(ctx, int64(1), []int64{}, []int64{2}).Return(nil) - -// err := rc.AddRuntimeFrameworkTag(ctx, []*database.Tag{ -// {Name: "img", ID: 1}, -// {Name: "npu", ID: 2}, -// }, int64(1), int64(2)) -// require.Nil(t, err) -// } - -// func TestRuntimeArchComponent_AddResourceTag(t *testing.T) { -// ctx := context.TODO() -// rc := initializeTestRuntimeArchComponent(ctx, t) - -// rc.mocks.stores.ResourceModelMock().EXPECT().FindByModelName(ctx, "model").Return( -// []*database.ResourceModel{ -// {ResourceName: "r1"}, -// {ResourceName: "r2"}, -// }, nil, -// ) -// rc.mocks.stores.TagMock().EXPECT().UpsertRepoTags(ctx, int64(1), []int64{}, []int64{1}).Return(nil) -// rc.mocks.stores.TagMock().EXPECT().UpsertRepoTags(ctx, int64(1), []int64{}, []int64{2}).Return(nil) - -// err := rc.AddResourceTag(ctx, []*database.Tag{ -// {Name: "r1", ID: 1}, -// }, "model", int64(1)) -// require.Nil(t, err) -// } +func TestRuntimeArchComponent_RemoveRuntimeFrameworkTag(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRuntimeArchComponent(ctx, t) + + rc.mocks.stores.RuntimeFrameworkMock().EXPECT().FindByID(ctx, int64(2)).Return( + &database.RuntimeFramework{ + FrameImage: "img", + }, nil, + ) + rc.mocks.stores.TagMock().EXPECT().RemoveRepoTags(ctx, int64(1), []int64{1}).Return(nil) + + rc.RemoveRuntimeFrameworkTag(ctx, []*database.Tag{ + {Name: "img", ID: 1}, + }, int64(1), int64(2)) +} + +func TestRuntimeArchComponent_AddRuntimeFrameworkTag(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRuntimeArchComponent(ctx, t) + + rc.mocks.stores.RuntimeFrameworkMock().EXPECT().FindByID(ctx, int64(2)).Return( + &database.RuntimeFramework{ + FrameImage: "img", + }, nil, + ) + rc.mocks.stores.TagMock().EXPECT().UpsertRepoTags(ctx, int64(1), []int64{}, []int64{1}).Return(nil) + + err := rc.AddRuntimeFrameworkTag(ctx, []*database.Tag{ + {Name: "img", ID: 1}, + }, int64(1), int64(2)) + require.Nil(t, err) +} + +func TestRuntimeArchComponent_AddResourceTag(t *testing.T) { + ctx := context.TODO() + rc := initializeTestRuntimeArchComponent(ctx, t) + + rc.mocks.stores.ResourceModelMock().EXPECT().FindByModelName(ctx, "model").Return( + []*database.ResourceModel{ + {ResourceName: "r1"}, + {ResourceName: "r2"}, + }, nil, + ) + rc.mocks.stores.TagMock().EXPECT().UpsertRepoTags(ctx, int64(1), []int64{}, []int64{1}).Return(nil) + + err := rc.AddResourceTag(ctx, []*database.Tag{ + {Name: "r1", ID: 1}, + }, "model", int64(1)) + require.Nil(t, err) +} From 30d572499ca187e669c4291696233adf111133d1 Mon Sep 17 00:00:00 2001 From: yiling Date: Fri, 27 Dec 2024 17:25:18 +0800 Subject: [PATCH 27/34] add prompt handler tests --- api/handler/prompt.go | 54 +++--- api/handler/prompt_test.go | 338 +++++++++++++++++++++++++++++++++++++ 2 files changed, 364 insertions(+), 28 deletions(-) create mode 100644 api/handler/prompt_test.go diff --git a/api/handler/prompt.go b/api/handler/prompt.go index a8defe4b..0bd6c750 100644 --- a/api/handler/prompt.go +++ b/api/handler/prompt.go @@ -17,9 +17,9 @@ import ( ) type PromptHandler struct { - pc component.PromptComponent - sc component.SensitiveComponent - repo component.RepoComponent + prompt component.PromptComponent + sensitive component.SensitiveComponent + repo component.RepoComponent } func NewPromptHandler(cfg *config.Config) (*PromptHandler, error) { @@ -33,13 +33,12 @@ func NewPromptHandler(cfg *config.Config) (*PromptHandler, error) { } repo, err := component.NewRepoComponent(cfg) if err != nil { - return nil, fmt.Errorf("error creating repo component:%w", err) + return nil, fmt.Errorf("failed to create repo component: %w", err) } - return &PromptHandler{ - pc: promptComp, - sc: sc, - repo: repo, + prompt: promptComp, + sensitive: sc, + repo: repo, }, nil } @@ -89,7 +88,7 @@ func (h *PromptHandler) Index(ctx *gin.Context) { return } - prompts, total, err := h.pc.IndexPromptRepo(ctx, filter, per, page) + prompts, total, err := h.prompt.IndexPromptRepo(ctx, filter, per, page) if err != nil { slog.Error("Failed to get prompts dataset", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -125,7 +124,7 @@ func (h *PromptHandler) ListPrompt(ctx *gin.Context) { return } - detail, err := h.pc.Show(ctx, namespace, name, currentUser) + detail, err := h.prompt.Show(ctx, namespace, name, currentUser) if err != nil { if errors.Is(err, component.ErrUnauthorized) { httpbase.UnauthorizedError(ctx, err) @@ -141,7 +140,7 @@ func (h *PromptHandler) ListPrompt(ctx *gin.Context) { Name: name, CurrentUser: currentUser, } - data, err := h.pc.ListPrompt(ctx, req) + data, err := h.prompt.ListPrompt(ctx, req) if err != nil { slog.Error("Failed to list prompts of repo", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -190,7 +189,7 @@ func (h *PromptHandler) GetPrompt(ctx *gin.Context) { CurrentUser: currentUser, Path: convertFilePathFromRoute(filePath), } - data, err := h.pc.GetPrompt(ctx, req) + data, err := h.prompt.GetPrompt(ctx, req) if err != nil { slog.Error("Failed to get prompt of repo", slog.Any("req", req), slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -232,7 +231,7 @@ func (h *PromptHandler) CreatePrompt(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - _, err = h.sc.CheckRequestV2(ctx, body) + _, err = h.sensitive.CheckRequestV2(ctx, body) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -244,8 +243,7 @@ func (h *PromptHandler) CreatePrompt(ctx *gin.Context) { Name: name, CurrentUser: currentUser, } - - data, err := h.pc.CreatePrompt(ctx, req, body) + data, err := h.prompt.CreatePrompt(ctx, req, body) if err != nil { slog.Error("Failed to create prompt file of repo", slog.Any("req", req), slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -294,7 +292,7 @@ func (h *PromptHandler) UpdatePrompt(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - _, err = h.sc.CheckRequestV2(ctx, body) + _, err = h.sensitive.CheckRequestV2(ctx, body) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -307,7 +305,7 @@ func (h *PromptHandler) UpdatePrompt(ctx *gin.Context) { CurrentUser: currentUser, Path: convertFilePathFromRoute(filePath), } - data, err := h.pc.UpdatePrompt(ctx, req, body) + data, err := h.prompt.UpdatePrompt(ctx, req, body) if err != nil { slog.Error("Failed to update prompt file of repo", slog.Any("req", req), slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -357,7 +355,7 @@ func (h *PromptHandler) DeletePrompt(ctx *gin.Context) { CurrentUser: currentUser, Path: convertFilePathFromRoute(filePath), } - err = h.pc.DeletePrompt(ctx, req) + err = h.prompt.DeletePrompt(ctx, req) if err != nil { slog.Error("Failed to remove prompt file of repo", slog.Any("req", req), slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -387,7 +385,7 @@ func (h *PromptHandler) Relations(ctx *gin.Context) { return } currentUser := httpbase.GetCurrentUser(ctx) - detail, err := h.pc.Relations(ctx, namespace, name, currentUser) + detail, err := h.prompt.Relations(ctx, namespace, name, currentUser) if err != nil { if errors.Is(err, component.ErrUnauthorized) { httpbase.UnauthorizedError(ctx, err) @@ -439,7 +437,7 @@ func (h *PromptHandler) SetRelations(ctx *gin.Context) { req.Name = name req.CurrentUser = currentUser - err = h.pc.SetRelationModels(ctx, req) + err = h.prompt.SetRelationModels(ctx, req) if err != nil { slog.Error("Failed to set models for prompt", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -486,7 +484,7 @@ func (h *PromptHandler) AddModelRelation(ctx *gin.Context) { req.Name = name req.CurrentUser = currentUser - err = h.pc.AddRelationModel(ctx, req) + err = h.prompt.AddRelationModel(ctx, req) if err != nil { slog.Error("Failed to add model for prompt", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -533,7 +531,7 @@ func (h *PromptHandler) DelModelRelation(ctx *gin.Context) { req.Name = name req.CurrentUser = currentUser - err = h.pc.DelRelationModel(ctx, req) + err = h.prompt.DelRelationModel(ctx, req) if err != nil { slog.Error("Failed to delete dataset for model", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -567,7 +565,7 @@ func (h *PromptHandler) Create(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - _, err := h.sc.CheckRequestV2(ctx, req) + _, err := h.sensitive.CheckRequestV2(ctx, req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -575,7 +573,7 @@ func (h *PromptHandler) Create(ctx *gin.Context) { } req.Username = currentUser - prompt, err := h.pc.CreatePromptRepo(ctx, req) + prompt, err := h.prompt.CreatePromptRepo(ctx, req) if err != nil { slog.Error("Failed to create prompt repo", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -616,7 +614,7 @@ func (h *PromptHandler) Update(ctx *gin.Context) { return } - _, err := h.sc.CheckRequestV2(ctx, req) + _, err := h.sensitive.CheckRequestV2(ctx, req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -633,7 +631,7 @@ func (h *PromptHandler) Update(ctx *gin.Context) { req.Namespace = namespace req.Name = name - prompt, err := h.pc.UpdatePromptRepo(ctx, req) + prompt, err := h.prompt.UpdatePromptRepo(ctx, req) if err != nil { slog.Error("Failed to update prompt repo", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -669,7 +667,7 @@ func (h *PromptHandler) Delete(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - err = h.pc.RemoveRepo(ctx, namespace, name, currentUser) + err = h.prompt.RemoveRepo(ctx, namespace, name, currentUser) if err != nil { slog.Error("Failed to delete prompt repo", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -778,7 +776,7 @@ func (h *PromptHandler) Tags(ctx *gin.Context) { func (h *PromptHandler) UpdateTags(ctx *gin.Context) { currentUser := httpbase.GetCurrentUser(ctx) if currentUser == "" { - httpbase.UnauthorizedError(ctx, httpbase.ErrorNeedLogin) + httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) return } namespace, name, err := common.GetNamespaceAndNameFromContext(ctx) diff --git a/api/handler/prompt_test.go b/api/handler/prompt_test.go new file mode 100644 index 00000000..3f7ad019 --- /dev/null +++ b/api/handler/prompt_test.go @@ -0,0 +1,338 @@ +package handler + +import ( + "fmt" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + mock_component "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +type PromptTester struct { + *GinTester + handler *PromptHandler + mocks struct { + prompt *mock_component.MockPromptComponent + sensitive *mock_component.MockSensitiveComponent + repo *mock_component.MockRepoComponent + } +} + +func NewPromptTester(t *testing.T) *PromptTester { + tester := &PromptTester{GinTester: NewGinTester()} + tester.mocks.prompt = mock_component.NewMockPromptComponent(t) + tester.mocks.sensitive = mock_component.NewMockSensitiveComponent(t) + tester.mocks.repo = mock_component.NewMockRepoComponent(t) + tester.handler = &PromptHandler{ + prompt: tester.mocks.prompt, sensitive: tester.mocks.sensitive, + repo: tester.mocks.repo, + } + tester.WithParam("name", "r") + tester.WithParam("namespace", "u") + return tester + +} + +func (t *PromptTester) WithHandleFunc(fn func(h *PromptHandler) gin.HandlerFunc) *PromptTester { + t.ginHandler = fn(t.handler) + return t + +} + +func TestPromptHandler_Index(t *testing.T) { + cases := []struct { + sort string + source string + error bool + }{ + {"most_download", "local", false}, + {"foo", "local", true}, + {"most_download", "bar", true}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("%+v", c), func(t *testing.T) { + + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.Index + }) + + if !c.error { + tester.mocks.prompt.EXPECT().IndexPromptRepo(tester.ctx, &types.RepoFilter{ + Search: "foo", + Sort: c.sort, + Source: c.source, + }, 10, 1).Return([]types.PromptRes{ + {Name: "cc"}, + }, 100, nil) + } + + tester.AddPagination(1, 10).WithQuery("search", "foo"). + WithQuery("sort", c.sort). + WithQuery("source", c.source).Execute() + + if c.error { + require.Equal(t, 400, tester.response.Code) + } else { + tester.ResponseEqSimple(t, 200, gin.H{ + "data": []types.PromptRes{{Name: "cc"}}, + "total": 100, + }) + } + }) + } +} + +func TestPromptHandler_ListPrompt(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.ListPrompt + }) + + tester.WithUser() + tester.mocks.prompt.EXPECT().Show(tester.ctx, "u", "r", "u").Return(&types.PromptRes{Name: "p"}, nil) + tester.mocks.prompt.EXPECT().ListPrompt(tester.ctx, types.PromptReq{ + Namespace: "u", Name: "r", CurrentUser: "u", + }).Return([]types.PromptOutput{{FilePath: "fp"}}, nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, gin.H{ + "detail": &types.PromptRes{Name: "p"}, + "prompts": []types.PromptOutput{{FilePath: "fp"}}, + }) +} + +func TestPromptHandler_GetPrompt(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.GetPrompt + }) + + tester.WithUser().WithParam("file_path", "fp") + tester.mocks.prompt.EXPECT().GetPrompt(tester.ctx, types.PromptReq{ + Namespace: "u", Name: "r", CurrentUser: "u", Path: "fp", + }).Return(&types.PromptOutput{FilePath: "fp"}, nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.PromptOutput{FilePath: "fp"}) +} + +func TestPromptHandler_CreatePrompt(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.CreatePrompt + }) + tester.RequireUser(t) + + req := &types.CreatePromptReq{Prompt: types.Prompt{ + Title: "t", Content: "c", Language: "l", + }} + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, req).Return(true, nil) + tester.mocks.prompt.EXPECT().CreatePrompt(tester.ctx, types.PromptReq{ + Namespace: "u", Name: "r", CurrentUser: "u", + }, req).Return(&types.Prompt{Title: "p"}, nil) + tester.WithBody(t, req).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Prompt{Title: "p"}) +} + +func TestPromptHandler_UpdatePrompt(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.UpdatePrompt + }) + tester.RequireUser(t) + + req := &types.UpdatePromptReq{Prompt: types.Prompt{ + Title: "t", Content: "c", Language: "l", + }} + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, req).Return(true, nil) + tester.mocks.prompt.EXPECT().UpdatePrompt(tester.ctx, types.PromptReq{ + Namespace: "u", Name: "r", CurrentUser: "u", Path: "fp", + }, req).Return(&types.Prompt{Title: "p"}, nil) + tester.WithParam("file_path", "fp").WithBody(t, req).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Prompt{Title: "p"}) +} + +func TestPromptHandler_DeletePrompt(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.DeletePrompt + }) + tester.RequireUser(t) + + tester.WithUser().WithParam("file_path", "fp") + tester.mocks.prompt.EXPECT().DeletePrompt(tester.ctx, types.PromptReq{ + Namespace: "u", Name: "r", CurrentUser: "u", Path: "fp", + }).Return(nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestPromptHandler_Relations(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.Relations + }) + + tester.WithUser() + tester.mocks.prompt.EXPECT().Relations(tester.ctx, "u", "r", "u").Return(&types.Relations{}, nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Relations{}) +} + +func TestPromptHandler_SetRelations(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.SetRelations + }) + tester.RequireUser(t) + + req := types.RelationModels{Namespace: "u", Name: "r", CurrentUser: "u"} + tester.mocks.prompt.EXPECT().SetRelationModels(tester.ctx, req).Return(nil) + tester.WithBody(t, types.RelationModels{Name: "rm"}).Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestPromptHandler_AddModelRelation(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.AddModelRelation + }) + tester.RequireUser(t) + + req := types.RelationModel{Namespace: "u", Name: "r", CurrentUser: "u"} + tester.mocks.prompt.EXPECT().AddRelationModel(tester.ctx, req).Return(nil) + tester.WithBody(t, types.RelationModels{Name: "rm"}).Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestPromptHandler_DeleteModelRelation(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.AddModelRelation + }) + tester.RequireUser(t) + + req := types.RelationModel{Namespace: "u", Name: "r", CurrentUser: "u"} + tester.mocks.prompt.EXPECT().AddRelationModel(tester.ctx, req).Return(nil) + tester.WithBody(t, types.RelationModels{Name: "rm"}).Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestPromptHandler_CreatePromptRepo(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.Create + }) + tester.RequireUser(t) + + req := &types.CreatePromptRepoReq{CreateRepoReq: types.CreateRepoReq{}} + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, req).Return(true, nil) + reqn := *req + reqn.Username = "u" + tester.mocks.prompt.EXPECT().CreatePromptRepo(tester.ctx, &reqn).Return( + &types.PromptRes{Name: "p"}, nil, + ) + tester.WithBody(t, req).Execute() + + tester.ResponseEqSimple(t, 200, gin.H{ + "data": &types.PromptRes{Name: "p"}, + }) +} + +func TestPromptHandler_UpdatePromptRepo(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.Update + }) + tester.RequireUser(t) + + req := &types.UpdatePromptRepoReq{UpdateRepoReq: types.UpdateRepoReq{}} + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, req).Return(true, nil) + reqn := *req + reqn.Namespace = "u" + reqn.Name = "r" + reqn.Username = "u" + tester.mocks.prompt.EXPECT().UpdatePromptRepo(tester.ctx, &reqn).Return( + &types.PromptRes{Name: "p"}, nil, + ) + tester.WithBody(t, req).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.PromptRes{Name: "p"}) +} + +func TestPromptHandler_DeletePromptRepo(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.Delete + }) + tester.RequireUser(t) + + tester.mocks.prompt.EXPECT().RemoveRepo(tester.ctx, "u", "r", "u").Return(nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestPromptHandler_Branches(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.Branches + }) + + tester.mocks.repo.EXPECT().Branches(tester.ctx, &types.GetBranchesReq{ + Namespace: "u", + Name: "r", + Per: 10, + Page: 1, + RepoType: types.PromptRepo, + CurrentUser: "u", + }).Return([]types.Branch{{Name: "main"}}, nil) + tester.WithUser().AddPagination(1, 10).Execute() + + tester.ResponseEq(t, 200, tester.OKText, []types.Branch{{Name: "main"}}) +} + +func TestPromptHandler_Tags(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.Tags + }) + + tester.mocks.repo.EXPECT().Tags(tester.ctx, &types.GetTagsReq{ + Namespace: "u", + Name: "r", + RepoType: types.PromptRepo, + CurrentUser: "u", + }).Return([]database.Tag{{Name: "main"}}, nil) + tester.WithUser().AddPagination(1, 10).Execute() + + tester.ResponseEq(t, 200, tester.OKText, []database.Tag{{Name: "main"}}) +} + +func TestPromptHandler_UpdateTags(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.UpdateTags + }) + tester.RequireUser(t) + + req := []string{"a", "b"} + tester.mocks.repo.EXPECT().UpdateTags(tester.ctx, "u", "r", types.PromptRepo, "cat", "u", req).Return(nil) + tester.WithBody(t, req).WithParam("category", "cat").Execute() + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestPromptHandler_UpdateDownloads(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.UpdateDownloads + }) + + tester.mocks.repo.EXPECT().UpdateDownloads(tester.ctx, &types.UpdateDownloadsReq{ + Namespace: "u", + Name: "r", + RepoType: types.PromptRepo, + Date: time.Date(2012, 12, 12, 0, 0, 0, 0, time.UTC), + ReqDate: "2012-12-12", + }).Return(nil) + tester.WithUser().WithBody(t, &types.UpdateDownloadsReq{ + ReqDate: time.Date(2012, 12, 12, 0, 0, 0, 0, time.UTC).Format("2006-01-02"), + }).WithParam("category", "cat").Execute() + tester.ResponseEq(t, 200, tester.OKText, nil) +} From eb9e3038fa5a7517e19c78be3dc9ab2e9f87acba Mon Sep 17 00:00:00 2001 From: SeanHH86 <154984842+SeanHH86@users.noreply.github.com> Date: Mon, 30 Dec 2024 09:52:09 +0800 Subject: [PATCH 28/34] [Tag] add tag category management api (#227) * [Tag] add tag category api for portal --------- Co-authored-by: Haihui.Wang --- .../builder/store/database/mock_TagStore.go | 224 +++++++++++++++++ .../component/mock_TagComponent.go | 227 +++++++++++++++++ api/handler/tag.go | 155 +++++++++++- api/handler/tag_test.go | 141 ++++++++++- api/router/api.go | 7 + builder/store/database/tag.go | 48 +++- builder/store/database/tag_test.go | 46 ++++ common/types/tag.go | 7 + component/tag.go | 69 +++++ component/tag_test.go | 129 ++++++++++ docs/docs.go | 237 ++++++++++++++++++ docs/swagger.json | 237 ++++++++++++++++++ docs/swagger.yaml | 148 +++++++++++ 13 files changed, 1658 insertions(+), 17 deletions(-) diff --git a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_TagStore.go b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_TagStore.go index 6a02617a..bf69b020 100644 --- a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_TagStore.go +++ b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_TagStore.go @@ -24,6 +24,65 @@ func (_m *MockTagStore) EXPECT() *MockTagStore_Expecter { return &MockTagStore_Expecter{mock: &_m.Mock} } +// AllCategories provides a mock function with given fields: ctx, scope +func (_m *MockTagStore) AllCategories(ctx context.Context, scope database.TagScope) ([]database.TagCategory, error) { + ret := _m.Called(ctx, scope) + + if len(ret) == 0 { + panic("no return value specified for AllCategories") + } + + var r0 []database.TagCategory + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, database.TagScope) ([]database.TagCategory, error)); ok { + return rf(ctx, scope) + } + if rf, ok := ret.Get(0).(func(context.Context, database.TagScope) []database.TagCategory); ok { + r0 = rf(ctx, scope) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]database.TagCategory) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, database.TagScope) error); ok { + r1 = rf(ctx, scope) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockTagStore_AllCategories_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AllCategories' +type MockTagStore_AllCategories_Call struct { + *mock.Call +} + +// AllCategories is a helper method to define mock.On call +// - ctx context.Context +// - scope database.TagScope +func (_e *MockTagStore_Expecter) AllCategories(ctx interface{}, scope interface{}) *MockTagStore_AllCategories_Call { + return &MockTagStore_AllCategories_Call{Call: _e.mock.On("AllCategories", ctx, scope)} +} + +func (_c *MockTagStore_AllCategories_Call) Run(run func(ctx context.Context, scope database.TagScope)) *MockTagStore_AllCategories_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(database.TagScope)) + }) + return _c +} + +func (_c *MockTagStore_AllCategories_Call) Return(_a0 []database.TagCategory, _a1 error) *MockTagStore_AllCategories_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockTagStore_AllCategories_Call) RunAndReturn(run func(context.Context, database.TagScope) ([]database.TagCategory, error)) *MockTagStore_AllCategories_Call { + _c.Call.Return(run) + return _c +} + // AllCodeCategories provides a mock function with given fields: ctx func (_m *MockTagStore) AllCodeCategories(ctx context.Context) ([]database.TagCategory, error) { ret := _m.Called(ctx) @@ -781,6 +840,65 @@ func (_c *MockTagStore_AllTagsByScopeAndCategory_Call) RunAndReturn(run func(con return _c } +// CreateCategory provides a mock function with given fields: ctx, category +func (_m *MockTagStore) CreateCategory(ctx context.Context, category database.TagCategory) (*database.TagCategory, error) { + ret := _m.Called(ctx, category) + + if len(ret) == 0 { + panic("no return value specified for CreateCategory") + } + + var r0 *database.TagCategory + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, database.TagCategory) (*database.TagCategory, error)); ok { + return rf(ctx, category) + } + if rf, ok := ret.Get(0).(func(context.Context, database.TagCategory) *database.TagCategory); ok { + r0 = rf(ctx, category) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.TagCategory) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, database.TagCategory) error); ok { + r1 = rf(ctx, category) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockTagStore_CreateCategory_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateCategory' +type MockTagStore_CreateCategory_Call struct { + *mock.Call +} + +// CreateCategory is a helper method to define mock.On call +// - ctx context.Context +// - category database.TagCategory +func (_e *MockTagStore_Expecter) CreateCategory(ctx interface{}, category interface{}) *MockTagStore_CreateCategory_Call { + return &MockTagStore_CreateCategory_Call{Call: _e.mock.On("CreateCategory", ctx, category)} +} + +func (_c *MockTagStore_CreateCategory_Call) Run(run func(ctx context.Context, category database.TagCategory)) *MockTagStore_CreateCategory_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(database.TagCategory)) + }) + return _c +} + +func (_c *MockTagStore_CreateCategory_Call) Return(_a0 *database.TagCategory, _a1 error) *MockTagStore_CreateCategory_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockTagStore_CreateCategory_Call) RunAndReturn(run func(context.Context, database.TagCategory) (*database.TagCategory, error)) *MockTagStore_CreateCategory_Call { + _c.Call.Return(run) + return _c +} + // CreateTag provides a mock function with given fields: ctx, category, name, group, scope func (_m *MockTagStore) CreateTag(ctx context.Context, category string, name string, group string, scope database.TagScope) (database.Tag, error) { ret := _m.Called(ctx, category, name, group, scope) @@ -841,6 +959,53 @@ func (_c *MockTagStore_CreateTag_Call) RunAndReturn(run func(context.Context, st return _c } +// DeleteCategory provides a mock function with given fields: ctx, id +func (_m *MockTagStore) DeleteCategory(ctx context.Context, id int64) error { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteCategory") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockTagStore_DeleteCategory_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteCategory' +type MockTagStore_DeleteCategory_Call struct { + *mock.Call +} + +// DeleteCategory is a helper method to define mock.On call +// - ctx context.Context +// - id int64 +func (_e *MockTagStore_Expecter) DeleteCategory(ctx interface{}, id interface{}) *MockTagStore_DeleteCategory_Call { + return &MockTagStore_DeleteCategory_Call{Call: _e.mock.On("DeleteCategory", ctx, id)} +} + +func (_c *MockTagStore_DeleteCategory_Call) Run(run func(ctx context.Context, id int64)) *MockTagStore_DeleteCategory_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockTagStore_DeleteCategory_Call) Return(_a0 error) *MockTagStore_DeleteCategory_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTagStore_DeleteCategory_Call) RunAndReturn(run func(context.Context, int64) error) *MockTagStore_DeleteCategory_Call { + _c.Call.Return(run) + return _c +} + // DeleteTagByID provides a mock function with given fields: ctx, id func (_m *MockTagStore) DeleteTagByID(ctx context.Context, id int64) error { ret := _m.Called(ctx, id) @@ -1335,6 +1500,65 @@ func (_c *MockTagStore_SetMetaTags_Call) RunAndReturn(run func(context.Context, return _c } +// UpdateCategory provides a mock function with given fields: ctx, category +func (_m *MockTagStore) UpdateCategory(ctx context.Context, category database.TagCategory) (*database.TagCategory, error) { + ret := _m.Called(ctx, category) + + if len(ret) == 0 { + panic("no return value specified for UpdateCategory") + } + + var r0 *database.TagCategory + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, database.TagCategory) (*database.TagCategory, error)); ok { + return rf(ctx, category) + } + if rf, ok := ret.Get(0).(func(context.Context, database.TagCategory) *database.TagCategory); ok { + r0 = rf(ctx, category) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.TagCategory) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, database.TagCategory) error); ok { + r1 = rf(ctx, category) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockTagStore_UpdateCategory_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCategory' +type MockTagStore_UpdateCategory_Call struct { + *mock.Call +} + +// UpdateCategory is a helper method to define mock.On call +// - ctx context.Context +// - category database.TagCategory +func (_e *MockTagStore_Expecter) UpdateCategory(ctx interface{}, category interface{}) *MockTagStore_UpdateCategory_Call { + return &MockTagStore_UpdateCategory_Call{Call: _e.mock.On("UpdateCategory", ctx, category)} +} + +func (_c *MockTagStore_UpdateCategory_Call) Run(run func(ctx context.Context, category database.TagCategory)) *MockTagStore_UpdateCategory_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(database.TagCategory)) + }) + return _c +} + +func (_c *MockTagStore_UpdateCategory_Call) Return(_a0 *database.TagCategory, _a1 error) *MockTagStore_UpdateCategory_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockTagStore_UpdateCategory_Call) RunAndReturn(run func(context.Context, database.TagCategory) (*database.TagCategory, error)) *MockTagStore_UpdateCategory_Call { + _c.Call.Return(run) + return _c +} + // UpdateTagByID provides a mock function with given fields: ctx, tag func (_m *MockTagStore) UpdateTagByID(ctx context.Context, tag *database.Tag) (*database.Tag, error) { ret := _m.Called(ctx, tag) diff --git a/_mocks/opencsg.com/csghub-server/component/mock_TagComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_TagComponent.go index d6e7c76a..912dab1e 100644 --- a/_mocks/opencsg.com/csghub-server/component/mock_TagComponent.go +++ b/_mocks/opencsg.com/csghub-server/component/mock_TagComponent.go @@ -24,6 +24,64 @@ func (_m *MockTagComponent) EXPECT() *MockTagComponent_Expecter { return &MockTagComponent_Expecter{mock: &_m.Mock} } +// AllCategories provides a mock function with given fields: ctx +func (_m *MockTagComponent) AllCategories(ctx context.Context) ([]database.TagCategory, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for AllCategories") + } + + var r0 []database.TagCategory + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]database.TagCategory, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []database.TagCategory); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]database.TagCategory) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockTagComponent_AllCategories_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AllCategories' +type MockTagComponent_AllCategories_Call struct { + *mock.Call +} + +// AllCategories is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockTagComponent_Expecter) AllCategories(ctx interface{}) *MockTagComponent_AllCategories_Call { + return &MockTagComponent_AllCategories_Call{Call: _e.mock.On("AllCategories", ctx)} +} + +func (_c *MockTagComponent_AllCategories_Call) Run(run func(ctx context.Context)) *MockTagComponent_AllCategories_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockTagComponent_AllCategories_Call) Return(_a0 []database.TagCategory, _a1 error) *MockTagComponent_AllCategories_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockTagComponent_AllCategories_Call) RunAndReturn(run func(context.Context) ([]database.TagCategory, error)) *MockTagComponent_AllCategories_Call { + _c.Call.Return(run) + return _c +} + // AllTagsByScopeAndCategory provides a mock function with given fields: ctx, scope, category func (_m *MockTagComponent) AllTagsByScopeAndCategory(ctx context.Context, scope string, category string) ([]*database.Tag, error) { ret := _m.Called(ctx, scope, category) @@ -133,6 +191,66 @@ func (_c *MockTagComponent_ClearMetaTags_Call) RunAndReturn(run func(context.Con return _c } +// CreateCategory provides a mock function with given fields: ctx, username, req +func (_m *MockTagComponent) CreateCategory(ctx context.Context, username string, req types.CreateCategory) (*database.TagCategory, error) { + ret := _m.Called(ctx, username, req) + + if len(ret) == 0 { + panic("no return value specified for CreateCategory") + } + + var r0 *database.TagCategory + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, types.CreateCategory) (*database.TagCategory, error)); ok { + return rf(ctx, username, req) + } + if rf, ok := ret.Get(0).(func(context.Context, string, types.CreateCategory) *database.TagCategory); ok { + r0 = rf(ctx, username, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.TagCategory) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, types.CreateCategory) error); ok { + r1 = rf(ctx, username, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockTagComponent_CreateCategory_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateCategory' +type MockTagComponent_CreateCategory_Call struct { + *mock.Call +} + +// CreateCategory is a helper method to define mock.On call +// - ctx context.Context +// - username string +// - req types.CreateCategory +func (_e *MockTagComponent_Expecter) CreateCategory(ctx interface{}, username interface{}, req interface{}) *MockTagComponent_CreateCategory_Call { + return &MockTagComponent_CreateCategory_Call{Call: _e.mock.On("CreateCategory", ctx, username, req)} +} + +func (_c *MockTagComponent_CreateCategory_Call) Run(run func(ctx context.Context, username string, req types.CreateCategory)) *MockTagComponent_CreateCategory_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(types.CreateCategory)) + }) + return _c +} + +func (_c *MockTagComponent_CreateCategory_Call) Return(_a0 *database.TagCategory, _a1 error) *MockTagComponent_CreateCategory_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockTagComponent_CreateCategory_Call) RunAndReturn(run func(context.Context, string, types.CreateCategory) (*database.TagCategory, error)) *MockTagComponent_CreateCategory_Call { + _c.Call.Return(run) + return _c +} + // CreateTag provides a mock function with given fields: ctx, username, req func (_m *MockTagComponent) CreateTag(ctx context.Context, username string, req types.CreateTag) (*database.Tag, error) { ret := _m.Called(ctx, username, req) @@ -193,6 +311,54 @@ func (_c *MockTagComponent_CreateTag_Call) RunAndReturn(run func(context.Context return _c } +// DeleteCategory provides a mock function with given fields: ctx, username, id +func (_m *MockTagComponent) DeleteCategory(ctx context.Context, username string, id int64) error { + ret := _m.Called(ctx, username, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteCategory") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, int64) error); ok { + r0 = rf(ctx, username, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockTagComponent_DeleteCategory_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteCategory' +type MockTagComponent_DeleteCategory_Call struct { + *mock.Call +} + +// DeleteCategory is a helper method to define mock.On call +// - ctx context.Context +// - username string +// - id int64 +func (_e *MockTagComponent_Expecter) DeleteCategory(ctx interface{}, username interface{}, id interface{}) *MockTagComponent_DeleteCategory_Call { + return &MockTagComponent_DeleteCategory_Call{Call: _e.mock.On("DeleteCategory", ctx, username, id)} +} + +func (_c *MockTagComponent_DeleteCategory_Call) Run(run func(ctx context.Context, username string, id int64)) *MockTagComponent_DeleteCategory_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(int64)) + }) + return _c +} + +func (_c *MockTagComponent_DeleteCategory_Call) Return(_a0 error) *MockTagComponent_DeleteCategory_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTagComponent_DeleteCategory_Call) RunAndReturn(run func(context.Context, string, int64) error) *MockTagComponent_DeleteCategory_Call { + _c.Call.Return(run) + return _c +} + // DeleteTag provides a mock function with given fields: ctx, username, id func (_m *MockTagComponent) DeleteTag(ctx context.Context, username string, id int64) error { ret := _m.Called(ctx, username, id) @@ -301,6 +467,67 @@ func (_c *MockTagComponent_GetTagByID_Call) RunAndReturn(run func(context.Contex return _c } +// UpdateCategory provides a mock function with given fields: ctx, username, req, id +func (_m *MockTagComponent) UpdateCategory(ctx context.Context, username string, req types.UpdateCategory, id int64) (*database.TagCategory, error) { + ret := _m.Called(ctx, username, req, id) + + if len(ret) == 0 { + panic("no return value specified for UpdateCategory") + } + + var r0 *database.TagCategory + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, types.UpdateCategory, int64) (*database.TagCategory, error)); ok { + return rf(ctx, username, req, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string, types.UpdateCategory, int64) *database.TagCategory); ok { + r0 = rf(ctx, username, req, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.TagCategory) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, types.UpdateCategory, int64) error); ok { + r1 = rf(ctx, username, req, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockTagComponent_UpdateCategory_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCategory' +type MockTagComponent_UpdateCategory_Call struct { + *mock.Call +} + +// UpdateCategory is a helper method to define mock.On call +// - ctx context.Context +// - username string +// - req types.UpdateCategory +// - id int64 +func (_e *MockTagComponent_Expecter) UpdateCategory(ctx interface{}, username interface{}, req interface{}, id interface{}) *MockTagComponent_UpdateCategory_Call { + return &MockTagComponent_UpdateCategory_Call{Call: _e.mock.On("UpdateCategory", ctx, username, req, id)} +} + +func (_c *MockTagComponent_UpdateCategory_Call) Run(run func(ctx context.Context, username string, req types.UpdateCategory, id int64)) *MockTagComponent_UpdateCategory_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(types.UpdateCategory), args[3].(int64)) + }) + return _c +} + +func (_c *MockTagComponent_UpdateCategory_Call) Return(_a0 *database.TagCategory, _a1 error) *MockTagComponent_UpdateCategory_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockTagComponent_UpdateCategory_Call) RunAndReturn(run func(context.Context, string, types.UpdateCategory, int64) (*database.TagCategory, error)) *MockTagComponent_UpdateCategory_Call { + _c.Call.Return(run) + return _c +} + // UpdateLibraryTags provides a mock function with given fields: ctx, tagScope, namespace, name, oldFilePath, newFilePath func (_m *MockTagComponent) UpdateLibraryTags(ctx context.Context, tagScope database.TagScope, namespace string, name string, oldFilePath string, newFilePath string) error { ret := _m.Called(ctx, tagScope, namespace, name, oldFilePath, newFilePath) diff --git a/api/handler/tag.go b/api/handler/tag.go index 16689e50..dff4cb99 100644 --- a/api/handler/tag.go +++ b/api/handler/tag.go @@ -1,6 +1,7 @@ package handler import ( + "errors" "log/slog" "net/http" "strconv" @@ -18,12 +19,12 @@ func NewTagHandler(config *config.Config) (*TagsHandler, error) { return nil, err } return &TagsHandler{ - tc: tc, + tag: tc, }, nil } type TagsHandler struct { - tc component.TagComponent + tag component.TagComponent } // GetAllTags godoc @@ -43,7 +44,7 @@ func (t *TagsHandler) AllTags(ctx *gin.Context) { //TODO:validate inputs category := ctx.Query("category") scope := ctx.Query("scope") - tags, err := t.tc.AllTagsByScopeAndCategory(ctx, scope, category) + tags, err := t.tag.AllTagsByScopeAndCategory(ctx, scope, category) if err != nil { slog.Error("Failed to load tags", slog.Any("category", category), slog.Any("scope", scope), slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -79,7 +80,7 @@ func (t *TagsHandler) CreateTag(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - tag, err := t.tc.CreateTag(ctx, userName, req) + tag, err := t.tag.CreateTag(ctx, userName, req) if err != nil { slog.Error("Failed to create tag", slog.Any("req", req), slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -112,7 +113,7 @@ func (t *TagsHandler) GetTagByID(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - tag, err := t.tc.GetTagByID(ctx, userName, id) + tag, err := t.tag.GetTagByID(ctx, userName, id) if err != nil { slog.Error("Failed to get tag", slog.Int64("id", id), slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -152,7 +153,7 @@ func (t *TagsHandler) UpdateTag(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - tag, err := t.tc.UpdateTag(ctx, userName, id, req) + tag, err := t.tag.UpdateTag(ctx, userName, id, req) if err != nil { slog.Error("Failed to update tag", slog.Int64("id", id), slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -185,7 +186,7 @@ func (t *TagsHandler) DeleteTag(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - err = t.tc.DeleteTag(ctx, userName, id) + err = t.tag.DeleteTag(ctx, userName, id) if err != nil { slog.Error("Failed to delete tag", slog.Int64("id", id), slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -193,3 +194,143 @@ func (t *TagsHandler) DeleteTag(ctx *gin.Context) { } ctx.JSON(http.StatusOK, nil) } + +// GetAllCategories godoc +// @Security ApiKey +// @Summary Get all Categories +// @Description Get all Categories +// @Tags Tag +// @Accept json +// @Produce json +// @Success 200 {object} types.ResponseWithTotal{data=[]database.TagCategory} "categores" +// @Failure 400 {object} types.APIBadRequest "Bad request" +// @Failure 500 {object} types.APIInternalServerError "Internal server error" +// @Router /tags/categories [get] +func (t *TagsHandler) AllCategories(ctx *gin.Context) { + categories, err := t.tag.AllCategories(ctx) + if err != nil { + slog.Error("Failed to load categories", slog.Any("error", err)) + httpbase.ServerError(ctx, err) + return + } + respData := gin.H{ + "data": categories, + } + ctx.JSON(http.StatusOK, respData) +} + +// CreateCategory godoc +// @Security ApiKey +// @Summary Create new category +// @Description Create new category +// @Tags Tag +// @Accept json +// @Produce json +// @Param body body types.CreateCategory true "body" +// @Success 200 {object} types.Response{database.TagCategory} "OK" +// @Failure 400 {object} types.APIBadRequest "Bad request" +// @Failure 500 {object} types.APIInternalServerError "Internal server error" +// @Router /tags/categories [post] +func (t *TagsHandler) CreateCategory(ctx *gin.Context) { + userName := httpbase.GetCurrentUser(ctx) + if userName == "" { + httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) + return + } + var req types.CreateCategory + if err := ctx.ShouldBindJSON(&req); err != nil { + slog.Error("Bad request format", slog.Any("error", err)) + httpbase.BadRequest(ctx, err.Error()) + return + } + category, err := t.tag.CreateCategory(ctx, userName, req) + if err != nil { + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) + return + } + slog.Error("Failed to create category", slog.Any("req", req), slog.Any("error", err)) + httpbase.ServerError(ctx, err) + return + } + ctx.JSON(http.StatusOK, gin.H{"data": category}) +} + +// UpdateCategory godoc +// @Security ApiKey +// @Summary Create new category +// @Description Create new category +// @Tags Tag +// @Accept json +// @Produce json +// @Param body body types.UpdateCategory true "body" +// @Success 200 {object} types.Response{database.TagCategory} "OK" +// @Failure 400 {object} types.APIBadRequest "Bad request" +// @Failure 500 {object} types.APIInternalServerError "Internal server error" +// @Router /tags/categories/id [put] +func (t *TagsHandler) UpdateCategory(ctx *gin.Context) { + userName := httpbase.GetCurrentUser(ctx) + if userName == "" { + httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) + return + } + var req types.UpdateCategory + if err := ctx.ShouldBindJSON(&req); err != nil { + slog.Error("Bad request format", slog.Any("error", err)) + httpbase.BadRequest(ctx, err.Error()) + return + } + id, err := strconv.ParseInt(ctx.Param("id"), 10, 64) + if err != nil { + slog.Error("Bad request format", slog.Any("error", err)) + httpbase.BadRequest(ctx, err.Error()) + return + } + category, err := t.tag.UpdateCategory(ctx, userName, req, id) + if err != nil { + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) + return + } + slog.Error("Failed to update category", slog.Any("req", req), slog.Any("id", id), slog.Any("error", err)) + httpbase.ServerError(ctx, err) + return + } + ctx.JSON(http.StatusOK, gin.H{"data": category}) +} + +// DeleteCategory godoc +// @Security ApiKey +// @Summary Delete a category by id +// @Description Delete a category by id +// @Tags Tag +// @Accept json +// @Produce json +// @Success 200 {object} types.Response{} "OK" +// @Failure 400 {object} types.APIBadRequest "Bad request" +// @Failure 500 {object} types.APIInternalServerError "Internal server error" +// @Router /tags/categories/id [delete] +func (t *TagsHandler) DeleteCategory(ctx *gin.Context) { + userName := httpbase.GetCurrentUser(ctx) + if userName == "" { + httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) + return + } + id, err := strconv.ParseInt(ctx.Param("id"), 10, 64) + if err != nil { + slog.Error("Bad request format", slog.Any("error", err)) + httpbase.BadRequest(ctx, err.Error()) + return + } + err = t.tag.DeleteCategory(ctx, userName, id) + if err != nil { + if errors.Is(err, component.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) + return + } + slog.Error("Failed to delete category", slog.Any("id", id), slog.Any("error", err)) + httpbase.ServerError(ctx, err) + return + } + ctx.JSON(http.StatusOK, nil) +} diff --git a/api/handler/tag_test.go b/api/handler/tag_test.go index ed47c226..0a2fa998 100644 --- a/api/handler/tag_test.go +++ b/api/handler/tag_test.go @@ -22,7 +22,7 @@ func NewTestTagHandler( tagComp component.TagComponent, ) (*TagsHandler, error) { return &TagsHandler{ - tc: tagComp, + tag: tagComp, }, nil } @@ -195,3 +195,142 @@ func TestTagHandler_DeleteTag(t *testing.T) { require.Equal(t, "", resp.Msg) require.Nil(t, resp.Data) } + +func TestTagHandler_AllCategories(t *testing.T) { + var categories []database.TagCategory + categories = append(categories, database.TagCategory{ID: 1, Name: "test1", Scope: database.TagScope("scope")}) + + req := httptest.NewRequest("get", "/tags/categories", nil) + + hr := httptest.NewRecorder() + ginContext, _ := gin.CreateTestContext(hr) + ginContext.Request = req + + tagComp := mockcom.NewMockTagComponent(t) + tagComp.EXPECT().AllCategories(ginContext).Return(categories, nil) + + tagHandler, err := NewTestTagHandler(tagComp) + require.Nil(t, err) + + tagHandler.AllCategories(ginContext) + + require.Equal(t, http.StatusOK, hr.Code) + + var resp httpbase.R + + err = json.Unmarshal(hr.Body.Bytes(), &resp) + require.Nil(t, err) + + require.Equal(t, 0, resp.Code) + require.Equal(t, "", resp.Msg) + require.NotNil(t, resp.Data) +} + +func TestTagHandler_CreateCategory(t *testing.T) { + username := "testuser" + data := types.CreateCategory{ + Name: "testcate", + Scope: "testscope", + } + + reqBody, _ := json.Marshal(data) + + req := httptest.NewRequest("post", "/tags/categories", bytes.NewBuffer(reqBody)) + + hr := httptest.NewRecorder() + ginContext, _ := gin.CreateTestContext(hr) + ginContext.Set("currentUser", username) + ginContext.Request = req + + tagComp := mockcom.NewMockTagComponent(t) + tagComp.EXPECT().CreateCategory(ginContext, username, data).Return( + &database.TagCategory{ID: 1, Name: "testcate", Scope: database.TagScope("testscope")}, + nil, + ) + + tagHandler, err := NewTestTagHandler(tagComp) + require.Nil(t, err) + + tagHandler.CreateCategory(ginContext) + + require.Equal(t, http.StatusOK, hr.Code) + + var resp httpbase.R + + err = json.Unmarshal(hr.Body.Bytes(), &resp) + require.Nil(t, err) + + require.Equal(t, 0, resp.Code) + require.Equal(t, "", resp.Msg) + require.NotNil(t, resp.Data) +} + +func TestTagHandler_UpdateCategory(t *testing.T) { + username := "testuser" + data := types.UpdateCategory{ + Name: "testcate", + Scope: "testscope", + } + + reqBody, _ := json.Marshal(data) + + req := httptest.NewRequest("put", "/tags/categories/1", bytes.NewBuffer(reqBody)) + + hr := httptest.NewRecorder() + ginContext, _ := gin.CreateTestContext(hr) + ginContext.Set("currentUser", username) + ginContext.AddParam("id", "1") + ginContext.Request = req + + tagComp := mockcom.NewMockTagComponent(t) + tagComp.EXPECT().UpdateCategory(ginContext, username, data, int64(1)).Return( + &database.TagCategory{ID: 1, Name: "testcate", Scope: database.TagScope("testscope")}, + nil, + ) + + tagHandler, err := NewTestTagHandler(tagComp) + require.Nil(t, err) + + tagHandler.UpdateCategory(ginContext) + + require.Equal(t, http.StatusOK, hr.Code) + + var resp httpbase.R + + err = json.Unmarshal(hr.Body.Bytes(), &resp) + require.Nil(t, err) + + require.Equal(t, 0, resp.Code) + require.Equal(t, "", resp.Msg) + require.NotNil(t, resp.Data) +} + +func TestTagHandler_DeleteCategory(t *testing.T) { + username := "testuser" + + req := httptest.NewRequest("delete", "/tags/categories/1", nil) + + hr := httptest.NewRecorder() + ginContext, _ := gin.CreateTestContext(hr) + ginContext.Set("currentUser", username) + ginContext.AddParam("id", "1") + ginContext.Request = req + + tagComp := mockcom.NewMockTagComponent(t) + tagComp.EXPECT().DeleteCategory(ginContext, username, int64(1)).Return(nil) + + tagHandler, err := NewTestTagHandler(tagComp) + require.Nil(t, err) + + tagHandler.DeleteCategory(ginContext) + + require.Equal(t, http.StatusOK, hr.Code) + + var resp httpbase.R + + err = json.Unmarshal(hr.Body.Bytes(), &resp) + require.Nil(t, err) + + require.Equal(t, 0, resp.Code) + require.Equal(t, "", resp.Msg) +} diff --git a/api/router/api.go b/api/router/api.go index 5b3f197c..7c9b7fef 100644 --- a/api/router/api.go +++ b/api/router/api.go @@ -771,6 +771,13 @@ func createPromptRoutes(apiGroup *gin.RouterGroup, promptHandler *handler.Prompt func createTagsRoutes(apiGroup *gin.RouterGroup, tagHandler *handler.TagsHandler) { tagsGrp := apiGroup.Group("/tags") { + categoryGrp := tagsGrp.Group("/categories") + { + categoryGrp.GET("", tagHandler.AllCategories) + categoryGrp.POST("", tagHandler.CreateCategory) + categoryGrp.PUT("/:id", tagHandler.UpdateCategory) + categoryGrp.DELETE("/:id", tagHandler.DeleteCategory) + } tagsGrp.GET("", tagHandler.AllTags) tagsGrp.POST("", tagHandler.CreateTag) tagsGrp.GET("/:id", tagHandler.GetTagByID) diff --git a/builder/store/database/tag.go b/builder/store/database/tag.go index 2194afc4..14819a89 100644 --- a/builder/store/database/tag.go +++ b/builder/store/database/tag.go @@ -26,6 +26,7 @@ type TagStore interface { AllDatasetTags(ctx context.Context) ([]*Tag, error) AllCodeTags(ctx context.Context) ([]*Tag, error) AllSpaceTags(ctx context.Context) ([]*Tag, error) + AllCategories(ctx context.Context, scope TagScope) ([]TagCategory, error) AllModelCategories(ctx context.Context) ([]TagCategory, error) AllPromptCategories(ctx context.Context) ([]TagCategory, error) AllDatasetCategories(ctx context.Context) ([]TagCategory, error) @@ -43,6 +44,9 @@ type TagStore interface { FindTagByID(ctx context.Context, id int64) (*Tag, error) UpdateTagByID(ctx context.Context, tag *Tag) (*Tag, error) DeleteTagByID(ctx context.Context, id int64) error + CreateCategory(ctx context.Context, category TagCategory) (*TagCategory, error) + UpdateCategory(ctx context.Context, category TagCategory) (*TagCategory, error) + DeleteCategory(ctx context.Context, id int64) error } func NewTagStore() TagStore { @@ -156,30 +160,32 @@ func (ts *tagStoreImpl) AllSpaceTags(ctx context.Context) ([]*Tag, error) { } func (ts *tagStoreImpl) AllModelCategories(ctx context.Context) ([]TagCategory, error) { - return ts.allCategories(ctx, ModelTagScope) + return ts.AllCategories(ctx, ModelTagScope) } func (ts *tagStoreImpl) AllPromptCategories(ctx context.Context) ([]TagCategory, error) { - return ts.allCategories(ctx, PromptTagScope) + return ts.AllCategories(ctx, PromptTagScope) } func (ts *tagStoreImpl) AllDatasetCategories(ctx context.Context) ([]TagCategory, error) { - return ts.allCategories(ctx, DatasetTagScope) + return ts.AllCategories(ctx, DatasetTagScope) } func (ts *tagStoreImpl) AllCodeCategories(ctx context.Context) ([]TagCategory, error) { - return ts.allCategories(ctx, CodeTagScope) + return ts.AllCategories(ctx, CodeTagScope) } func (ts *tagStoreImpl) AllSpaceCategories(ctx context.Context) ([]TagCategory, error) { - return ts.allCategories(ctx, SpaceTagScope) + return ts.AllCategories(ctx, SpaceTagScope) } -func (ts *tagStoreImpl) allCategories(ctx context.Context, scope TagScope) ([]TagCategory, error) { +func (ts *tagStoreImpl) AllCategories(ctx context.Context, scope TagScope) ([]TagCategory, error) { var tags []TagCategory - err := ts.db.Operator.Core.NewSelect().Model(&TagCategory{}). - Where("scope = ?", scope). - Scan(ctx, &tags) + q := ts.db.Operator.Core.NewSelect().Model(&TagCategory{}) + if len(scope) > 0 { + q = q.Where("scope = ?", scope) + } + err := q.Order("id").Scan(ctx, &tags) if err != nil { slog.Error("Failed to select tags", "error", err) return nil, err @@ -440,3 +446,27 @@ func (ts *tagStoreImpl) DeleteTagByID(ctx context.Context, id int64) error { return err } + +func (ts *tagStoreImpl) CreateCategory(ctx context.Context, category TagCategory) (*TagCategory, error) { + _, err := ts.db.Operator.Core.NewInsert().Model(&category).Exec(ctx) + if err != nil { + return nil, fmt.Errorf("insert category error: %w", err) + } + return &category, nil +} + +func (ts *tagStoreImpl) UpdateCategory(ctx context.Context, category TagCategory) (*TagCategory, error) { + _, err := ts.db.Operator.Core.NewUpdate().Model(&category).WherePK().Exec(ctx) + if err != nil { + return nil, fmt.Errorf("update category by id %d, error: %w", category.ID, err) + } + return &category, nil +} + +func (ts *tagStoreImpl) DeleteCategory(ctx context.Context, id int64) error { + _, err := ts.db.Operator.Core.NewDelete().Model(&TagCategory{}).Where("id = ?", id).Exec(ctx) + if err != nil { + return fmt.Errorf("delete category by id %d, error: %w", id, err) + } + return nil +} diff --git a/builder/store/database/tag_test.go b/builder/store/database/tag_test.go index 27d3bc0f..467832e1 100644 --- a/builder/store/database/tag_test.go +++ b/builder/store/database/tag_test.go @@ -625,3 +625,49 @@ func TestTagStore_DeleteTagByID(t *testing.T) { require.NotEmpty(t, err) } + +func TestTagStore_Category_CURD(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + ts := database.NewTagStoreWithDB(db) + + categories, err := ts.AllCategories(ctx, "") + require.Empty(t, err) + require.NotEmpty(t, categories) + + total := len(categories) + + catetory, err := ts.CreateCategory(ctx, database.TagCategory{ + Name: "test-category", + Scope: "test-scope", + }) + require.Empty(t, err) + require.NotEmpty(t, catetory) + + id := catetory.ID + + catetory, err = ts.UpdateCategory(ctx, database.TagCategory{ + ID: 1, + Name: "test-category1", + Scope: "test-scope1", + }) + require.Empty(t, err) + require.NotEmpty(t, catetory) + + categories, err = ts.AllCategories(ctx, "") + require.Empty(t, err) + require.NotEmpty(t, categories) + require.Equal(t, "test-category1", categories[0].Name) + require.Equal(t, "test-scope1", string(categories[0].Scope)) + + err = ts.DeleteCategory(ctx, id) + require.Empty(t, err) + + categories, err = ts.AllCategories(ctx, "") + require.Empty(t, err) + require.Equal(t, total, len(categories)) +} diff --git a/common/types/tag.go b/common/types/tag.go index d85fa91c..2cb741e0 100644 --- a/common/types/tag.go +++ b/common/types/tag.go @@ -35,3 +35,10 @@ type CreateTag struct { } type UpdateTag CreateTag + +type CreateCategory struct { + Name string `json:"name" binding:"required"` + Scope string `json:"scope" binding:"required"` +} + +type UpdateCategory CreateCategory diff --git a/component/tag.go b/component/tag.go index 44c41506..2e932af1 100644 --- a/component/tag.go +++ b/component/tag.go @@ -24,6 +24,10 @@ type TagComponent interface { GetTagByID(ctx context.Context, username string, id int64) (*database.Tag, error) UpdateTag(ctx context.Context, username string, id int64, req types.UpdateTag) (*database.Tag, error) DeleteTag(ctx context.Context, username string, id int64) error + AllCategories(ctx context.Context) ([]database.TagCategory, error) + CreateCategory(ctx context.Context, username string, req types.CreateCategory) (*database.TagCategory, error) + UpdateCategory(ctx context.Context, username string, req types.UpdateCategory, id int64) (*database.TagCategory, error) + DeleteCategory(ctx context.Context, username string, id int64) error } func NewTagComponent(config *config.Config) (TagComponent, error) { @@ -291,3 +295,68 @@ func (c *tagComponentImpl) DeleteTag(ctx context.Context, username string, id in } return nil } + +func (c *tagComponentImpl) AllCategories(ctx context.Context) ([]database.TagCategory, error) { + return c.tagStore.AllCategories(ctx, database.TagScope("")) +} + +func (c *tagComponentImpl) CreateCategory(ctx context.Context, username string, req types.CreateCategory) (*database.TagCategory, error) { + user, err := c.userStore.FindByUsername(ctx, username) + if err != nil { + return nil, fmt.Errorf("failed to get user, error: %w", err) + } + if !user.CanAdmin() { + return nil, ErrForbidden + } + + newCategory := database.TagCategory{ + Name: req.Name, + Scope: database.TagScope(req.Scope), + } + + category, err := c.tagStore.CreateCategory(ctx, newCategory) + if err != nil { + return nil, fmt.Errorf("failed to create category, error: %w", err) + } + + return category, nil +} + +func (c *tagComponentImpl) UpdateCategory(ctx context.Context, username string, req types.UpdateCategory, id int64) (*database.TagCategory, error) { + user, err := c.userStore.FindByUsername(ctx, username) + if err != nil { + return nil, fmt.Errorf("failed to get user, error: %w", err) + } + if !user.CanAdmin() { + return nil, ErrForbidden + } + + newCategory := database.TagCategory{ + ID: id, + Name: req.Name, + Scope: database.TagScope(req.Scope), + } + + category, err := c.tagStore.UpdateCategory(ctx, newCategory) + if err != nil { + return nil, fmt.Errorf("failed to update category, error: %w", err) + } + + return category, nil +} + +func (c *tagComponentImpl) DeleteCategory(ctx context.Context, username string, id int64) error { + user, err := c.userStore.FindByUsername(ctx, username) + if err != nil { + return fmt.Errorf("failed to get user, error: %w", err) + } + if !user.CanAdmin() { + return ErrForbidden + } + + err = c.tagStore.DeleteCategory(ctx, id) + if err != nil { + return fmt.Errorf("failed to delete category, error: %w", err) + } + return nil +} diff --git a/component/tag_test.go b/component/tag_test.go index acfa168d..c2a0c0eb 100644 --- a/component/tag_test.go +++ b/component/tag_test.go @@ -235,3 +235,132 @@ func TestTagComponent_UpdateRepoTagsByCategory(t *testing.T) { err := tc.UpdateRepoTagsByCategory(ctx, database.DatasetTagScope, 1, "c", []string{"t1"}) require.Nil(t, err) } + +func TestTagComponent_AllCategories(t *testing.T) { + ctx := context.TODO() + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.TagMock().EXPECT().AllCategories(ctx, database.TagScope("")).Return([]database.TagCategory{}, nil) + + categories, err := tc.AllCategories(ctx) + require.Nil(t, err) + require.Equal(t, []database.TagCategory{}, categories) +} + +func TestTagComponent_CreateCategory(t *testing.T) { + ctx := context.TODO() + + t.Run("admin", func(t *testing.T) { + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "admin").Return(database.User{ + Username: "admin", + RoleMask: "admin", + }, nil) + tc.mocks.stores.TagMock().EXPECT().CreateCategory(ctx, database.TagCategory{ + Name: "test-cate", + Scope: database.TagScope("test-scope"), + }).Return(&database.TagCategory{ + ID: 1, + Name: "test-cate", + Scope: "test-scope", + }, nil) + + category, err := tc.CreateCategory(ctx, "admin", types.CreateCategory{ + Name: "test-cate", + Scope: "test-scope", + }) + require.Nil(t, err) + require.NotNil(t, category) + }) + + t.Run("user", func(t *testing.T) { + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + Username: "user", + RoleMask: "user", + }, nil) + + category, err := tc.CreateCategory(ctx, "user", types.CreateCategory{ + Name: "test-cate", + Scope: "test-scope", + }) + require.NotNil(t, err) + require.Nil(t, category) + }) +} + +func TestTagComponent_UpdateCategory(t *testing.T) { + ctx := context.TODO() + + t.Run("admin", func(t *testing.T) { + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "admin").Return(database.User{ + Username: "admin", + RoleMask: "admin", + }, nil) + tc.mocks.stores.TagMock().EXPECT().UpdateCategory(ctx, database.TagCategory{ + ID: int64(1), + Name: "test-cate", + Scope: database.TagScope("test-scope"), + }).Return(&database.TagCategory{ + ID: 1, + Name: "test-cate", + Scope: "test-scope", + }, nil) + + category, err := tc.UpdateCategory(ctx, "admin", types.UpdateCategory{ + Name: "test-cate", + Scope: "test-scope", + }, int64(1)) + require.Nil(t, err) + require.NotNil(t, category) + }) + + t.Run("user", func(t *testing.T) { + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + Username: "user", + RoleMask: "user", + }, nil) + + category, err := tc.UpdateCategory(ctx, "user", types.UpdateCategory{ + Name: "test-cate", + Scope: "test-scope", + }, int64(1)) + require.NotNil(t, err) + require.Nil(t, category) + }) +} + +func TestTagComponent_DeleteCategory(t *testing.T) { + ctx := context.TODO() + + t.Run("admin", func(t *testing.T) { + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "admin").Return(database.User{ + Username: "admin", + RoleMask: "admin", + }, nil) + tc.mocks.stores.TagMock().EXPECT().DeleteCategory(ctx, int64(1)).Return(nil) + + err := tc.DeleteCategory(ctx, "admin", int64(1)) + require.Nil(t, err) + }) + + t.Run("user", func(t *testing.T) { + tc := initializeTestTagComponent(ctx, t) + + tc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ + Username: "user", + RoleMask: "user", + }, nil) + + err := tc.DeleteCategory(ctx, "user", int64(1)) + require.NotNil(t, err) + }) +} diff --git a/docs/docs.go b/docs/docs.go index b2611626..31468db6 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -10144,6 +10144,199 @@ const docTemplate = `{ } } }, + "/tags/categories": { + "get": { + "security": [ + { + "ApiKey": [] + } + ], + "description": "Get all Categories", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Tag" + ], + "summary": "Get all Categories", + "responses": { + "200": { + "description": "categores", + "schema": { + "allOf": [ + { + "$ref": "#/definitions/types.ResponseWithTotal" + }, + { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/definitions/database.TagCategory" + } + } + } + } + ] + } + }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/types.APIInternalServerError" + } + } + } + }, + "post": { + "security": [ + { + "ApiKey": [] + } + ], + "description": "Create new category", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Tag" + ], + "summary": "Create new category", + "parameters": [ + { + "description": "body", + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/types.CreateCategory" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/types.Response" + } + }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/types.APIInternalServerError" + } + } + } + } + }, + "/tags/categories/id": { + "put": { + "security": [ + { + "ApiKey": [] + } + ], + "description": "Create new category", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Tag" + ], + "summary": "Create new category", + "parameters": [ + { + "description": "body", + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/types.UpdateCategory" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/types.Response" + } + }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/types.APIInternalServerError" + } + } + } + }, + "delete": { + "security": [ + { + "ApiKey": [] + } + ], + "description": "Delete a category by id", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Tag" + ], + "summary": "Delete a category by id", + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/types.Response" + } + }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/types.APIInternalServerError" + } + } + } + } + }, "/tags/{id}": { "get": { "security": [ @@ -15990,6 +16183,20 @@ const docTemplate = `{ } } }, + "database.TagCategory": { + "type": "object", + "properties": { + "id": { + "type": "integer" + }, + "name": { + "type": "string" + }, + "scope": { + "$ref": "#/definitions/database.TagScope" + } + } + }, "database.TagScope": { "type": "string", "enum": [ @@ -16636,6 +16843,21 @@ const docTemplate = `{ } } }, + "types.CreateCategory": { + "type": "object", + "required": [ + "name", + "scope" + ], + "properties": { + "name": { + "type": "string" + }, + "scope": { + "type": "string" + } + } + }, "types.CreateCodeReq": { "type": "object", "properties": { @@ -18570,6 +18792,21 @@ const docTemplate = `{ "TaskTypeLeaderBoard" ] }, + "types.UpdateCategory": { + "type": "object", + "required": [ + "name", + "scope" + ], + "properties": { + "name": { + "type": "string" + }, + "scope": { + "type": "string" + } + } + }, "types.UpdateCodeReq": { "type": "object", "properties": { diff --git a/docs/swagger.json b/docs/swagger.json index 196b34e1..711b3efd 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -10133,6 +10133,199 @@ } } }, + "/tags/categories": { + "get": { + "security": [ + { + "ApiKey": [] + } + ], + "description": "Get all Categories", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Tag" + ], + "summary": "Get all Categories", + "responses": { + "200": { + "description": "categores", + "schema": { + "allOf": [ + { + "$ref": "#/definitions/types.ResponseWithTotal" + }, + { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/definitions/database.TagCategory" + } + } + } + } + ] + } + }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/types.APIInternalServerError" + } + } + } + }, + "post": { + "security": [ + { + "ApiKey": [] + } + ], + "description": "Create new category", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Tag" + ], + "summary": "Create new category", + "parameters": [ + { + "description": "body", + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/types.CreateCategory" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/types.Response" + } + }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/types.APIInternalServerError" + } + } + } + } + }, + "/tags/categories/id": { + "put": { + "security": [ + { + "ApiKey": [] + } + ], + "description": "Create new category", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Tag" + ], + "summary": "Create new category", + "parameters": [ + { + "description": "body", + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/types.UpdateCategory" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/types.Response" + } + }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/types.APIInternalServerError" + } + } + } + }, + "delete": { + "security": [ + { + "ApiKey": [] + } + ], + "description": "Delete a category by id", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Tag" + ], + "summary": "Delete a category by id", + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/types.Response" + } + }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/types.APIInternalServerError" + } + } + } + } + }, "/tags/{id}": { "get": { "security": [ @@ -15979,6 +16172,20 @@ } } }, + "database.TagCategory": { + "type": "object", + "properties": { + "id": { + "type": "integer" + }, + "name": { + "type": "string" + }, + "scope": { + "$ref": "#/definitions/database.TagScope" + } + } + }, "database.TagScope": { "type": "string", "enum": [ @@ -16625,6 +16832,21 @@ } } }, + "types.CreateCategory": { + "type": "object", + "required": [ + "name", + "scope" + ], + "properties": { + "name": { + "type": "string" + }, + "scope": { + "type": "string" + } + } + }, "types.CreateCodeReq": { "type": "object", "properties": { @@ -18559,6 +18781,21 @@ "TaskTypeLeaderBoard" ] }, + "types.UpdateCategory": { + "type": "object", + "required": [ + "name", + "scope" + ], + "properties": { + "name": { + "type": "string" + }, + "scope": { + "type": "string" + } + } + }, "types.UpdateCodeReq": { "type": "object", "properties": { diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 1209dabb..ccfa9f9e 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -463,6 +463,15 @@ definitions: updated_at: type: string type: object + database.TagCategory: + properties: + id: + type: integer + name: + type: string + scope: + $ref: '#/definitions/database.TagScope' + type: object database.TagScope: enum: - model @@ -900,6 +909,16 @@ definitions: - title - uuid type: object + types.CreateCategory: + properties: + name: + type: string + scope: + type: string + required: + - name + - scope + type: object types.CreateCodeReq: properties: default_branch: @@ -2216,6 +2235,16 @@ definitions: - TaskTypeTraining - TaskTypeComparison - TaskTypeLeaderBoard + types.UpdateCategory: + properties: + name: + type: string + scope: + type: string + required: + - name + - scope + type: object types.UpdateCodeReq: properties: description: @@ -10687,6 +10716,125 @@ paths: summary: Update a tag by id tags: - Tag + /tags/categories: + get: + consumes: + - application/json + description: Get all Categories + produces: + - application/json + responses: + "200": + description: categores + schema: + allOf: + - $ref: '#/definitions/types.ResponseWithTotal' + - properties: + data: + items: + $ref: '#/definitions/database.TagCategory' + type: array + type: object + "400": + description: Bad request + schema: + $ref: '#/definitions/types.APIBadRequest' + "500": + description: Internal server error + schema: + $ref: '#/definitions/types.APIInternalServerError' + security: + - ApiKey: [] + summary: Get all Categories + tags: + - Tag + post: + consumes: + - application/json + description: Create new category + parameters: + - description: body + in: body + name: body + required: true + schema: + $ref: '#/definitions/types.CreateCategory' + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/types.Response' + "400": + description: Bad request + schema: + $ref: '#/definitions/types.APIBadRequest' + "500": + description: Internal server error + schema: + $ref: '#/definitions/types.APIInternalServerError' + security: + - ApiKey: [] + summary: Create new category + tags: + - Tag + /tags/categories/id: + delete: + consumes: + - application/json + description: Delete a category by id + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/types.Response' + "400": + description: Bad request + schema: + $ref: '#/definitions/types.APIBadRequest' + "500": + description: Internal server error + schema: + $ref: '#/definitions/types.APIInternalServerError' + security: + - ApiKey: [] + summary: Delete a category by id + tags: + - Tag + put: + consumes: + - application/json + description: Create new category + parameters: + - description: body + in: body + name: body + required: true + schema: + $ref: '#/definitions/types.UpdateCategory' + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/types.Response' + "400": + description: Bad request + schema: + $ref: '#/definitions/types.APIBadRequest' + "500": + description: Internal server error + schema: + $ref: '#/definitions/types.APIInternalServerError' + security: + - ApiKey: [] + summary: Create new category + tags: + - Tag /telemetry/usage: post: consumes: From 7d421037acc4bd4235a00ca4168354d157557e8c Mon Sep 17 00:00:00 2001 From: yiling Date: Fri, 27 Dec 2024 17:07:36 +0800 Subject: [PATCH 29/34] Sync Space/SpaceResource component with enterprise --- .../component/mock_SpaceComponent.go | 21 +- api/handler/space.go | 2 +- api/handler/space_resource.go | 2 +- common/types/space.go | 21 +- component/space.go | 106 ++++------ component/space_ce.go | 94 +++++++++ component/space_ce_test.go | 168 +++++++++++++++ component/space_resource.go | 51 ++--- component/space_resource_ce.go | 58 +++++ component/space_resource_ce_test.go | 38 ++++ component/space_resource_test.go | 141 ++++++------- component/space_test.go | 199 +++--------------- component/wireset.go | 34 --- component/wireset_ce.go | 48 +++++ 14 files changed, 588 insertions(+), 395 deletions(-) create mode 100644 component/space_ce.go create mode 100644 component/space_ce_test.go create mode 100644 component/space_resource_ce.go create mode 100644 component/space_resource_ce_test.go create mode 100644 component/wireset_ce.go diff --git a/_mocks/opencsg.com/csghub-server/component/mock_SpaceComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_SpaceComponent.go index 026277cb..99d6cde9 100644 --- a/_mocks/opencsg.com/csghub-server/component/mock_SpaceComponent.go +++ b/_mocks/opencsg.com/csghub-server/component/mock_SpaceComponent.go @@ -726,17 +726,17 @@ func (_c *MockSpaceComponent_Status_Call) RunAndReturn(run func(context.Context, return _c } -// Stop provides a mock function with given fields: ctx, namespace, name -func (_m *MockSpaceComponent) Stop(ctx context.Context, namespace string, name string) error { - ret := _m.Called(ctx, namespace, name) +// Stop provides a mock function with given fields: ctx, namespace, name, deleteSpace +func (_m *MockSpaceComponent) Stop(ctx context.Context, namespace string, name string, deleteSpace bool) error { + ret := _m.Called(ctx, namespace, name, deleteSpace) if len(ret) == 0 { panic("no return value specified for Stop") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, namespace, name) + if rf, ok := ret.Get(0).(func(context.Context, string, string, bool) error); ok { + r0 = rf(ctx, namespace, name, deleteSpace) } else { r0 = ret.Error(0) } @@ -753,13 +753,14 @@ type MockSpaceComponent_Stop_Call struct { // - ctx context.Context // - namespace string // - name string -func (_e *MockSpaceComponent_Expecter) Stop(ctx interface{}, namespace interface{}, name interface{}) *MockSpaceComponent_Stop_Call { - return &MockSpaceComponent_Stop_Call{Call: _e.mock.On("Stop", ctx, namespace, name)} +// - deleteSpace bool +func (_e *MockSpaceComponent_Expecter) Stop(ctx interface{}, namespace interface{}, name interface{}, deleteSpace interface{}) *MockSpaceComponent_Stop_Call { + return &MockSpaceComponent_Stop_Call{Call: _e.mock.On("Stop", ctx, namespace, name, deleteSpace)} } -func (_c *MockSpaceComponent_Stop_Call) Run(run func(ctx context.Context, namespace string, name string)) *MockSpaceComponent_Stop_Call { +func (_c *MockSpaceComponent_Stop_Call) Run(run func(ctx context.Context, namespace string, name string, deleteSpace bool)) *MockSpaceComponent_Stop_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(bool)) }) return _c } @@ -769,7 +770,7 @@ func (_c *MockSpaceComponent_Stop_Call) Return(_a0 error) *MockSpaceComponent_St return _c } -func (_c *MockSpaceComponent_Stop_Call) RunAndReturn(run func(context.Context, string, string) error) *MockSpaceComponent_Stop_Call { +func (_c *MockSpaceComponent_Stop_Call) RunAndReturn(run func(context.Context, string, string, bool) error) *MockSpaceComponent_Stop_Call { _c.Call.Return(run) return _c } diff --git a/api/handler/space.go b/api/handler/space.go index 04d7c290..da4ddd05 100644 --- a/api/handler/space.go +++ b/api/handler/space.go @@ -387,7 +387,7 @@ func (h *SpaceHandler) Stop(ctx *gin.Context) { return } - err = h.c.Stop(ctx, namespace, name) + err = h.c.Stop(ctx, namespace, name, false) if err != nil { slog.Error("failed to stop space", slog.String("namespace", namespace), slog.String("name", name), slog.Any("error", err)) diff --git a/api/handler/space_resource.go b/api/handler/space_resource.go index 86ad56db..91f7563e 100644 --- a/api/handler/space_resource.go +++ b/api/handler/space_resource.go @@ -51,7 +51,7 @@ func (h *SpaceResourceHandler) Index(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - spaceResources, err := h.c.Index(ctx, clusterId, deployType) + spaceResources, err := h.c.Index(ctx, clusterId, deployType, "") if err != nil { slog.Error("Failed to get space resources", slog.String("cluster_id", clusterId), slog.String("deploy_type", deployTypeStr), slog.Any("error", err)) httpbase.ServerError(ctx, err) diff --git a/common/types/space.go b/common/types/space.go index 2584d626..9b6dc9d0 100644 --- a/common/types/space.go +++ b/common/types/space.go @@ -43,16 +43,17 @@ type Space struct { // the serving endpoint url Endpoint string `json:"endpoint,omitempty" example:"https://localhost/spaces/myname/myspace"` // deploying, running, failed - Status string `json:"status"` - RepositoryID int64 `json:"repository_id,omitempty"` - UserLikes bool `json:"user_likes"` - Source RepositorySource `json:"source"` - SyncStatus RepositorySyncStatus `json:"sync_status"` - SKU string `json:"sku,omitempty"` - SvcName string `json:"svc_name,omitempty"` - CanWrite bool `json:"can_write"` - CanManage bool `json:"can_manage"` - Namespace *Namespace `json:"namespace"` + Status string `json:"status"` + RepositoryID int64 `json:"repository_id,omitempty"` + UserLikes bool `json:"user_likes"` + Source RepositorySource `json:"source"` + SyncStatus RepositorySyncStatus `json:"sync_status"` + SKU string `json:"sku,omitempty"` + SvcName string `json:"svc_name,omitempty"` + CanWrite bool `json:"can_write"` + CanManage bool `json:"can_manage"` + Namespace *Namespace `json:"namespace"` + SensitiveCheckStatus string `json:"sensitive_check_status"` } type UpdateSpaceReq struct { diff --git a/component/space.go b/component/space.go index 422d41b8..1850810e 100644 --- a/component/space.go +++ b/component/space.go @@ -11,12 +11,9 @@ import ( "opencsg.com/csghub-server/builder/deploy" "opencsg.com/csghub-server/builder/deploy/scheduler" - "opencsg.com/csghub-server/builder/git" "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/git/membership" - "opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/builder/store/database" - "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/types" "opencsg.com/csghub-server/common/utils/common" ) @@ -45,7 +42,7 @@ type SpaceComponent interface { Delete(ctx context.Context, namespace, name, currentUser string) error Deploy(ctx context.Context, namespace, name, currentUser string) (int64, error) Wakeup(ctx context.Context, namespace, name string) error - Stop(ctx context.Context, namespace, name string) error + Stop(ctx context.Context, namespace, name string, deleteSpace bool) error // FixHasEntryFile checks whether git repo has entry point file and update space's HasAppFile property in db FixHasEntryFile(ctx context.Context, s *database.Space) *database.Space Status(ctx context.Context, namespace, name string) (string, string, error) @@ -54,55 +51,6 @@ type SpaceComponent interface { HasEntryFile(ctx context.Context, space *database.Space) bool } -func NewSpaceComponent(config *config.Config) (SpaceComponent, error) { - c := &spaceComponentImpl{} - c.spaceStore = database.NewSpaceStore() - var err error - c.spaceSdkStore = database.NewSpaceSdkStore() - c.spaceResourceStore = database.NewSpaceResourceStore() - c.repoStore = database.NewRepoStore() - c.repoComponent, err = NewRepoComponentImpl(config) - if err != nil { - return nil, err - } - c.deployer = deploy.NewDeployer() - c.publicRootDomain = config.Space.PublicRootDomain - c.userStore = database.NewUserStore() - c.accountingComponent, err = NewAccountingComponent(config) - if err != nil { - return nil, err - } - c.git, err = git.NewGitServer(config) - if err != nil { - return nil, err - } - c.serverBaseUrl = config.APIServer.PublicDomain - c.deployTaskStore = database.NewDeployTaskStore() - c.userSvcClient = rpc.NewUserSvcHttpClient(fmt.Sprintf("%s:%d", config.User.Host, config.User.Port), - rpc.AuthWithApiKey(config.APIToken)) - c.userLikesStore = database.NewUserLikesStore() - c.config = config - return c, nil -} - -type spaceComponentImpl struct { - config *config.Config - repoComponent RepoComponent - git gitserver.GitServer - spaceStore database.SpaceStore - spaceSdkStore database.SpaceSdkStore - spaceResourceStore database.SpaceResourceStore - repoStore database.RepoStore - userStore database.UserStore - deployer deploy.Deployer - publicRootDomain string - accountingComponent AccountingComponent - serverBaseUrl string - deployTaskStore database.DeployTaskStore - userSvcClient rpc.UserSvcClient - userLikesStore database.UserLikesStore -} - func (c *spaceComponentImpl) Create(ctx context.Context, req types.CreateSpaceReq) (*types.Space, error) { var nickname string if req.Nickname != "" { @@ -110,13 +58,19 @@ func (c *spaceComponentImpl) Create(ctx context.Context, req types.CreateSpaceRe } else { nickname = req.Name } + if req.DefaultBranch == "" { req.DefaultBranch = types.MainBranch } + req.Nickname = nickname req.RepoType = types.SpaceRepo req.Readme = generateReadmeData(req.License) resource, err := c.spaceResourceStore.FindByID(ctx, req.ResourceID) + if err != nil { + return nil, fmt.Errorf("fail to find resource by id, %w", err) + } + err = c.checkResourcePurchasableForCreate(ctx, req, resource) if err != nil { return nil, err } @@ -125,7 +79,7 @@ func (c *spaceComponentImpl) Create(ctx context.Context, req types.CreateSpaceRe if err != nil { return nil, fmt.Errorf("invalid hardware setting, %w", err) } - _, err = c.deployer.CheckResourceAvailable(ctx, req.ClusterID, 0, &hardware) + _, err = c.checkResourceAvailable(ctx, req, hardware) if err != nil { return nil, fmt.Errorf("fail to check resource, %w", err) } @@ -145,6 +99,7 @@ func (c *spaceComponentImpl) Create(ctx context.Context, req types.CreateSpaceRe Secrets: req.Secrets, SKU: strconv.FormatInt(resource.ID, 10), } + dbSpace = c.updateSpaceByReq(dbSpace, req) resSpace, err := c.spaceStore.Create(ctx, dbSpace) if err != nil { @@ -302,12 +257,26 @@ func (c *spaceComponentImpl) Show(ctx context.Context, namespace, name, currentU CanManage: permission.CanAdmin, Namespace: ns, } + if permission.CanAdmin { + resModel.SensitiveCheckStatus = space.Repository.SensitiveCheckStatus.String() + } return resModel, nil } func (c *spaceComponentImpl) Update(ctx context.Context, req *types.UpdateSpaceReq) (*types.Space, error) { req.RepoType = types.SpaceRepo + if req.ResourceID != nil { + resource, err := c.spaceResourceStore.FindByID(ctx, *req.ResourceID) + if err != nil { + return nil, fmt.Errorf("fail to find resource by id, %w", err) + } + + err = c.checkResourcePurchasableForUpdate(ctx, *req, resource) + if err != nil { + return nil, err + } + } dbRepo, err := c.repoComponent.UpdateRepo(ctx, req.UpdateRepoReq) if err != nil { return nil, err @@ -317,6 +286,10 @@ func (c *spaceComponentImpl) Update(ctx context.Context, req *types.UpdateSpaceR if err != nil { return nil, fmt.Errorf("failed to find space, error: %w", err) } + // don't support switch reserved resource + if c.resourceReserved(space, req) { + return nil, fmt.Errorf("don't support switch reserved resource so far") + } err = c.mergeUpdateSpaceRequest(ctx, space, req) if err != nil { return nil, fmt.Errorf("failed to merge update space request, error: %w", err) @@ -602,7 +575,12 @@ func (c *spaceComponentImpl) Delete(ctx context.Context, namespace, name, curren } // stop any running space instance - go func() { _ = c.Stop(ctx, namespace, name) }() + go func() { + err := c.Stop(ctx, namespace, name, true) + if err != nil { + slog.Error("stop space failed", slog.Any("error", err)) + } + }() return nil } @@ -641,7 +619,7 @@ func (c *spaceComponentImpl) Deploy(ctx context.Context, namespace, name, curren slog.Info("run space with container image", slog.Any("namespace", namespace), slog.Any("name", name), slog.Any("containerImg", containerImg)) // create deploy for space - return c.deployer.Deploy(ctx, types.DeployRepo{ + dr := types.DeployRepo{ SpaceID: s.ID, Path: s.Repository.Path, GitPath: s.Repository.GitPath, @@ -660,7 +638,9 @@ func (c *spaceComponentImpl) Deploy(ctx context.Context, namespace, name, curren Type: types.SpaceType, UserUUID: user.UUID, SKU: s.SKU, - }) + } + dr = c.updateDeployRepoBySpace(dr, s) + return c.deployer.Deploy(ctx, dr) } func (c *spaceComponentImpl) Wakeup(ctx context.Context, namespace, name string) error { @@ -682,7 +662,7 @@ func (c *spaceComponentImpl) Wakeup(ctx context.Context, namespace, name string) }) } -func (c *spaceComponentImpl) Stop(ctx context.Context, namespace, name string) error { +func (c *spaceComponentImpl) Stop(ctx context.Context, namespace, name string, deleteSpace bool) error { s, err := c.spaceStore.FindByPath(ctx, namespace, name) if err != nil { slog.Error("can't stop space", slog.Any("error", err), slog.String("namespace", namespace), slog.String("name", name)) @@ -698,12 +678,14 @@ func (c *spaceComponentImpl) Stop(ctx context.Context, namespace, name string) e return fmt.Errorf("can't get space deployment") } - err = c.deployer.Stop(ctx, types.DeployRepo{ + dr := types.DeployRepo{ SpaceID: s.ID, Namespace: namespace, Name: name, SvcName: deploy.SvcName, - }) + } + dr = c.updateDeployRepoByDeploy(dr, deploy) + err = c.deployer.Stop(ctx, dr) if err != nil { return fmt.Errorf("can't stop space service deploy for service '%s', %w", deploy.SvcName, err) } @@ -720,9 +702,7 @@ func (c *spaceComponentImpl) FixHasEntryFile(ctx context.Context, s *database.Sp hasAppFile := c.HasEntryFile(ctx, s) if s.HasAppFile != hasAppFile { s.HasAppFile = hasAppFile - if er := c.spaceStore.Update(ctx, *s); er != nil { - slog.Error("update space failed", "error", er) - } + _ = c.spaceStore.Update(ctx, *s) } return s diff --git a/component/space_ce.go b/component/space_ce.go new file mode 100644 index 00000000..c9e93a69 --- /dev/null +++ b/component/space_ce.go @@ -0,0 +1,94 @@ +//go:build !saas + +package component + +import ( + "context" + "fmt" + + "opencsg.com/csghub-server/builder/deploy" + "opencsg.com/csghub-server/builder/git" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/rpc" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" +) + +func NewSpaceComponent(config *config.Config) (SpaceComponent, error) { + c := &spaceComponentImpl{} + c.spaceStore = database.NewSpaceStore() + var err error + c.spaceSdkStore = database.NewSpaceSdkStore() + c.spaceResourceStore = database.NewSpaceResourceStore() + c.repoStore = database.NewRepoStore() + c.repoComponent, err = NewRepoComponentImpl(config) + if err != nil { + return nil, err + } + c.deployer = deploy.NewDeployer() + c.publicRootDomain = config.Space.PublicRootDomain + c.userStore = database.NewUserStore() + c.accountingComponent, err = NewAccountingComponent(config) + if err != nil { + return nil, err + } + c.serverBaseUrl = config.APIServer.PublicDomain + c.userLikesStore = database.NewUserLikesStore() + c.config = config + c.userSvcClient = rpc.NewUserSvcHttpClient(fmt.Sprintf("%s:%d", config.User.Host, config.User.Port), + rpc.AuthWithApiKey(config.APIToken)) + + c.deployTaskStore = database.NewDeployTaskStore() + c.git, err = git.NewGitServer(config) + if err != nil { + return nil, err + } + return c, nil +} + +type spaceComponentImpl struct { + repoComponent RepoComponent + git gitserver.GitServer + spaceStore database.SpaceStore + spaceSdkStore database.SpaceSdkStore + spaceResourceStore database.SpaceResourceStore + repoStore database.RepoStore + userStore database.UserStore + deployer deploy.Deployer + publicRootDomain string + accountingComponent AccountingComponent + serverBaseUrl string + userLikesStore database.UserLikesStore + config *config.Config + userSvcClient rpc.UserSvcClient + deployTaskStore database.DeployTaskStore +} + +func (c *spaceComponentImpl) checkResourcePurchasableForCreate(ctx context.Context, req types.CreateSpaceReq, resource *database.SpaceResource) error { + return nil +} + +func (c *spaceComponentImpl) checkResourcePurchasableForUpdate(ctx context.Context, req types.UpdateSpaceReq, resource *database.SpaceResource) error { + return nil +} + +func (c *spaceComponentImpl) checkResourceAvailable(ctx context.Context, req types.CreateSpaceReq, hardware types.HardWare) (bool, error) { + return c.deployer.CheckResourceAvailable(ctx, req.ClusterID, 0, &hardware) +} + +func (c *spaceComponentImpl) updateSpaceByReq(space database.Space, req types.CreateSpaceReq) database.Space { + return space +} + +func (c *spaceComponentImpl) resourceReserved(space *database.Space, req *types.UpdateSpaceReq) bool { + return false +} + +func (c *spaceComponentImpl) updateDeployRepoBySpace(repo types.DeployRepo, space *database.Space) types.DeployRepo { + return repo +} + +func (c *spaceComponentImpl) updateDeployRepoByDeploy(repo types.DeployRepo, deploy *database.Deploy) types.DeployRepo { + return repo +} diff --git a/component/space_ce_test.go b/component/space_ce_test.go new file mode 100644 index 00000000..c9175045 --- /dev/null +++ b/component/space_ce_test.go @@ -0,0 +1,168 @@ +//go:build !saas + +package component + +import ( + "context" + "testing" + + "github.com/alibabacloud-go/tea/tea" + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/deploy/scheduler" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestSpaceComponent_Create(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSpaceComponent(ctx, t) + + sc.mocks.stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(1)).Return(&database.SpaceResource{ + ID: 1, + Name: "sp", + Resources: `{"memory": "foo"}`, + }, nil) + + sc.mocks.deployer.EXPECT().CheckResourceAvailable(ctx, "cluster", int64(0), &types.HardWare{ + Memory: "foo", + }).Return(true, nil) + + sc.mocks.components.repo.EXPECT().CreateRepo(ctx, types.CreateRepoReq{ + DefaultBranch: "main", + Readme: generateReadmeData("MIT"), + License: "MIT", + Namespace: "ns", + Name: "n", + Nickname: "n", + RepoType: types.SpaceRepo, + Username: "user", + }).Return(nil, &database.Repository{ + ID: 321, + User: database.User{ + Username: "user", + Email: "foo@bar.com", + }, + }, nil) + + sc.mocks.stores.SpaceMock().EXPECT().Create(ctx, database.Space{ + RepositoryID: 321, + Sdk: scheduler.STREAMLIT.Name, + SdkVersion: "v1", + Env: "env", + Hardware: `{"memory": "foo"}`, + Secrets: "sss", + SKU: "1", + }).Return(&database.Space{}, nil) + sc.mocks.gitServer.EXPECT().CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ + Username: "user", + Email: "foo@bar.com", + Message: initCommitMessage, + Branch: "main", + Content: generateReadmeData("MIT"), + NewBranch: "main", + Namespace: "ns", + Name: "n", + FilePath: readmeFileName, + }, types.SpaceRepo)).Return(nil) + sc.mocks.gitServer.EXPECT().CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ + Username: "user", + Email: "foo@bar.com", + Message: initCommitMessage, + Branch: "main", + Content: spaceGitattributesContent, + NewBranch: "main", + Namespace: "ns", + Name: "n", + FilePath: gitattributesFileName, + }, types.SpaceRepo)).Return(nil) + sc.mocks.gitServer.EXPECT().CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ + Username: "user", + Email: "foo@bar.com", + Message: initCommitMessage, + Branch: "main", + Content: streamlitConfigContent, + NewBranch: "main", + Namespace: "ns", + Name: "n", + FilePath: streamlitConfig, + }, types.SpaceRepo)).Return(nil) + + space, err := sc.Create(ctx, types.CreateSpaceReq{ + Sdk: scheduler.STREAMLIT.Name, + SdkVersion: "v1", + Env: "env", + Secrets: "sss", + ResourceID: 1, + ClusterID: "cluster", + CreateRepoReq: types.CreateRepoReq{ + DefaultBranch: "main", + Readme: "readme", + Namespace: "ns", + Name: "n", + License: "MIT", + Username: "user", + }, + }) + require.Nil(t, err) + + require.Equal(t, &types.Space{ + License: "MIT", + Name: "n", + Sdk: "streamlit", + SdkVersion: "v1", + Env: "env", + Secrets: "sss", + Hardware: `{"memory": "foo"}`, + Creator: "user", + }, space) + +} + +func TestSpaceComponent_Update(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSpaceComponent(ctx, t) + + sc.mocks.stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(12)).Return(&database.SpaceResource{ + ID: 12, + Name: "sp", + Resources: `{"memory": "foo"}`, + }, nil) + + sc.mocks.components.repo.EXPECT().UpdateRepo(ctx, types.UpdateRepoReq{ + Username: "user", + Namespace: "ns", + Name: "n", + RepoType: types.SpaceRepo, + }).Return( + &database.Repository{ + ID: 123, + Name: "repo", + }, nil, + ) + sc.mocks.stores.SpaceMock().EXPECT().ByRepoID(ctx, int64(123)).Return(&database.Space{ + ID: 321, + }, nil) + sc.mocks.stores.SpaceMock().EXPECT().Update(ctx, database.Space{ + ID: 321, + Hardware: `{"memory": "foo"}`, + SKU: "12", + }).Return(nil) + + space, err := sc.Update(ctx, &types.UpdateSpaceReq{ + ResourceID: tea.Int64(12), + UpdateRepoReq: types.UpdateRepoReq{ + Username: "user", + Namespace: "ns", + Name: "n", + }, + }) + require.Nil(t, err) + + require.Equal(t, &types.Space{ + ID: 321, + Name: "repo", + Hardware: `{"memory": "foo"}`, + SKU: "12", + }, space) + +} diff --git a/component/space_resource.go b/component/space_resource.go index b40939e0..28430d88 100644 --- a/component/space_resource.go +++ b/component/space_resource.go @@ -8,30 +8,17 @@ import ( "opencsg.com/csghub-server/builder/deploy" "opencsg.com/csghub-server/builder/store/database" - "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/types" ) type SpaceResourceComponent interface { - Index(ctx context.Context, clusterId string, deployType int) ([]types.SpaceResource, error) + Index(ctx context.Context, clusterId string, deployType int, currentUser string) ([]types.SpaceResource, error) Update(ctx context.Context, req *types.UpdateSpaceResourceReq) (*types.SpaceResource, error) Create(ctx context.Context, req *types.CreateSpaceResourceReq) (*types.SpaceResource, error) Delete(ctx context.Context, id int64) error } -func NewSpaceResourceComponent(config *config.Config) (SpaceResourceComponent, error) { - c := &spaceResourceComponentImpl{} - c.srs = database.NewSpaceResourceStore() - c.deployer = deploy.NewDeployer() - return c, nil -} - -type spaceResourceComponentImpl struct { - srs database.SpaceResourceStore - deployer deploy.Deployer -} - -func (c *spaceResourceComponentImpl) Index(ctx context.Context, clusterId string, deployType int) ([]types.SpaceResource, error) { +func (c *spaceResourceComponentImpl) Index(ctx context.Context, clusterId string, deployType int, currentUser string) ([]types.SpaceResource, error) { // backward compatibility for old api if clusterId == "" { clusters, err := c.deployer.ListCluster(ctx) @@ -44,7 +31,7 @@ func (c *spaceResourceComponentImpl) Index(ctx context.Context, clusterId string clusterId = clusters[0].ClusterID } var result []types.SpaceResource - databaseSpaceResources, err := c.srs.Index(ctx, clusterId) + databaseSpaceResources, err := c.spaceResourceStore.Index(ctx, clusterId) if err != nil { return nil, err } @@ -52,6 +39,7 @@ func (c *spaceResourceComponentImpl) Index(ctx context.Context, clusterId string if err != nil { return nil, err } + for _, r := range databaseSpaceResources { var isAvailable bool var hardware types.HardWare @@ -61,16 +49,10 @@ func (c *spaceResourceComponentImpl) Index(ctx context.Context, clusterId string } else { isAvailable = deploy.CheckResource(clusterResources, &hardware) } - if deployType == types.FinetuneType { - if hardware.Gpu.Num == "" { - continue - } - } - resourceType := types.ResourceTypeCPU - if hardware.Gpu.Num != "" { - resourceType = types.ResourceTypeGPU + if !c.deployAvailable(deployType, hardware) { + continue } - + resourceType := c.resourceType(hardware) result = append(result, types.SpaceResource{ ID: r.ID, Name: r.Name, @@ -79,12 +61,21 @@ func (c *spaceResourceComponentImpl) Index(ctx context.Context, clusterId string Type: resourceType, }) } + err = c.updatePriceInfo(currentUser, result) + if err != nil { + return nil, err + } + + result, err = c.appendUserResources(ctx, currentUser, clusterId, result) + if err != nil { + return nil, err + } return result, nil } func (c *spaceResourceComponentImpl) Update(ctx context.Context, req *types.UpdateSpaceResourceReq) (*types.SpaceResource, error) { - sr, err := c.srs.FindByID(ctx, req.ID) + sr, err := c.spaceResourceStore.FindByID(ctx, req.ID) if err != nil { slog.Error("error getting space resource", slog.Any("error", err)) return nil, err @@ -92,7 +83,7 @@ func (c *spaceResourceComponentImpl) Update(ctx context.Context, req *types.Upda sr.Name = req.Name sr.Resources = req.Resources - sr, err = c.srs.Update(ctx, *sr) + sr, err = c.spaceResourceStore.Update(ctx, *sr) if err != nil { slog.Error("error updating space resource", slog.Any("error", err)) return nil, err @@ -113,7 +104,7 @@ func (c *spaceResourceComponentImpl) Create(ctx context.Context, req *types.Crea Resources: req.Resources, ClusterID: req.ClusterID, } - res, err := c.srs.Create(ctx, sr) + res, err := c.spaceResourceStore.Create(ctx, sr) if err != nil { slog.Error("error creating space resource", slog.Any("error", err)) return nil, err @@ -129,13 +120,13 @@ func (c *spaceResourceComponentImpl) Create(ctx context.Context, req *types.Crea } func (c *spaceResourceComponentImpl) Delete(ctx context.Context, id int64) error { - sr, err := c.srs.FindByID(ctx, id) + sr, err := c.spaceResourceStore.FindByID(ctx, id) if err != nil { slog.Error("error finding space resource", slog.Any("error", err)) return err } - err = c.srs.Delete(ctx, *sr) + err = c.spaceResourceStore.Delete(ctx, *sr) if err != nil { slog.Error("error deleting space resource", slog.Any("error", err)) return err diff --git a/component/space_resource_ce.go b/component/space_resource_ce.go new file mode 100644 index 00000000..fe43d8fe --- /dev/null +++ b/component/space_resource_ce.go @@ -0,0 +1,58 @@ +//go:build !ee && !saas + +package component + +import ( + "context" + + "opencsg.com/csghub-server/builder/deploy" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" +) + +func NewSpaceResourceComponent(config *config.Config) (SpaceResourceComponent, error) { + c := &spaceResourceComponentImpl{} + c.spaceResourceStore = database.NewSpaceResourceStore() + c.deployer = deploy.NewDeployer() + c.userStore = database.NewUserStore() + ac, err := NewAccountingComponent(config) + if err != nil { + return nil, err + } + c.accountComponent = ac + return c, nil +} + +type spaceResourceComponentImpl struct { + spaceResourceStore database.SpaceResourceStore + deployer deploy.Deployer + userStore database.UserStore + accountComponent AccountingComponent +} + +func (c *spaceResourceComponentImpl) updatePriceInfo(currentUser string, resources []types.SpaceResource) error { + return nil + +} + +func (c *spaceResourceComponentImpl) appendUserResources(ctx context.Context, currentUser string, clusterID string, resources []types.SpaceResource) ([]types.SpaceResource, error) { + return resources, nil +} + +func (c *spaceResourceComponentImpl) deployAvailable(deployType int, hardware types.HardWare) bool { + if deployType == types.FinetuneType { + if hardware.Gpu.Num == "" { + return false + } + } + return true +} + +func (c *spaceResourceComponentImpl) resourceType(hardware types.HardWare) types.ResourceType { + resourceType := types.ResourceTypeCPU + if hardware.Gpu.Num != "" { + resourceType = types.ResourceTypeGPU + } + return resourceType +} diff --git a/component/space_resource_ce_test.go b/component/space_resource_ce_test.go new file mode 100644 index 00000000..931427dd --- /dev/null +++ b/component/space_resource_ce_test.go @@ -0,0 +1,38 @@ +//go:build !ee && !saas + +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +func TestSpaceResourceComponent_Index(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSpaceResourceComponent(ctx, t) + + sc.mocks.deployer.EXPECT().ListCluster(ctx).Return([]types.ClusterRes{ + {ClusterID: "c1"}, + }, nil) + sc.mocks.stores.SpaceResourceMock().EXPECT().Index(ctx, "c1").Return( + []database.SpaceResource{ + {ID: 1, Name: "sr", Resources: `{"memory": "1000", "gpu": {"num": "5"}}`}, + {ID: 2, Name: "sr2", Resources: `{"memory": "1000"}`}, + }, nil, + ) + sc.mocks.deployer.EXPECT().GetClusterById(ctx, "c1").Return(&types.ClusterRes{}, nil) + + data, err := sc.Index(ctx, "", types.FinetuneType, "user") + require.Nil(t, err) + require.Equal(t, []types.SpaceResource{ + { + ID: 1, Name: "sr", Resources: `{"memory": "1000", "gpu": {"num": "5"}}`, + IsAvailable: false, Type: "gpu", + }, + }, data) + +} diff --git a/component/space_resource_test.go b/component/space_resource_test.go index 024cc1e9..28a93fe3 100644 --- a/component/space_resource_test.go +++ b/component/space_resource_test.go @@ -1,94 +1,71 @@ package component -// func TestSpaceResourceComponent_Index(t *testing.T) { -// ctx := context.TODO() -// sc := initializeTestSpaceResourceComponent(ctx, t) +import ( + "context" + "testing" -// sc.mocks.deployer.EXPECT().ListCluster(ctx).Return([]types.ClusterRes{ -// {ClusterID: "c1"}, -// }, nil) -// sc.mocks.stores.SpaceResourceMock().EXPECT().Index(ctx, "c1").Return( -// []database.SpaceResource{ -// {ID: 1, Name: "sr", Resources: `{"memory": "1000"}`}, -// }, nil, -// ) -// sc.mocks.deployer.EXPECT().GetClusterById(ctx, "c1").Return(&types.ClusterRes{}, nil) -// sc.mocks.stores.UserMock().EXPECT().FindByUsername(ctx, "user").Return(database.User{ -// UUID: "uid", -// }, nil) + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) -// data, err := sc.Index(ctx, "", 1) -// require.Nil(t, err) -// require.Equal(t, []types.SpaceResource{ -// { -// ID: 1, Name: "sr", Resources: "{\"memory\": \"1000\"}", -// IsAvailable: false, Type: "cpu", -// }, -// { -// ID: 0, Name: "", Resources: "{\"memory\": \"2000\"}", IsAvailable: true, -// Type: "cpu", -// }, -// }, data) +func TestSpaceResourceComponent_Update(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSpaceResourceComponent(ctx, t) -// } + sc.mocks.stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(1)).Return( + &database.SpaceResource{}, nil, + ) + sc.mocks.stores.SpaceResourceMock().EXPECT().Update(ctx, database.SpaceResource{ + Name: "n", + Resources: "r", + }).Return(&database.SpaceResource{ID: 1, Name: "n", Resources: "r"}, nil) -// func TestSpaceResourceComponent_Update(t *testing.T) { -// ctx := context.TODO() -// sc := initializeTestSpaceResourceComponent(ctx, t) + data, err := sc.Update(ctx, &types.UpdateSpaceResourceReq{ + ID: 1, + Name: "n", + Resources: "r", + }) + require.Nil(t, err) + require.Equal(t, &types.SpaceResource{ + ID: 1, + Name: "n", + Resources: "r", + }, data) +} -// sc.mocks.stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(1)).Return( -// &database.SpaceResource{}, nil, -// ) -// sc.mocks.stores.SpaceResourceMock().EXPECT().Update(ctx, database.SpaceResource{ -// Name: "n", -// Resources: "r", -// }).Return(&database.SpaceResource{ID: 1, Name: "n", Resources: "r"}, nil) +func TestSpaceResourceComponent_Create(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSpaceResourceComponent(ctx, t) -// data, err := sc.Update(ctx, &types.UpdateSpaceResourceReq{ -// ID: 1, -// Name: "n", -// Resources: "r", -// }) -// require.Nil(t, err) -// require.Equal(t, &types.SpaceResource{ -// ID: 1, -// Name: "n", -// Resources: "r", -// }, data) -// } + sc.mocks.stores.SpaceResourceMock().EXPECT().Create(ctx, database.SpaceResource{ + Name: "n", + Resources: "r", + ClusterID: "c", + }).Return(&database.SpaceResource{ID: 1, Name: "n", Resources: "r"}, nil) -// func TestSpaceResourceComponent_Create(t *testing.T) { -// ctx := context.TODO() -// sc := initializeTestSpaceResourceComponent(ctx, t) + data, err := sc.Create(ctx, &types.CreateSpaceResourceReq{ + Name: "n", + Resources: "r", + ClusterID: "c", + }) + require.Nil(t, err) + require.Equal(t, &types.SpaceResource{ + ID: 1, + Name: "n", + Resources: "r", + }, data) +} -// sc.mocks.stores.SpaceResourceMock().EXPECT().Create(ctx, database.SpaceResource{ -// Name: "n", -// Resources: "r", -// ClusterID: "c", -// }).Return(&database.SpaceResource{ID: 1, Name: "n", Resources: "r"}, nil) +func TestSpaceResourceComponent_Delete(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSpaceResourceComponent(ctx, t) -// data, err := sc.Create(ctx, &types.CreateSpaceResourceReq{ -// Name: "n", -// Resources: "r", -// ClusterID: "c", -// }) -// require.Nil(t, err) -// require.Equal(t, &types.SpaceResource{ -// ID: 1, -// Name: "n", -// Resources: "r", -// }, data) -// } + sc.mocks.stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(1)).Return( + &database.SpaceResource{}, nil, + ) + sc.mocks.stores.SpaceResourceMock().EXPECT().Delete(ctx, database.SpaceResource{}).Return(nil) -// func TestSpaceResourceComponent_Delete(t *testing.T) { -// ctx := context.TODO() -// sc := initializeTestSpaceResourceComponent(ctx, t) - -// sc.mocks.stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(1)).Return( -// &database.SpaceResource{}, nil, -// ) -// sc.mocks.stores.SpaceResourceMock().EXPECT().Delete(ctx, database.SpaceResource{}).Return(nil) - -// err := sc.Delete(ctx, 1) -// require.Nil(t, err) -// } + err := sc.Delete(ctx, 1) + require.Nil(t, err) +} diff --git a/component/space_test.go b/component/space_test.go index 57d894de..72486077 100644 --- a/component/space_test.go +++ b/component/space_test.go @@ -6,7 +6,6 @@ import ( "testing" "time" - "github.com/alibabacloud-go/tea/tea" "github.com/stretchr/testify/require" "opencsg.com/csghub-server/builder/deploy" "opencsg.com/csghub-server/builder/deploy/scheduler" @@ -16,111 +15,6 @@ import ( "opencsg.com/csghub-server/common/types" ) -// func TestSpaceComponent_Create(t *testing.T) { -// ctx := context.TODO() -// sc := initializeTestSpaceComponent(ctx, t) - -// sc.mocks.stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(1)).Return(&database.SpaceResource{ -// ID: 1, -// Name: "sp", -// Resources: `{"memory": "foo"}`, -// }, nil) - -// sc.mocks.deployer.EXPECT().CheckResourceAvailable(ctx, int64(0), &types.HardWare{ -// Memory: "foo", -// }).Return(true, nil) - -// sc.mocks.components.repo.EXPECT().CreateRepo(ctx, types.CreateRepoReq{ -// DefaultBranch: "main", -// Readme: generateReadmeData("MIT"), -// License: "MIT", -// Namespace: "ns", -// Name: "n", -// Nickname: "n", -// RepoType: types.SpaceRepo, -// Username: "user", -// }).Return(nil, &database.Repository{ -// ID: 321, -// User: database.User{ -// Username: "user", -// Email: "foo@bar.com", -// }, -// }, nil) - -// sc.mocks.stores.SpaceMock().EXPECT().Create(ctx, database.Space{ -// RepositoryID: 321, -// Sdk: scheduler.STREAMLIT.Name, -// SdkVersion: "v1", -// Env: "env", -// Hardware: `{"memory": "foo"}`, -// Secrets: "sss", -// SKU: "1", -// }).Return(&database.Space{}, nil) -// sc.mocks.gitServer.EXPECT().CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ -// Username: "user", -// Email: "foo@bar.com", -// Message: initCommitMessage, -// Branch: "main", -// Content: generateReadmeData("MIT"), -// NewBranch: "main", -// Namespace: "ns", -// Name: "n", -// FilePath: readmeFileName, -// }, types.SpaceRepo)).Return(nil) -// sc.mocks.gitServer.EXPECT().CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ -// Username: "user", -// Email: "foo@bar.com", -// Message: initCommitMessage, -// Branch: "main", -// Content: spaceGitattributesContent, -// NewBranch: "main", -// Namespace: "ns", -// Name: "n", -// FilePath: gitattributesFileName, -// }, types.SpaceRepo)).Return(nil) -// sc.mocks.gitServer.EXPECT().CreateRepoFile(buildCreateFileReq(&types.CreateFileParams{ -// Username: "user", -// Email: "foo@bar.com", -// Message: initCommitMessage, -// Branch: "main", -// Content: streamlitConfigContent, -// NewBranch: "main", -// Namespace: "ns", -// Name: "n", -// FilePath: streamlitConfig, -// }, types.SpaceRepo)).Return(nil) - -// space, err := sc.Create(ctx, types.CreateSpaceReq{ -// Sdk: scheduler.STREAMLIT.Name, -// SdkVersion: "v1", -// Env: "env", -// Secrets: "sss", -// ResourceID: 1, -// ClusterID: "cluster", -// CreateRepoReq: types.CreateRepoReq{ -// DefaultBranch: "main", -// Readme: "readme", -// Namespace: "ns", -// Name: "n", -// License: "MIT", -// Username: "user", -// }, -// }) -// require.Nil(t, err) - -// require.Equal(t, &types.Space{ -// License: "MIT", -// Name: "n", -// Sdk: "streamlit", -// SdkVersion: "v1", -// Env: "env", -// Secrets: "sss", -// Hardware: `{"memory": "foo"}`, -// Creator: "user", -// }, space) - -// } - func TestSpaceComponent_Show(t *testing.T) { ctx := context.TODO() sc := initializeTestSpaceComponent(ctx, t) @@ -152,15 +46,16 @@ func TestSpaceComponent_Show(t *testing.T) { space, err := sc.Show(ctx, "ns", "n", "user") require.Nil(t, err) require.Equal(t, &types.Space{ - ID: 1, - Name: "n", - Namespace: &types.Namespace{Path: "ns"}, - UserLikes: true, - RepositoryID: 123, - Status: "Stopped", - CanManage: true, - User: &types.User{}, - Path: "foo/bar", + ID: 1, + Name: "n", + Namespace: &types.Namespace{Path: "ns"}, + UserLikes: true, + RepositoryID: 123, + Status: "Stopped", + CanManage: true, + User: &types.User{}, + Path: "foo/bar", + SensitiveCheckStatus: "Pending", Repository: &types.Repository{ HTTPCloneURL: "/s/foo/bar.git", SSHCloneURL: ":s/foo/bar.git", @@ -170,55 +65,6 @@ func TestSpaceComponent_Show(t *testing.T) { }, space) } -func TestSpaceComponent_Update(t *testing.T) { - ctx := context.TODO() - sc := initializeTestSpaceComponent(ctx, t) - - sc.mocks.stores.SpaceResourceMock().EXPECT().FindByID(ctx, int64(12)).Return(&database.SpaceResource{ - ID: 12, - Name: "sp", - Resources: `{"memory": "foo"}`, - }, nil) - - sc.mocks.components.repo.EXPECT().UpdateRepo(ctx, types.UpdateRepoReq{ - Username: "user", - Namespace: "ns", - Name: "n", - RepoType: types.SpaceRepo, - }).Return( - &database.Repository{ - ID: 123, - Name: "repo", - }, nil, - ) - sc.mocks.stores.SpaceMock().EXPECT().ByRepoID(ctx, int64(123)).Return(&database.Space{ - ID: 321, - }, nil) - sc.mocks.stores.SpaceMock().EXPECT().Update(ctx, database.Space{ - ID: 321, - Hardware: `{"memory": "foo"}`, - SKU: "12", - }).Return(nil) - - space, err := sc.Update(ctx, &types.UpdateSpaceReq{ - ResourceID: tea.Int64(12), - UpdateRepoReq: types.UpdateRepoReq{ - Username: "user", - Namespace: "ns", - Name: "n", - }, - }) - require.Nil(t, err) - - require.Equal(t, &types.Space{ - ID: 321, - Name: "repo", - Hardware: `{"memory": "foo"}`, - SKU: "12", - }, space) - -} - func TestSpaceComponent_Index(t *testing.T) { ctx := context.TODO() sc := initializeTestSpaceComponent(ctx, t) @@ -457,6 +303,31 @@ func TestSpaceComponent_Wakeup(t *testing.T) { } +func TestSpaceComponent_Stop(t *testing.T) { + ctx := context.TODO() + sc := initializeTestSpaceComponent(ctx, t) + sc.mocks.stores.SpaceMock().EXPECT().FindByPath(ctx, "ns", "n").Return(&database.Space{ + ID: 1, + }, nil) + + sc.mocks.stores.DeployTaskMock().EXPECT().GetLatestDeployBySpaceID(ctx, int64(1)).Return( + &database.Deploy{SvcName: "svc", RepoID: 1, UserID: 2, ID: 3}, nil, + ) + + sc.mocks.deployer.EXPECT().Stop(ctx, types.DeployRepo{ + SpaceID: 1, + Namespace: "ns", + Name: "n", + SvcName: "svc", + }).Return(nil) + sc.mocks.stores.DeployTaskMock().EXPECT().StopDeploy( + ctx, types.SpaceRepo, int64(1), int64(2), int64(3), + ).Return(nil) + + err := sc.Stop(ctx, "ns", "n", false) + require.Nil(t, err) +} + func TestSpaceComponent_FixHasEntryFile(t *testing.T) { cases := []struct { diff --git a/component/wireset.go b/component/wireset.go index e9394a66..a63d2ab5 100644 --- a/component/wireset.go +++ b/component/wireset.go @@ -207,34 +207,6 @@ func NewTestUserComponent( var UserComponentSet = wire.NewSet(NewTestUserComponent) -func NewTestSpaceComponent( - stores *tests.MockStores, - repoComponent RepoComponent, - git gitserver.GitServer, - deployer deploy.Deployer, - accountingComponent AccountingComponent, - config *config.Config, - userSvcClient rpc.UserSvcClient, -) *spaceComponentImpl { - return &spaceComponentImpl{ - repoComponent: repoComponent, - git: git, - spaceStore: stores.Space, - spaceSdkStore: stores.SpaceSdk, - spaceResourceStore: stores.SpaceResource, - repoStore: stores.Repo, - userStore: stores.User, - deployer: deployer, - publicRootDomain: config.Space.PublicRootDomain, - accountingComponent: accountingComponent, - serverBaseUrl: config.APIServer.PublicDomain, - userLikesStore: stores.UserLikes, - config: config, - userSvcClient: userSvcClient, - deployTaskStore: stores.DeployTask, - } -} - var SpaceComponentSet = wire.NewSet(NewTestSpaceComponent) func NewTestModelComponent( @@ -452,12 +424,6 @@ func NewTestMirrorSourceComponent(config *config.Config, stores *tests.MockStore var MirrorSourceComponentSet = wire.NewSet(NewTestMirrorSourceComponent) -func NewTestSpaceResourceComponent(config *config.Config, stores *tests.MockStores, deployer deploy.Deployer, accountComponent AccountingComponent) *spaceResourceComponentImpl { - return &spaceResourceComponentImpl{ - deployer: deployer, - } -} - var SpaceResourceComponentSet = wire.NewSet(NewTestSpaceResourceComponent) func NewTestTagComponent(config *config.Config, stores *tests.MockStores, sensitiveChecker rpc.ModerationSvcClient) *tagComponentImpl { diff --git a/component/wireset_ce.go b/component/wireset_ce.go new file mode 100644 index 00000000..ce845e06 --- /dev/null +++ b/component/wireset_ce.go @@ -0,0 +1,48 @@ +//go:build !saas + +package component + +import ( + "opencsg.com/csghub-server/builder/deploy" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/rpc" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/tests" +) + +func NewTestSpaceComponent( + stores *tests.MockStores, + repoComponent RepoComponent, + git gitserver.GitServer, + deployer deploy.Deployer, + accountingComponent AccountingComponent, + config *config.Config, + userSvcClient rpc.UserSvcClient, +) *spaceComponentImpl { + return &spaceComponentImpl{ + repoComponent: repoComponent, + git: git, + spaceStore: stores.Space, + spaceSdkStore: stores.SpaceSdk, + spaceResourceStore: stores.SpaceResource, + repoStore: stores.Repo, + userStore: stores.User, + deployer: deployer, + publicRootDomain: config.Space.PublicRootDomain, + accountingComponent: accountingComponent, + serverBaseUrl: config.APIServer.PublicDomain, + userLikesStore: stores.UserLikes, + config: config, + userSvcClient: userSvcClient, + deployTaskStore: stores.DeployTask, + } +} + +func NewTestSpaceResourceComponent(config *config.Config, stores *tests.MockStores, deployer deploy.Deployer, accountComponent AccountingComponent) *spaceResourceComponentImpl { + return &spaceResourceComponentImpl{ + spaceResourceStore: stores.SpaceResource, + deployer: deployer, + userStore: stores.User, + accountComponent: accountComponent, + } +} From 92314cdc2a05da13059ac03747f82f99c642c71b Mon Sep 17 00:00:00 2001 From: Yiling-J Date: Thu, 2 Jan 2025 11:29:28 +0800 Subject: [PATCH 30/34] Add some handler tests (#223) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Merge branch 'feature/handler_tests' into 'main' Add repo handler tests See merge request product/starhub/starhub-server!731 * fix repo handler test * Merge branch 'feature/handler_tests' into 'main' Add code handler tests and fix prompt component cycle import See merge request product/starhub/starhub-server!741 * Merge branch 'feature/handler_tests' into 'main' Add model/user/git-http handler tests See merge request product/starhub/starhub-server!751 * Add opencsg check back to git http handler * Merge branch 'feature/handler_tests' into 'main' Add space/discussion/dataset/collection handler tests See merge request product/starhub/starhub-server!759 * Merge branch 'feature/handler_tests' into 'main' Add some handler tests See merge request product/starhub/starhub-server!768 * Merge branch 'fix/swagger' into 'main' Fix swagger doc, update makefile and ci See merge request product/starhub/starhub-server!762 * add swag to makefile and update CI * fix swag ci * fix deployer resource check bug * Merge branch 'fix-internal-api-error' into 'main' Fix internal API error See merge request product/starhub/starhub-server!772 * bump go mod version --------- Co-authored-by: yiling.ji Co-authored-by: 泽华 --- .github/workflows/test.yaml | 22 +- .mockery.yaml | 18 +- Makefile | 7 +- .../go.temporal.io/sdk/client/mock_Client.go | 2072 +++++++++++++++++ .../component/mock_CollectionComponent.go | 541 +++++ .../component/mock_DatasetComponent.go | 460 ++++ .../component/mock_DiscussionComponent.go | 524 +++++ .../component/mock_EvaluationComponent.go | 202 ++ .../component/mock_InternalComponent.go | 331 +++ .../component/mock_MirrorComponent.go | 386 +++ .../component/mock_MirrorSourceComponent.go | 324 +++ .../component/mock_SpaceResourceComponent.go | 263 +++ api/handler/accounting.go | 10 +- api/handler/accounting_test.go | 60 + api/handler/collection.go | 26 +- api/handler/collection_test.go | 173 ++ api/handler/dataset.go | 28 +- api/handler/dataset_test.go | 189 ++ api/handler/discussion.go | 62 +- api/handler/discussion_test.go | 212 ++ api/handler/evaluation.go | 16 +- api/handler/evaluation_test.go | 85 + api/handler/internal.go | 22 +- api/handler/internal_test.go | 191 ++ api/handler/mirror.go | 46 +- api/handler/mirror_source.go | 61 +- api/handler/mirror_source_test.go | 104 + api/handler/mirror_test.go | 95 + api/handler/organization.go | 36 +- api/handler/organization_test.go | 175 ++ api/handler/runtime_architecture.go | 16 +- api/handler/runtime_architecture_test.go | 94 + api/handler/space.go | 41 +- api/handler/space_resource.go | 12 +- api/handler/space_resource_test.go | 98 + api/handler/space_test.go | 259 +++ builder/deploy/deployer.go | 6 +- builder/deploy/deployer_test.go | 5 + cmd/csghub-server/cmd/start/server.go | 8 +- common/types/discussion.go | 144 ++ component/discussion.go | 193 +- component/discussion_test.go | 24 +- docs/docs.go | 1075 +++------ docs/swagger.json | 1075 +++------ docs/swagger.yaml | 566 +---- go.mod | 5 +- go.sum | 5 - 47 files changed, 7942 insertions(+), 2425 deletions(-) create mode 100644 _mocks/go.temporal.io/sdk/client/mock_Client.go create mode 100644 _mocks/opencsg.com/csghub-server/component/mock_CollectionComponent.go create mode 100644 _mocks/opencsg.com/csghub-server/component/mock_DatasetComponent.go create mode 100644 _mocks/opencsg.com/csghub-server/component/mock_DiscussionComponent.go create mode 100644 _mocks/opencsg.com/csghub-server/component/mock_EvaluationComponent.go create mode 100644 _mocks/opencsg.com/csghub-server/component/mock_InternalComponent.go create mode 100644 _mocks/opencsg.com/csghub-server/component/mock_MirrorComponent.go create mode 100644 _mocks/opencsg.com/csghub-server/component/mock_MirrorSourceComponent.go create mode 100644 _mocks/opencsg.com/csghub-server/component/mock_SpaceResourceComponent.go create mode 100644 api/handler/accounting_test.go create mode 100644 api/handler/collection_test.go create mode 100644 api/handler/dataset_test.go create mode 100644 api/handler/discussion_test.go create mode 100644 api/handler/evaluation_test.go create mode 100644 api/handler/internal_test.go create mode 100644 api/handler/mirror_source_test.go create mode 100644 api/handler/mirror_test.go create mode 100644 api/handler/organization_test.go create mode 100644 api/handler/runtime_architecture_test.go create mode 100644 api/handler/space_resource_test.go create mode 100644 api/handler/space_test.go create mode 100644 common/types/discussion.go diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 9cd8b36a..7b33645e 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -17,17 +17,35 @@ jobs: steps: - uses: actions/setup-go@v5 with: - go-version: '1.21' + go-version: '1.23' - uses: actions/checkout@v4 - name: golangci-lint uses: golangci/golangci-lint-action@v6 with: args: --timeout=5m + swagger: + name: swagger-gen + strategy: + matrix: + go: ["1.23"] + runs-on: ubuntu-latest + steps: + - name: Setup Go + with: + go-version: ${{ matrix.go }} + uses: actions/setup-go@v2 + + - uses: actions/checkout@v2 + + - name: Gen + run: | + go install github.com/swaggo/swag/cmd/swag@latest + make swag test: name: test strategy: matrix: - go: ["1.21.x"] + go: ["1.23"] runs-on: ubuntu-latest steps: - name: Setup Go diff --git a/.mockery.yaml b/.mockery.yaml index 6f7dc0b1..719c4b2b 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -15,6 +15,7 @@ packages: TagComponent: AccountingComponent: SpaceComponent: + SpaceResourceComponent: RuntimeArchitectureComponent: SensitiveComponent: CodeComponent: @@ -22,6 +23,15 @@ packages: ModelComponent: UserComponent: GitHTTPComponent: + DiscussionComponent: + DatasetComponent: + CollectionComponent: + InternalComponent: + MirrorSourceComponent: + MirrorComponent: + EvaluationComponent: + + opencsg.com/csghub-server/user/component: config: interfaces: @@ -67,7 +77,7 @@ packages: opencsg.com/csghub-server/mq: config: interfaces: - MessageQueue: + MessageQueue: opencsg.com/csghub-server/builder/store/s3: config: interfaces: @@ -92,7 +102,7 @@ packages: config: interfaces: Builder: - + opencsg.com/csghub-server/accounting/component: config: interfaces: @@ -116,3 +126,7 @@ packages: config: interfaces: Msg: + go.temporal.io/sdk/client: + config: + interfaces: + Client: diff --git a/Makefile b/Makefile index 739643b5..46ed5ddb 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ -.PHONY: test lint cover mock_wire mock_gen - +.PHONY: test lint cover mock_wire mock_gen swag + test: go test ./... @@ -25,3 +25,6 @@ mock_wire: mock_gen: mockery + +swag: + swag init --pd -d cmd/csghub-server/cmd/start,api/router,api/handler,builder/store/database,common/types,accounting/handler,user/handler,component -g server.go diff --git a/_mocks/go.temporal.io/sdk/client/mock_Client.go b/_mocks/go.temporal.io/sdk/client/mock_Client.go new file mode 100644 index 00000000..350bfeff --- /dev/null +++ b/_mocks/go.temporal.io/sdk/client/mock_Client.go @@ -0,0 +1,2072 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package client + +import ( + context "context" + + client "go.temporal.io/sdk/client" + + converter "go.temporal.io/sdk/converter" + + enums "go.temporal.io/api/enums/v1" + + mock "github.com/stretchr/testify/mock" + + operatorservice "go.temporal.io/api/operatorservice/v1" + + workflowservice "go.temporal.io/api/workflowservice/v1" +) + +// MockClient is an autogenerated mock type for the Client type +type MockClient struct { + mock.Mock +} + +type MockClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockClient) EXPECT() *MockClient_Expecter { + return &MockClient_Expecter{mock: &_m.Mock} +} + +// CancelWorkflow provides a mock function with given fields: ctx, workflowID, runID +func (_m *MockClient) CancelWorkflow(ctx context.Context, workflowID string, runID string) error { + ret := _m.Called(ctx, workflowID, runID) + + if len(ret) == 0 { + panic("no return value specified for CancelWorkflow") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, workflowID, runID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_CancelWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CancelWorkflow' +type MockClient_CancelWorkflow_Call struct { + *mock.Call +} + +// CancelWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - workflowID string +// - runID string +func (_e *MockClient_Expecter) CancelWorkflow(ctx interface{}, workflowID interface{}, runID interface{}) *MockClient_CancelWorkflow_Call { + return &MockClient_CancelWorkflow_Call{Call: _e.mock.On("CancelWorkflow", ctx, workflowID, runID)} +} + +func (_c *MockClient_CancelWorkflow_Call) Run(run func(ctx context.Context, workflowID string, runID string)) *MockClient_CancelWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MockClient_CancelWorkflow_Call) Return(_a0 error) *MockClient_CancelWorkflow_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_CancelWorkflow_Call) RunAndReturn(run func(context.Context, string, string) error) *MockClient_CancelWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// CheckHealth provides a mock function with given fields: ctx, request +func (_m *MockClient) CheckHealth(ctx context.Context, request *client.CheckHealthRequest) (*client.CheckHealthResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for CheckHealth") + } + + var r0 *client.CheckHealthResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *client.CheckHealthRequest) (*client.CheckHealthResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *client.CheckHealthRequest) *client.CheckHealthResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*client.CheckHealthResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *client.CheckHealthRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_CheckHealth_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckHealth' +type MockClient_CheckHealth_Call struct { + *mock.Call +} + +// CheckHealth is a helper method to define mock.On call +// - ctx context.Context +// - request *client.CheckHealthRequest +func (_e *MockClient_Expecter) CheckHealth(ctx interface{}, request interface{}) *MockClient_CheckHealth_Call { + return &MockClient_CheckHealth_Call{Call: _e.mock.On("CheckHealth", ctx, request)} +} + +func (_c *MockClient_CheckHealth_Call) Run(run func(ctx context.Context, request *client.CheckHealthRequest)) *MockClient_CheckHealth_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*client.CheckHealthRequest)) + }) + return _c +} + +func (_c *MockClient_CheckHealth_Call) Return(_a0 *client.CheckHealthResponse, _a1 error) *MockClient_CheckHealth_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_CheckHealth_Call) RunAndReturn(run func(context.Context, *client.CheckHealthRequest) (*client.CheckHealthResponse, error)) *MockClient_CheckHealth_Call { + _c.Call.Return(run) + return _c +} + +// Close provides a mock function with given fields: +func (_m *MockClient) Close() { + _m.Called() +} + +// MockClient_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockClient_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockClient_Expecter) Close() *MockClient_Close_Call { + return &MockClient_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockClient_Close_Call) Run(run func()) *MockClient_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockClient_Close_Call) Return() *MockClient_Close_Call { + _c.Call.Return() + return _c +} + +func (_c *MockClient_Close_Call) RunAndReturn(run func()) *MockClient_Close_Call { + _c.Call.Return(run) + return _c +} + +// CompleteActivity provides a mock function with given fields: ctx, taskToken, result, err +func (_m *MockClient) CompleteActivity(ctx context.Context, taskToken []byte, result interface{}, err error) error { + ret := _m.Called(ctx, taskToken, result, err) + + if len(ret) == 0 { + panic("no return value specified for CompleteActivity") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, []byte, interface{}, error) error); ok { + r0 = rf(ctx, taskToken, result, err) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_CompleteActivity_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CompleteActivity' +type MockClient_CompleteActivity_Call struct { + *mock.Call +} + +// CompleteActivity is a helper method to define mock.On call +// - ctx context.Context +// - taskToken []byte +// - result interface{} +// - err error +func (_e *MockClient_Expecter) CompleteActivity(ctx interface{}, taskToken interface{}, result interface{}, err interface{}) *MockClient_CompleteActivity_Call { + return &MockClient_CompleteActivity_Call{Call: _e.mock.On("CompleteActivity", ctx, taskToken, result, err)} +} + +func (_c *MockClient_CompleteActivity_Call) Run(run func(ctx context.Context, taskToken []byte, result interface{}, err error)) *MockClient_CompleteActivity_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([]byte), args[2].(interface{}), args[3].(error)) + }) + return _c +} + +func (_c *MockClient_CompleteActivity_Call) Return(_a0 error) *MockClient_CompleteActivity_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_CompleteActivity_Call) RunAndReturn(run func(context.Context, []byte, interface{}, error) error) *MockClient_CompleteActivity_Call { + _c.Call.Return(run) + return _c +} + +// CompleteActivityByID provides a mock function with given fields: ctx, namespace, workflowID, runID, activityID, result, err +func (_m *MockClient) CompleteActivityByID(ctx context.Context, namespace string, workflowID string, runID string, activityID string, result interface{}, err error) error { + ret := _m.Called(ctx, namespace, workflowID, runID, activityID, result, err) + + if len(ret) == 0 { + panic("no return value specified for CompleteActivityByID") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string, interface{}, error) error); ok { + r0 = rf(ctx, namespace, workflowID, runID, activityID, result, err) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_CompleteActivityByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CompleteActivityByID' +type MockClient_CompleteActivityByID_Call struct { + *mock.Call +} + +// CompleteActivityByID is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - workflowID string +// - runID string +// - activityID string +// - result interface{} +// - err error +func (_e *MockClient_Expecter) CompleteActivityByID(ctx interface{}, namespace interface{}, workflowID interface{}, runID interface{}, activityID interface{}, result interface{}, err interface{}) *MockClient_CompleteActivityByID_Call { + return &MockClient_CompleteActivityByID_Call{Call: _e.mock.On("CompleteActivityByID", ctx, namespace, workflowID, runID, activityID, result, err)} +} + +func (_c *MockClient_CompleteActivityByID_Call) Run(run func(ctx context.Context, namespace string, workflowID string, runID string, activityID string, result interface{}, err error)) *MockClient_CompleteActivityByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string), args[5].(interface{}), args[6].(error)) + }) + return _c +} + +func (_c *MockClient_CompleteActivityByID_Call) Return(_a0 error) *MockClient_CompleteActivityByID_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_CompleteActivityByID_Call) RunAndReturn(run func(context.Context, string, string, string, string, interface{}, error) error) *MockClient_CompleteActivityByID_Call { + _c.Call.Return(run) + return _c +} + +// CountWorkflow provides a mock function with given fields: ctx, request +func (_m *MockClient) CountWorkflow(ctx context.Context, request *workflowservice.CountWorkflowExecutionsRequest) (*workflowservice.CountWorkflowExecutionsResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for CountWorkflow") + } + + var r0 *workflowservice.CountWorkflowExecutionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.CountWorkflowExecutionsRequest) (*workflowservice.CountWorkflowExecutionsResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.CountWorkflowExecutionsRequest) *workflowservice.CountWorkflowExecutionsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.CountWorkflowExecutionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *workflowservice.CountWorkflowExecutionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_CountWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CountWorkflow' +type MockClient_CountWorkflow_Call struct { + *mock.Call +} + +// CountWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - request *workflowservice.CountWorkflowExecutionsRequest +func (_e *MockClient_Expecter) CountWorkflow(ctx interface{}, request interface{}) *MockClient_CountWorkflow_Call { + return &MockClient_CountWorkflow_Call{Call: _e.mock.On("CountWorkflow", ctx, request)} +} + +func (_c *MockClient_CountWorkflow_Call) Run(run func(ctx context.Context, request *workflowservice.CountWorkflowExecutionsRequest)) *MockClient_CountWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*workflowservice.CountWorkflowExecutionsRequest)) + }) + return _c +} + +func (_c *MockClient_CountWorkflow_Call) Return(_a0 *workflowservice.CountWorkflowExecutionsResponse, _a1 error) *MockClient_CountWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_CountWorkflow_Call) RunAndReturn(run func(context.Context, *workflowservice.CountWorkflowExecutionsRequest) (*workflowservice.CountWorkflowExecutionsResponse, error)) *MockClient_CountWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// DescribeTaskQueue provides a mock function with given fields: ctx, taskqueue, taskqueueType +func (_m *MockClient) DescribeTaskQueue(ctx context.Context, taskqueue string, taskqueueType enums.TaskQueueType) (*workflowservice.DescribeTaskQueueResponse, error) { + ret := _m.Called(ctx, taskqueue, taskqueueType) + + if len(ret) == 0 { + panic("no return value specified for DescribeTaskQueue") + } + + var r0 *workflowservice.DescribeTaskQueueResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, enums.TaskQueueType) (*workflowservice.DescribeTaskQueueResponse, error)); ok { + return rf(ctx, taskqueue, taskqueueType) + } + if rf, ok := ret.Get(0).(func(context.Context, string, enums.TaskQueueType) *workflowservice.DescribeTaskQueueResponse); ok { + r0 = rf(ctx, taskqueue, taskqueueType) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.DescribeTaskQueueResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, enums.TaskQueueType) error); ok { + r1 = rf(ctx, taskqueue, taskqueueType) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_DescribeTaskQueue_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeTaskQueue' +type MockClient_DescribeTaskQueue_Call struct { + *mock.Call +} + +// DescribeTaskQueue is a helper method to define mock.On call +// - ctx context.Context +// - taskqueue string +// - taskqueueType enums.TaskQueueType +func (_e *MockClient_Expecter) DescribeTaskQueue(ctx interface{}, taskqueue interface{}, taskqueueType interface{}) *MockClient_DescribeTaskQueue_Call { + return &MockClient_DescribeTaskQueue_Call{Call: _e.mock.On("DescribeTaskQueue", ctx, taskqueue, taskqueueType)} +} + +func (_c *MockClient_DescribeTaskQueue_Call) Run(run func(ctx context.Context, taskqueue string, taskqueueType enums.TaskQueueType)) *MockClient_DescribeTaskQueue_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(enums.TaskQueueType)) + }) + return _c +} + +func (_c *MockClient_DescribeTaskQueue_Call) Return(_a0 *workflowservice.DescribeTaskQueueResponse, _a1 error) *MockClient_DescribeTaskQueue_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_DescribeTaskQueue_Call) RunAndReturn(run func(context.Context, string, enums.TaskQueueType) (*workflowservice.DescribeTaskQueueResponse, error)) *MockClient_DescribeTaskQueue_Call { + _c.Call.Return(run) + return _c +} + +// DescribeTaskQueueEnhanced provides a mock function with given fields: ctx, options +func (_m *MockClient) DescribeTaskQueueEnhanced(ctx context.Context, options client.DescribeTaskQueueEnhancedOptions) (client.TaskQueueDescription, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for DescribeTaskQueueEnhanced") + } + + var r0 client.TaskQueueDescription + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, client.DescribeTaskQueueEnhancedOptions) (client.TaskQueueDescription, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, client.DescribeTaskQueueEnhancedOptions) client.TaskQueueDescription); ok { + r0 = rf(ctx, options) + } else { + r0 = ret.Get(0).(client.TaskQueueDescription) + } + + if rf, ok := ret.Get(1).(func(context.Context, client.DescribeTaskQueueEnhancedOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_DescribeTaskQueueEnhanced_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeTaskQueueEnhanced' +type MockClient_DescribeTaskQueueEnhanced_Call struct { + *mock.Call +} + +// DescribeTaskQueueEnhanced is a helper method to define mock.On call +// - ctx context.Context +// - options client.DescribeTaskQueueEnhancedOptions +func (_e *MockClient_Expecter) DescribeTaskQueueEnhanced(ctx interface{}, options interface{}) *MockClient_DescribeTaskQueueEnhanced_Call { + return &MockClient_DescribeTaskQueueEnhanced_Call{Call: _e.mock.On("DescribeTaskQueueEnhanced", ctx, options)} +} + +func (_c *MockClient_DescribeTaskQueueEnhanced_Call) Run(run func(ctx context.Context, options client.DescribeTaskQueueEnhancedOptions)) *MockClient_DescribeTaskQueueEnhanced_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(client.DescribeTaskQueueEnhancedOptions)) + }) + return _c +} + +func (_c *MockClient_DescribeTaskQueueEnhanced_Call) Return(_a0 client.TaskQueueDescription, _a1 error) *MockClient_DescribeTaskQueueEnhanced_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_DescribeTaskQueueEnhanced_Call) RunAndReturn(run func(context.Context, client.DescribeTaskQueueEnhancedOptions) (client.TaskQueueDescription, error)) *MockClient_DescribeTaskQueueEnhanced_Call { + _c.Call.Return(run) + return _c +} + +// DescribeWorkflowExecution provides a mock function with given fields: ctx, workflowID, runID +func (_m *MockClient) DescribeWorkflowExecution(ctx context.Context, workflowID string, runID string) (*workflowservice.DescribeWorkflowExecutionResponse, error) { + ret := _m.Called(ctx, workflowID, runID) + + if len(ret) == 0 { + panic("no return value specified for DescribeWorkflowExecution") + } + + var r0 *workflowservice.DescribeWorkflowExecutionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (*workflowservice.DescribeWorkflowExecutionResponse, error)); ok { + return rf(ctx, workflowID, runID) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) *workflowservice.DescribeWorkflowExecutionResponse); ok { + r0 = rf(ctx, workflowID, runID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.DescribeWorkflowExecutionResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, workflowID, runID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_DescribeWorkflowExecution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeWorkflowExecution' +type MockClient_DescribeWorkflowExecution_Call struct { + *mock.Call +} + +// DescribeWorkflowExecution is a helper method to define mock.On call +// - ctx context.Context +// - workflowID string +// - runID string +func (_e *MockClient_Expecter) DescribeWorkflowExecution(ctx interface{}, workflowID interface{}, runID interface{}) *MockClient_DescribeWorkflowExecution_Call { + return &MockClient_DescribeWorkflowExecution_Call{Call: _e.mock.On("DescribeWorkflowExecution", ctx, workflowID, runID)} +} + +func (_c *MockClient_DescribeWorkflowExecution_Call) Run(run func(ctx context.Context, workflowID string, runID string)) *MockClient_DescribeWorkflowExecution_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MockClient_DescribeWorkflowExecution_Call) Return(_a0 *workflowservice.DescribeWorkflowExecutionResponse, _a1 error) *MockClient_DescribeWorkflowExecution_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_DescribeWorkflowExecution_Call) RunAndReturn(run func(context.Context, string, string) (*workflowservice.DescribeWorkflowExecutionResponse, error)) *MockClient_DescribeWorkflowExecution_Call { + _c.Call.Return(run) + return _c +} + +// ExecuteWorkflow provides a mock function with given fields: ctx, options, workflow, args +func (_m *MockClient) ExecuteWorkflow(ctx context.Context, options client.StartWorkflowOptions, workflow interface{}, args ...interface{}) (client.WorkflowRun, error) { + var _ca []interface{} + _ca = append(_ca, ctx, options, workflow) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for ExecuteWorkflow") + } + + var r0 client.WorkflowRun + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, client.StartWorkflowOptions, interface{}, ...interface{}) (client.WorkflowRun, error)); ok { + return rf(ctx, options, workflow, args...) + } + if rf, ok := ret.Get(0).(func(context.Context, client.StartWorkflowOptions, interface{}, ...interface{}) client.WorkflowRun); ok { + r0 = rf(ctx, options, workflow, args...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.WorkflowRun) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, client.StartWorkflowOptions, interface{}, ...interface{}) error); ok { + r1 = rf(ctx, options, workflow, args...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_ExecuteWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ExecuteWorkflow' +type MockClient_ExecuteWorkflow_Call struct { + *mock.Call +} + +// ExecuteWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - options client.StartWorkflowOptions +// - workflow interface{} +// - args ...interface{} +func (_e *MockClient_Expecter) ExecuteWorkflow(ctx interface{}, options interface{}, workflow interface{}, args ...interface{}) *MockClient_ExecuteWorkflow_Call { + return &MockClient_ExecuteWorkflow_Call{Call: _e.mock.On("ExecuteWorkflow", + append([]interface{}{ctx, options, workflow}, args...)...)} +} + +func (_c *MockClient_ExecuteWorkflow_Call) Run(run func(ctx context.Context, options client.StartWorkflowOptions, workflow interface{}, args ...interface{})) *MockClient_ExecuteWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-3) + for i, a := range args[3:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(context.Context), args[1].(client.StartWorkflowOptions), args[2].(interface{}), variadicArgs...) + }) + return _c +} + +func (_c *MockClient_ExecuteWorkflow_Call) Return(_a0 client.WorkflowRun, _a1 error) *MockClient_ExecuteWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_ExecuteWorkflow_Call) RunAndReturn(run func(context.Context, client.StartWorkflowOptions, interface{}, ...interface{}) (client.WorkflowRun, error)) *MockClient_ExecuteWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// GetSearchAttributes provides a mock function with given fields: ctx +func (_m *MockClient) GetSearchAttributes(ctx context.Context) (*workflowservice.GetSearchAttributesResponse, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetSearchAttributes") + } + + var r0 *workflowservice.GetSearchAttributesResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (*workflowservice.GetSearchAttributesResponse, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) *workflowservice.GetSearchAttributesResponse); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.GetSearchAttributesResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_GetSearchAttributes_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSearchAttributes' +type MockClient_GetSearchAttributes_Call struct { + *mock.Call +} + +// GetSearchAttributes is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockClient_Expecter) GetSearchAttributes(ctx interface{}) *MockClient_GetSearchAttributes_Call { + return &MockClient_GetSearchAttributes_Call{Call: _e.mock.On("GetSearchAttributes", ctx)} +} + +func (_c *MockClient_GetSearchAttributes_Call) Run(run func(ctx context.Context)) *MockClient_GetSearchAttributes_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockClient_GetSearchAttributes_Call) Return(_a0 *workflowservice.GetSearchAttributesResponse, _a1 error) *MockClient_GetSearchAttributes_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_GetSearchAttributes_Call) RunAndReturn(run func(context.Context) (*workflowservice.GetSearchAttributesResponse, error)) *MockClient_GetSearchAttributes_Call { + _c.Call.Return(run) + return _c +} + +// GetWorkerBuildIdCompatibility provides a mock function with given fields: ctx, options +func (_m *MockClient) GetWorkerBuildIdCompatibility(ctx context.Context, options *client.GetWorkerBuildIdCompatibilityOptions) (*client.WorkerBuildIDVersionSets, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for GetWorkerBuildIdCompatibility") + } + + var r0 *client.WorkerBuildIDVersionSets + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *client.GetWorkerBuildIdCompatibilityOptions) (*client.WorkerBuildIDVersionSets, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, *client.GetWorkerBuildIdCompatibilityOptions) *client.WorkerBuildIDVersionSets); ok { + r0 = rf(ctx, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*client.WorkerBuildIDVersionSets) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *client.GetWorkerBuildIdCompatibilityOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_GetWorkerBuildIdCompatibility_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetWorkerBuildIdCompatibility' +type MockClient_GetWorkerBuildIdCompatibility_Call struct { + *mock.Call +} + +// GetWorkerBuildIdCompatibility is a helper method to define mock.On call +// - ctx context.Context +// - options *client.GetWorkerBuildIdCompatibilityOptions +func (_e *MockClient_Expecter) GetWorkerBuildIdCompatibility(ctx interface{}, options interface{}) *MockClient_GetWorkerBuildIdCompatibility_Call { + return &MockClient_GetWorkerBuildIdCompatibility_Call{Call: _e.mock.On("GetWorkerBuildIdCompatibility", ctx, options)} +} + +func (_c *MockClient_GetWorkerBuildIdCompatibility_Call) Run(run func(ctx context.Context, options *client.GetWorkerBuildIdCompatibilityOptions)) *MockClient_GetWorkerBuildIdCompatibility_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*client.GetWorkerBuildIdCompatibilityOptions)) + }) + return _c +} + +func (_c *MockClient_GetWorkerBuildIdCompatibility_Call) Return(_a0 *client.WorkerBuildIDVersionSets, _a1 error) *MockClient_GetWorkerBuildIdCompatibility_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_GetWorkerBuildIdCompatibility_Call) RunAndReturn(run func(context.Context, *client.GetWorkerBuildIdCompatibilityOptions) (*client.WorkerBuildIDVersionSets, error)) *MockClient_GetWorkerBuildIdCompatibility_Call { + _c.Call.Return(run) + return _c +} + +// GetWorkerTaskReachability provides a mock function with given fields: ctx, options +func (_m *MockClient) GetWorkerTaskReachability(ctx context.Context, options *client.GetWorkerTaskReachabilityOptions) (*client.WorkerTaskReachability, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for GetWorkerTaskReachability") + } + + var r0 *client.WorkerTaskReachability + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *client.GetWorkerTaskReachabilityOptions) (*client.WorkerTaskReachability, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, *client.GetWorkerTaskReachabilityOptions) *client.WorkerTaskReachability); ok { + r0 = rf(ctx, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*client.WorkerTaskReachability) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *client.GetWorkerTaskReachabilityOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_GetWorkerTaskReachability_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetWorkerTaskReachability' +type MockClient_GetWorkerTaskReachability_Call struct { + *mock.Call +} + +// GetWorkerTaskReachability is a helper method to define mock.On call +// - ctx context.Context +// - options *client.GetWorkerTaskReachabilityOptions +func (_e *MockClient_Expecter) GetWorkerTaskReachability(ctx interface{}, options interface{}) *MockClient_GetWorkerTaskReachability_Call { + return &MockClient_GetWorkerTaskReachability_Call{Call: _e.mock.On("GetWorkerTaskReachability", ctx, options)} +} + +func (_c *MockClient_GetWorkerTaskReachability_Call) Run(run func(ctx context.Context, options *client.GetWorkerTaskReachabilityOptions)) *MockClient_GetWorkerTaskReachability_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*client.GetWorkerTaskReachabilityOptions)) + }) + return _c +} + +func (_c *MockClient_GetWorkerTaskReachability_Call) Return(_a0 *client.WorkerTaskReachability, _a1 error) *MockClient_GetWorkerTaskReachability_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_GetWorkerTaskReachability_Call) RunAndReturn(run func(context.Context, *client.GetWorkerTaskReachabilityOptions) (*client.WorkerTaskReachability, error)) *MockClient_GetWorkerTaskReachability_Call { + _c.Call.Return(run) + return _c +} + +// GetWorkerVersioningRules provides a mock function with given fields: ctx, options +func (_m *MockClient) GetWorkerVersioningRules(ctx context.Context, options client.GetWorkerVersioningOptions) (*client.WorkerVersioningRules, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for GetWorkerVersioningRules") + } + + var r0 *client.WorkerVersioningRules + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, client.GetWorkerVersioningOptions) (*client.WorkerVersioningRules, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, client.GetWorkerVersioningOptions) *client.WorkerVersioningRules); ok { + r0 = rf(ctx, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*client.WorkerVersioningRules) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, client.GetWorkerVersioningOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_GetWorkerVersioningRules_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetWorkerVersioningRules' +type MockClient_GetWorkerVersioningRules_Call struct { + *mock.Call +} + +// GetWorkerVersioningRules is a helper method to define mock.On call +// - ctx context.Context +// - options client.GetWorkerVersioningOptions +func (_e *MockClient_Expecter) GetWorkerVersioningRules(ctx interface{}, options interface{}) *MockClient_GetWorkerVersioningRules_Call { + return &MockClient_GetWorkerVersioningRules_Call{Call: _e.mock.On("GetWorkerVersioningRules", ctx, options)} +} + +func (_c *MockClient_GetWorkerVersioningRules_Call) Run(run func(ctx context.Context, options client.GetWorkerVersioningOptions)) *MockClient_GetWorkerVersioningRules_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(client.GetWorkerVersioningOptions)) + }) + return _c +} + +func (_c *MockClient_GetWorkerVersioningRules_Call) Return(_a0 *client.WorkerVersioningRules, _a1 error) *MockClient_GetWorkerVersioningRules_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_GetWorkerVersioningRules_Call) RunAndReturn(run func(context.Context, client.GetWorkerVersioningOptions) (*client.WorkerVersioningRules, error)) *MockClient_GetWorkerVersioningRules_Call { + _c.Call.Return(run) + return _c +} + +// GetWorkflow provides a mock function with given fields: ctx, workflowID, runID +func (_m *MockClient) GetWorkflow(ctx context.Context, workflowID string, runID string) client.WorkflowRun { + ret := _m.Called(ctx, workflowID, runID) + + if len(ret) == 0 { + panic("no return value specified for GetWorkflow") + } + + var r0 client.WorkflowRun + if rf, ok := ret.Get(0).(func(context.Context, string, string) client.WorkflowRun); ok { + r0 = rf(ctx, workflowID, runID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.WorkflowRun) + } + } + + return r0 +} + +// MockClient_GetWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetWorkflow' +type MockClient_GetWorkflow_Call struct { + *mock.Call +} + +// GetWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - workflowID string +// - runID string +func (_e *MockClient_Expecter) GetWorkflow(ctx interface{}, workflowID interface{}, runID interface{}) *MockClient_GetWorkflow_Call { + return &MockClient_GetWorkflow_Call{Call: _e.mock.On("GetWorkflow", ctx, workflowID, runID)} +} + +func (_c *MockClient_GetWorkflow_Call) Run(run func(ctx context.Context, workflowID string, runID string)) *MockClient_GetWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MockClient_GetWorkflow_Call) Return(_a0 client.WorkflowRun) *MockClient_GetWorkflow_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_GetWorkflow_Call) RunAndReturn(run func(context.Context, string, string) client.WorkflowRun) *MockClient_GetWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// GetWorkflowHistory provides a mock function with given fields: ctx, workflowID, runID, isLongPoll, filterType +func (_m *MockClient) GetWorkflowHistory(ctx context.Context, workflowID string, runID string, isLongPoll bool, filterType enums.HistoryEventFilterType) client.HistoryEventIterator { + ret := _m.Called(ctx, workflowID, runID, isLongPoll, filterType) + + if len(ret) == 0 { + panic("no return value specified for GetWorkflowHistory") + } + + var r0 client.HistoryEventIterator + if rf, ok := ret.Get(0).(func(context.Context, string, string, bool, enums.HistoryEventFilterType) client.HistoryEventIterator); ok { + r0 = rf(ctx, workflowID, runID, isLongPoll, filterType) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.HistoryEventIterator) + } + } + + return r0 +} + +// MockClient_GetWorkflowHistory_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetWorkflowHistory' +type MockClient_GetWorkflowHistory_Call struct { + *mock.Call +} + +// GetWorkflowHistory is a helper method to define mock.On call +// - ctx context.Context +// - workflowID string +// - runID string +// - isLongPoll bool +// - filterType enums.HistoryEventFilterType +func (_e *MockClient_Expecter) GetWorkflowHistory(ctx interface{}, workflowID interface{}, runID interface{}, isLongPoll interface{}, filterType interface{}) *MockClient_GetWorkflowHistory_Call { + return &MockClient_GetWorkflowHistory_Call{Call: _e.mock.On("GetWorkflowHistory", ctx, workflowID, runID, isLongPoll, filterType)} +} + +func (_c *MockClient_GetWorkflowHistory_Call) Run(run func(ctx context.Context, workflowID string, runID string, isLongPoll bool, filterType enums.HistoryEventFilterType)) *MockClient_GetWorkflowHistory_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(bool), args[4].(enums.HistoryEventFilterType)) + }) + return _c +} + +func (_c *MockClient_GetWorkflowHistory_Call) Return(_a0 client.HistoryEventIterator) *MockClient_GetWorkflowHistory_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_GetWorkflowHistory_Call) RunAndReturn(run func(context.Context, string, string, bool, enums.HistoryEventFilterType) client.HistoryEventIterator) *MockClient_GetWorkflowHistory_Call { + _c.Call.Return(run) + return _c +} + +// GetWorkflowUpdateHandle provides a mock function with given fields: ref +func (_m *MockClient) GetWorkflowUpdateHandle(ref client.GetWorkflowUpdateHandleOptions) client.WorkflowUpdateHandle { + ret := _m.Called(ref) + + if len(ret) == 0 { + panic("no return value specified for GetWorkflowUpdateHandle") + } + + var r0 client.WorkflowUpdateHandle + if rf, ok := ret.Get(0).(func(client.GetWorkflowUpdateHandleOptions) client.WorkflowUpdateHandle); ok { + r0 = rf(ref) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.WorkflowUpdateHandle) + } + } + + return r0 +} + +// MockClient_GetWorkflowUpdateHandle_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetWorkflowUpdateHandle' +type MockClient_GetWorkflowUpdateHandle_Call struct { + *mock.Call +} + +// GetWorkflowUpdateHandle is a helper method to define mock.On call +// - ref client.GetWorkflowUpdateHandleOptions +func (_e *MockClient_Expecter) GetWorkflowUpdateHandle(ref interface{}) *MockClient_GetWorkflowUpdateHandle_Call { + return &MockClient_GetWorkflowUpdateHandle_Call{Call: _e.mock.On("GetWorkflowUpdateHandle", ref)} +} + +func (_c *MockClient_GetWorkflowUpdateHandle_Call) Run(run func(ref client.GetWorkflowUpdateHandleOptions)) *MockClient_GetWorkflowUpdateHandle_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(client.GetWorkflowUpdateHandleOptions)) + }) + return _c +} + +func (_c *MockClient_GetWorkflowUpdateHandle_Call) Return(_a0 client.WorkflowUpdateHandle) *MockClient_GetWorkflowUpdateHandle_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_GetWorkflowUpdateHandle_Call) RunAndReturn(run func(client.GetWorkflowUpdateHandleOptions) client.WorkflowUpdateHandle) *MockClient_GetWorkflowUpdateHandle_Call { + _c.Call.Return(run) + return _c +} + +// ListArchivedWorkflow provides a mock function with given fields: ctx, request +func (_m *MockClient) ListArchivedWorkflow(ctx context.Context, request *workflowservice.ListArchivedWorkflowExecutionsRequest) (*workflowservice.ListArchivedWorkflowExecutionsResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for ListArchivedWorkflow") + } + + var r0 *workflowservice.ListArchivedWorkflowExecutionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ListArchivedWorkflowExecutionsRequest) (*workflowservice.ListArchivedWorkflowExecutionsResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ListArchivedWorkflowExecutionsRequest) *workflowservice.ListArchivedWorkflowExecutionsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.ListArchivedWorkflowExecutionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *workflowservice.ListArchivedWorkflowExecutionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_ListArchivedWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListArchivedWorkflow' +type MockClient_ListArchivedWorkflow_Call struct { + *mock.Call +} + +// ListArchivedWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - request *workflowservice.ListArchivedWorkflowExecutionsRequest +func (_e *MockClient_Expecter) ListArchivedWorkflow(ctx interface{}, request interface{}) *MockClient_ListArchivedWorkflow_Call { + return &MockClient_ListArchivedWorkflow_Call{Call: _e.mock.On("ListArchivedWorkflow", ctx, request)} +} + +func (_c *MockClient_ListArchivedWorkflow_Call) Run(run func(ctx context.Context, request *workflowservice.ListArchivedWorkflowExecutionsRequest)) *MockClient_ListArchivedWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*workflowservice.ListArchivedWorkflowExecutionsRequest)) + }) + return _c +} + +func (_c *MockClient_ListArchivedWorkflow_Call) Return(_a0 *workflowservice.ListArchivedWorkflowExecutionsResponse, _a1 error) *MockClient_ListArchivedWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_ListArchivedWorkflow_Call) RunAndReturn(run func(context.Context, *workflowservice.ListArchivedWorkflowExecutionsRequest) (*workflowservice.ListArchivedWorkflowExecutionsResponse, error)) *MockClient_ListArchivedWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// ListClosedWorkflow provides a mock function with given fields: ctx, request +func (_m *MockClient) ListClosedWorkflow(ctx context.Context, request *workflowservice.ListClosedWorkflowExecutionsRequest) (*workflowservice.ListClosedWorkflowExecutionsResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for ListClosedWorkflow") + } + + var r0 *workflowservice.ListClosedWorkflowExecutionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ListClosedWorkflowExecutionsRequest) (*workflowservice.ListClosedWorkflowExecutionsResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ListClosedWorkflowExecutionsRequest) *workflowservice.ListClosedWorkflowExecutionsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.ListClosedWorkflowExecutionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *workflowservice.ListClosedWorkflowExecutionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_ListClosedWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListClosedWorkflow' +type MockClient_ListClosedWorkflow_Call struct { + *mock.Call +} + +// ListClosedWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - request *workflowservice.ListClosedWorkflowExecutionsRequest +func (_e *MockClient_Expecter) ListClosedWorkflow(ctx interface{}, request interface{}) *MockClient_ListClosedWorkflow_Call { + return &MockClient_ListClosedWorkflow_Call{Call: _e.mock.On("ListClosedWorkflow", ctx, request)} +} + +func (_c *MockClient_ListClosedWorkflow_Call) Run(run func(ctx context.Context, request *workflowservice.ListClosedWorkflowExecutionsRequest)) *MockClient_ListClosedWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*workflowservice.ListClosedWorkflowExecutionsRequest)) + }) + return _c +} + +func (_c *MockClient_ListClosedWorkflow_Call) Return(_a0 *workflowservice.ListClosedWorkflowExecutionsResponse, _a1 error) *MockClient_ListClosedWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_ListClosedWorkflow_Call) RunAndReturn(run func(context.Context, *workflowservice.ListClosedWorkflowExecutionsRequest) (*workflowservice.ListClosedWorkflowExecutionsResponse, error)) *MockClient_ListClosedWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// ListOpenWorkflow provides a mock function with given fields: ctx, request +func (_m *MockClient) ListOpenWorkflow(ctx context.Context, request *workflowservice.ListOpenWorkflowExecutionsRequest) (*workflowservice.ListOpenWorkflowExecutionsResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for ListOpenWorkflow") + } + + var r0 *workflowservice.ListOpenWorkflowExecutionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ListOpenWorkflowExecutionsRequest) (*workflowservice.ListOpenWorkflowExecutionsResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ListOpenWorkflowExecutionsRequest) *workflowservice.ListOpenWorkflowExecutionsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.ListOpenWorkflowExecutionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *workflowservice.ListOpenWorkflowExecutionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_ListOpenWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListOpenWorkflow' +type MockClient_ListOpenWorkflow_Call struct { + *mock.Call +} + +// ListOpenWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - request *workflowservice.ListOpenWorkflowExecutionsRequest +func (_e *MockClient_Expecter) ListOpenWorkflow(ctx interface{}, request interface{}) *MockClient_ListOpenWorkflow_Call { + return &MockClient_ListOpenWorkflow_Call{Call: _e.mock.On("ListOpenWorkflow", ctx, request)} +} + +func (_c *MockClient_ListOpenWorkflow_Call) Run(run func(ctx context.Context, request *workflowservice.ListOpenWorkflowExecutionsRequest)) *MockClient_ListOpenWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*workflowservice.ListOpenWorkflowExecutionsRequest)) + }) + return _c +} + +func (_c *MockClient_ListOpenWorkflow_Call) Return(_a0 *workflowservice.ListOpenWorkflowExecutionsResponse, _a1 error) *MockClient_ListOpenWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_ListOpenWorkflow_Call) RunAndReturn(run func(context.Context, *workflowservice.ListOpenWorkflowExecutionsRequest) (*workflowservice.ListOpenWorkflowExecutionsResponse, error)) *MockClient_ListOpenWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// ListWorkflow provides a mock function with given fields: ctx, request +func (_m *MockClient) ListWorkflow(ctx context.Context, request *workflowservice.ListWorkflowExecutionsRequest) (*workflowservice.ListWorkflowExecutionsResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for ListWorkflow") + } + + var r0 *workflowservice.ListWorkflowExecutionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ListWorkflowExecutionsRequest) (*workflowservice.ListWorkflowExecutionsResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ListWorkflowExecutionsRequest) *workflowservice.ListWorkflowExecutionsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.ListWorkflowExecutionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *workflowservice.ListWorkflowExecutionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_ListWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListWorkflow' +type MockClient_ListWorkflow_Call struct { + *mock.Call +} + +// ListWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - request *workflowservice.ListWorkflowExecutionsRequest +func (_e *MockClient_Expecter) ListWorkflow(ctx interface{}, request interface{}) *MockClient_ListWorkflow_Call { + return &MockClient_ListWorkflow_Call{Call: _e.mock.On("ListWorkflow", ctx, request)} +} + +func (_c *MockClient_ListWorkflow_Call) Run(run func(ctx context.Context, request *workflowservice.ListWorkflowExecutionsRequest)) *MockClient_ListWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*workflowservice.ListWorkflowExecutionsRequest)) + }) + return _c +} + +func (_c *MockClient_ListWorkflow_Call) Return(_a0 *workflowservice.ListWorkflowExecutionsResponse, _a1 error) *MockClient_ListWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_ListWorkflow_Call) RunAndReturn(run func(context.Context, *workflowservice.ListWorkflowExecutionsRequest) (*workflowservice.ListWorkflowExecutionsResponse, error)) *MockClient_ListWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// OperatorService provides a mock function with given fields: +func (_m *MockClient) OperatorService() operatorservice.OperatorServiceClient { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for OperatorService") + } + + var r0 operatorservice.OperatorServiceClient + if rf, ok := ret.Get(0).(func() operatorservice.OperatorServiceClient); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(operatorservice.OperatorServiceClient) + } + } + + return r0 +} + +// MockClient_OperatorService_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OperatorService' +type MockClient_OperatorService_Call struct { + *mock.Call +} + +// OperatorService is a helper method to define mock.On call +func (_e *MockClient_Expecter) OperatorService() *MockClient_OperatorService_Call { + return &MockClient_OperatorService_Call{Call: _e.mock.On("OperatorService")} +} + +func (_c *MockClient_OperatorService_Call) Run(run func()) *MockClient_OperatorService_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockClient_OperatorService_Call) Return(_a0 operatorservice.OperatorServiceClient) *MockClient_OperatorService_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_OperatorService_Call) RunAndReturn(run func() operatorservice.OperatorServiceClient) *MockClient_OperatorService_Call { + _c.Call.Return(run) + return _c +} + +// QueryWorkflow provides a mock function with given fields: ctx, workflowID, runID, queryType, args +func (_m *MockClient) QueryWorkflow(ctx context.Context, workflowID string, runID string, queryType string, args ...interface{}) (converter.EncodedValue, error) { + var _ca []interface{} + _ca = append(_ca, ctx, workflowID, runID, queryType) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for QueryWorkflow") + } + + var r0 converter.EncodedValue + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, ...interface{}) (converter.EncodedValue, error)); ok { + return rf(ctx, workflowID, runID, queryType, args...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, ...interface{}) converter.EncodedValue); ok { + r0 = rf(ctx, workflowID, runID, queryType, args...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(converter.EncodedValue) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, ...interface{}) error); ok { + r1 = rf(ctx, workflowID, runID, queryType, args...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_QueryWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryWorkflow' +type MockClient_QueryWorkflow_Call struct { + *mock.Call +} + +// QueryWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - workflowID string +// - runID string +// - queryType string +// - args ...interface{} +func (_e *MockClient_Expecter) QueryWorkflow(ctx interface{}, workflowID interface{}, runID interface{}, queryType interface{}, args ...interface{}) *MockClient_QueryWorkflow_Call { + return &MockClient_QueryWorkflow_Call{Call: _e.mock.On("QueryWorkflow", + append([]interface{}{ctx, workflowID, runID, queryType}, args...)...)} +} + +func (_c *MockClient_QueryWorkflow_Call) Run(run func(ctx context.Context, workflowID string, runID string, queryType string, args ...interface{})) *MockClient_QueryWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-4) + for i, a := range args[4:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), variadicArgs...) + }) + return _c +} + +func (_c *MockClient_QueryWorkflow_Call) Return(_a0 converter.EncodedValue, _a1 error) *MockClient_QueryWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_QueryWorkflow_Call) RunAndReturn(run func(context.Context, string, string, string, ...interface{}) (converter.EncodedValue, error)) *MockClient_QueryWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// QueryWorkflowWithOptions provides a mock function with given fields: ctx, request +func (_m *MockClient) QueryWorkflowWithOptions(ctx context.Context, request *client.QueryWorkflowWithOptionsRequest) (*client.QueryWorkflowWithOptionsResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for QueryWorkflowWithOptions") + } + + var r0 *client.QueryWorkflowWithOptionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *client.QueryWorkflowWithOptionsRequest) (*client.QueryWorkflowWithOptionsResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *client.QueryWorkflowWithOptionsRequest) *client.QueryWorkflowWithOptionsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*client.QueryWorkflowWithOptionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *client.QueryWorkflowWithOptionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_QueryWorkflowWithOptions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryWorkflowWithOptions' +type MockClient_QueryWorkflowWithOptions_Call struct { + *mock.Call +} + +// QueryWorkflowWithOptions is a helper method to define mock.On call +// - ctx context.Context +// - request *client.QueryWorkflowWithOptionsRequest +func (_e *MockClient_Expecter) QueryWorkflowWithOptions(ctx interface{}, request interface{}) *MockClient_QueryWorkflowWithOptions_Call { + return &MockClient_QueryWorkflowWithOptions_Call{Call: _e.mock.On("QueryWorkflowWithOptions", ctx, request)} +} + +func (_c *MockClient_QueryWorkflowWithOptions_Call) Run(run func(ctx context.Context, request *client.QueryWorkflowWithOptionsRequest)) *MockClient_QueryWorkflowWithOptions_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*client.QueryWorkflowWithOptionsRequest)) + }) + return _c +} + +func (_c *MockClient_QueryWorkflowWithOptions_Call) Return(_a0 *client.QueryWorkflowWithOptionsResponse, _a1 error) *MockClient_QueryWorkflowWithOptions_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_QueryWorkflowWithOptions_Call) RunAndReturn(run func(context.Context, *client.QueryWorkflowWithOptionsRequest) (*client.QueryWorkflowWithOptionsResponse, error)) *MockClient_QueryWorkflowWithOptions_Call { + _c.Call.Return(run) + return _c +} + +// RecordActivityHeartbeat provides a mock function with given fields: ctx, taskToken, details +func (_m *MockClient) RecordActivityHeartbeat(ctx context.Context, taskToken []byte, details ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, ctx, taskToken) + _ca = append(_ca, details...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for RecordActivityHeartbeat") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, []byte, ...interface{}) error); ok { + r0 = rf(ctx, taskToken, details...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_RecordActivityHeartbeat_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecordActivityHeartbeat' +type MockClient_RecordActivityHeartbeat_Call struct { + *mock.Call +} + +// RecordActivityHeartbeat is a helper method to define mock.On call +// - ctx context.Context +// - taskToken []byte +// - details ...interface{} +func (_e *MockClient_Expecter) RecordActivityHeartbeat(ctx interface{}, taskToken interface{}, details ...interface{}) *MockClient_RecordActivityHeartbeat_Call { + return &MockClient_RecordActivityHeartbeat_Call{Call: _e.mock.On("RecordActivityHeartbeat", + append([]interface{}{ctx, taskToken}, details...)...)} +} + +func (_c *MockClient_RecordActivityHeartbeat_Call) Run(run func(ctx context.Context, taskToken []byte, details ...interface{})) *MockClient_RecordActivityHeartbeat_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(context.Context), args[1].([]byte), variadicArgs...) + }) + return _c +} + +func (_c *MockClient_RecordActivityHeartbeat_Call) Return(_a0 error) *MockClient_RecordActivityHeartbeat_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_RecordActivityHeartbeat_Call) RunAndReturn(run func(context.Context, []byte, ...interface{}) error) *MockClient_RecordActivityHeartbeat_Call { + _c.Call.Return(run) + return _c +} + +// RecordActivityHeartbeatByID provides a mock function with given fields: ctx, namespace, workflowID, runID, activityID, details +func (_m *MockClient) RecordActivityHeartbeatByID(ctx context.Context, namespace string, workflowID string, runID string, activityID string, details ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, ctx, namespace, workflowID, runID, activityID) + _ca = append(_ca, details...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for RecordActivityHeartbeatByID") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string, ...interface{}) error); ok { + r0 = rf(ctx, namespace, workflowID, runID, activityID, details...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_RecordActivityHeartbeatByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecordActivityHeartbeatByID' +type MockClient_RecordActivityHeartbeatByID_Call struct { + *mock.Call +} + +// RecordActivityHeartbeatByID is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - workflowID string +// - runID string +// - activityID string +// - details ...interface{} +func (_e *MockClient_Expecter) RecordActivityHeartbeatByID(ctx interface{}, namespace interface{}, workflowID interface{}, runID interface{}, activityID interface{}, details ...interface{}) *MockClient_RecordActivityHeartbeatByID_Call { + return &MockClient_RecordActivityHeartbeatByID_Call{Call: _e.mock.On("RecordActivityHeartbeatByID", + append([]interface{}{ctx, namespace, workflowID, runID, activityID}, details...)...)} +} + +func (_c *MockClient_RecordActivityHeartbeatByID_Call) Run(run func(ctx context.Context, namespace string, workflowID string, runID string, activityID string, details ...interface{})) *MockClient_RecordActivityHeartbeatByID_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-5) + for i, a := range args[5:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string), variadicArgs...) + }) + return _c +} + +func (_c *MockClient_RecordActivityHeartbeatByID_Call) Return(_a0 error) *MockClient_RecordActivityHeartbeatByID_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_RecordActivityHeartbeatByID_Call) RunAndReturn(run func(context.Context, string, string, string, string, ...interface{}) error) *MockClient_RecordActivityHeartbeatByID_Call { + _c.Call.Return(run) + return _c +} + +// ResetWorkflowExecution provides a mock function with given fields: ctx, request +func (_m *MockClient) ResetWorkflowExecution(ctx context.Context, request *workflowservice.ResetWorkflowExecutionRequest) (*workflowservice.ResetWorkflowExecutionResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for ResetWorkflowExecution") + } + + var r0 *workflowservice.ResetWorkflowExecutionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ResetWorkflowExecutionRequest) (*workflowservice.ResetWorkflowExecutionResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ResetWorkflowExecutionRequest) *workflowservice.ResetWorkflowExecutionResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.ResetWorkflowExecutionResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *workflowservice.ResetWorkflowExecutionRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_ResetWorkflowExecution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ResetWorkflowExecution' +type MockClient_ResetWorkflowExecution_Call struct { + *mock.Call +} + +// ResetWorkflowExecution is a helper method to define mock.On call +// - ctx context.Context +// - request *workflowservice.ResetWorkflowExecutionRequest +func (_e *MockClient_Expecter) ResetWorkflowExecution(ctx interface{}, request interface{}) *MockClient_ResetWorkflowExecution_Call { + return &MockClient_ResetWorkflowExecution_Call{Call: _e.mock.On("ResetWorkflowExecution", ctx, request)} +} + +func (_c *MockClient_ResetWorkflowExecution_Call) Run(run func(ctx context.Context, request *workflowservice.ResetWorkflowExecutionRequest)) *MockClient_ResetWorkflowExecution_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*workflowservice.ResetWorkflowExecutionRequest)) + }) + return _c +} + +func (_c *MockClient_ResetWorkflowExecution_Call) Return(_a0 *workflowservice.ResetWorkflowExecutionResponse, _a1 error) *MockClient_ResetWorkflowExecution_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_ResetWorkflowExecution_Call) RunAndReturn(run func(context.Context, *workflowservice.ResetWorkflowExecutionRequest) (*workflowservice.ResetWorkflowExecutionResponse, error)) *MockClient_ResetWorkflowExecution_Call { + _c.Call.Return(run) + return _c +} + +// ScanWorkflow provides a mock function with given fields: ctx, request +func (_m *MockClient) ScanWorkflow(ctx context.Context, request *workflowservice.ScanWorkflowExecutionsRequest) (*workflowservice.ScanWorkflowExecutionsResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for ScanWorkflow") + } + + var r0 *workflowservice.ScanWorkflowExecutionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ScanWorkflowExecutionsRequest) (*workflowservice.ScanWorkflowExecutionsResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ScanWorkflowExecutionsRequest) *workflowservice.ScanWorkflowExecutionsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.ScanWorkflowExecutionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *workflowservice.ScanWorkflowExecutionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_ScanWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ScanWorkflow' +type MockClient_ScanWorkflow_Call struct { + *mock.Call +} + +// ScanWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - request *workflowservice.ScanWorkflowExecutionsRequest +func (_e *MockClient_Expecter) ScanWorkflow(ctx interface{}, request interface{}) *MockClient_ScanWorkflow_Call { + return &MockClient_ScanWorkflow_Call{Call: _e.mock.On("ScanWorkflow", ctx, request)} +} + +func (_c *MockClient_ScanWorkflow_Call) Run(run func(ctx context.Context, request *workflowservice.ScanWorkflowExecutionsRequest)) *MockClient_ScanWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*workflowservice.ScanWorkflowExecutionsRequest)) + }) + return _c +} + +func (_c *MockClient_ScanWorkflow_Call) Return(_a0 *workflowservice.ScanWorkflowExecutionsResponse, _a1 error) *MockClient_ScanWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_ScanWorkflow_Call) RunAndReturn(run func(context.Context, *workflowservice.ScanWorkflowExecutionsRequest) (*workflowservice.ScanWorkflowExecutionsResponse, error)) *MockClient_ScanWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// ScheduleClient provides a mock function with given fields: +func (_m *MockClient) ScheduleClient() client.ScheduleClient { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ScheduleClient") + } + + var r0 client.ScheduleClient + if rf, ok := ret.Get(0).(func() client.ScheduleClient); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.ScheduleClient) + } + } + + return r0 +} + +// MockClient_ScheduleClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ScheduleClient' +type MockClient_ScheduleClient_Call struct { + *mock.Call +} + +// ScheduleClient is a helper method to define mock.On call +func (_e *MockClient_Expecter) ScheduleClient() *MockClient_ScheduleClient_Call { + return &MockClient_ScheduleClient_Call{Call: _e.mock.On("ScheduleClient")} +} + +func (_c *MockClient_ScheduleClient_Call) Run(run func()) *MockClient_ScheduleClient_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockClient_ScheduleClient_Call) Return(_a0 client.ScheduleClient) *MockClient_ScheduleClient_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_ScheduleClient_Call) RunAndReturn(run func() client.ScheduleClient) *MockClient_ScheduleClient_Call { + _c.Call.Return(run) + return _c +} + +// SignalWithStartWorkflow provides a mock function with given fields: ctx, workflowID, signalName, signalArg, options, workflow, workflowArgs +func (_m *MockClient) SignalWithStartWorkflow(ctx context.Context, workflowID string, signalName string, signalArg interface{}, options client.StartWorkflowOptions, workflow interface{}, workflowArgs ...interface{}) (client.WorkflowRun, error) { + var _ca []interface{} + _ca = append(_ca, ctx, workflowID, signalName, signalArg, options, workflow) + _ca = append(_ca, workflowArgs...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for SignalWithStartWorkflow") + } + + var r0 client.WorkflowRun + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, interface{}, client.StartWorkflowOptions, interface{}, ...interface{}) (client.WorkflowRun, error)); ok { + return rf(ctx, workflowID, signalName, signalArg, options, workflow, workflowArgs...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, interface{}, client.StartWorkflowOptions, interface{}, ...interface{}) client.WorkflowRun); ok { + r0 = rf(ctx, workflowID, signalName, signalArg, options, workflow, workflowArgs...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.WorkflowRun) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, interface{}, client.StartWorkflowOptions, interface{}, ...interface{}) error); ok { + r1 = rf(ctx, workflowID, signalName, signalArg, options, workflow, workflowArgs...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_SignalWithStartWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SignalWithStartWorkflow' +type MockClient_SignalWithStartWorkflow_Call struct { + *mock.Call +} + +// SignalWithStartWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - workflowID string +// - signalName string +// - signalArg interface{} +// - options client.StartWorkflowOptions +// - workflow interface{} +// - workflowArgs ...interface{} +func (_e *MockClient_Expecter) SignalWithStartWorkflow(ctx interface{}, workflowID interface{}, signalName interface{}, signalArg interface{}, options interface{}, workflow interface{}, workflowArgs ...interface{}) *MockClient_SignalWithStartWorkflow_Call { + return &MockClient_SignalWithStartWorkflow_Call{Call: _e.mock.On("SignalWithStartWorkflow", + append([]interface{}{ctx, workflowID, signalName, signalArg, options, workflow}, workflowArgs...)...)} +} + +func (_c *MockClient_SignalWithStartWorkflow_Call) Run(run func(ctx context.Context, workflowID string, signalName string, signalArg interface{}, options client.StartWorkflowOptions, workflow interface{}, workflowArgs ...interface{})) *MockClient_SignalWithStartWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-6) + for i, a := range args[6:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(interface{}), args[4].(client.StartWorkflowOptions), args[5].(interface{}), variadicArgs...) + }) + return _c +} + +func (_c *MockClient_SignalWithStartWorkflow_Call) Return(_a0 client.WorkflowRun, _a1 error) *MockClient_SignalWithStartWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_SignalWithStartWorkflow_Call) RunAndReturn(run func(context.Context, string, string, interface{}, client.StartWorkflowOptions, interface{}, ...interface{}) (client.WorkflowRun, error)) *MockClient_SignalWithStartWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// SignalWorkflow provides a mock function with given fields: ctx, workflowID, runID, signalName, arg +func (_m *MockClient) SignalWorkflow(ctx context.Context, workflowID string, runID string, signalName string, arg interface{}) error { + ret := _m.Called(ctx, workflowID, runID, signalName, arg) + + if len(ret) == 0 { + panic("no return value specified for SignalWorkflow") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, interface{}) error); ok { + r0 = rf(ctx, workflowID, runID, signalName, arg) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_SignalWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SignalWorkflow' +type MockClient_SignalWorkflow_Call struct { + *mock.Call +} + +// SignalWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - workflowID string +// - runID string +// - signalName string +// - arg interface{} +func (_e *MockClient_Expecter) SignalWorkflow(ctx interface{}, workflowID interface{}, runID interface{}, signalName interface{}, arg interface{}) *MockClient_SignalWorkflow_Call { + return &MockClient_SignalWorkflow_Call{Call: _e.mock.On("SignalWorkflow", ctx, workflowID, runID, signalName, arg)} +} + +func (_c *MockClient_SignalWorkflow_Call) Run(run func(ctx context.Context, workflowID string, runID string, signalName string, arg interface{})) *MockClient_SignalWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(interface{})) + }) + return _c +} + +func (_c *MockClient_SignalWorkflow_Call) Return(_a0 error) *MockClient_SignalWorkflow_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_SignalWorkflow_Call) RunAndReturn(run func(context.Context, string, string, string, interface{}) error) *MockClient_SignalWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// TerminateWorkflow provides a mock function with given fields: ctx, workflowID, runID, reason, details +func (_m *MockClient) TerminateWorkflow(ctx context.Context, workflowID string, runID string, reason string, details ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, ctx, workflowID, runID, reason) + _ca = append(_ca, details...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for TerminateWorkflow") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, ...interface{}) error); ok { + r0 = rf(ctx, workflowID, runID, reason, details...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_TerminateWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TerminateWorkflow' +type MockClient_TerminateWorkflow_Call struct { + *mock.Call +} + +// TerminateWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - workflowID string +// - runID string +// - reason string +// - details ...interface{} +func (_e *MockClient_Expecter) TerminateWorkflow(ctx interface{}, workflowID interface{}, runID interface{}, reason interface{}, details ...interface{}) *MockClient_TerminateWorkflow_Call { + return &MockClient_TerminateWorkflow_Call{Call: _e.mock.On("TerminateWorkflow", + append([]interface{}{ctx, workflowID, runID, reason}, details...)...)} +} + +func (_c *MockClient_TerminateWorkflow_Call) Run(run func(ctx context.Context, workflowID string, runID string, reason string, details ...interface{})) *MockClient_TerminateWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-4) + for i, a := range args[4:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), variadicArgs...) + }) + return _c +} + +func (_c *MockClient_TerminateWorkflow_Call) Return(_a0 error) *MockClient_TerminateWorkflow_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_TerminateWorkflow_Call) RunAndReturn(run func(context.Context, string, string, string, ...interface{}) error) *MockClient_TerminateWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// UpdateWorkerBuildIdCompatibility provides a mock function with given fields: ctx, options +func (_m *MockClient) UpdateWorkerBuildIdCompatibility(ctx context.Context, options *client.UpdateWorkerBuildIdCompatibilityOptions) error { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for UpdateWorkerBuildIdCompatibility") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *client.UpdateWorkerBuildIdCompatibilityOptions) error); ok { + r0 = rf(ctx, options) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_UpdateWorkerBuildIdCompatibility_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateWorkerBuildIdCompatibility' +type MockClient_UpdateWorkerBuildIdCompatibility_Call struct { + *mock.Call +} + +// UpdateWorkerBuildIdCompatibility is a helper method to define mock.On call +// - ctx context.Context +// - options *client.UpdateWorkerBuildIdCompatibilityOptions +func (_e *MockClient_Expecter) UpdateWorkerBuildIdCompatibility(ctx interface{}, options interface{}) *MockClient_UpdateWorkerBuildIdCompatibility_Call { + return &MockClient_UpdateWorkerBuildIdCompatibility_Call{Call: _e.mock.On("UpdateWorkerBuildIdCompatibility", ctx, options)} +} + +func (_c *MockClient_UpdateWorkerBuildIdCompatibility_Call) Run(run func(ctx context.Context, options *client.UpdateWorkerBuildIdCompatibilityOptions)) *MockClient_UpdateWorkerBuildIdCompatibility_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*client.UpdateWorkerBuildIdCompatibilityOptions)) + }) + return _c +} + +func (_c *MockClient_UpdateWorkerBuildIdCompatibility_Call) Return(_a0 error) *MockClient_UpdateWorkerBuildIdCompatibility_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_UpdateWorkerBuildIdCompatibility_Call) RunAndReturn(run func(context.Context, *client.UpdateWorkerBuildIdCompatibilityOptions) error) *MockClient_UpdateWorkerBuildIdCompatibility_Call { + _c.Call.Return(run) + return _c +} + +// UpdateWorkerVersioningRules provides a mock function with given fields: ctx, options +func (_m *MockClient) UpdateWorkerVersioningRules(ctx context.Context, options client.UpdateWorkerVersioningRulesOptions) (*client.WorkerVersioningRules, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for UpdateWorkerVersioningRules") + } + + var r0 *client.WorkerVersioningRules + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, client.UpdateWorkerVersioningRulesOptions) (*client.WorkerVersioningRules, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, client.UpdateWorkerVersioningRulesOptions) *client.WorkerVersioningRules); ok { + r0 = rf(ctx, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*client.WorkerVersioningRules) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, client.UpdateWorkerVersioningRulesOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_UpdateWorkerVersioningRules_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateWorkerVersioningRules' +type MockClient_UpdateWorkerVersioningRules_Call struct { + *mock.Call +} + +// UpdateWorkerVersioningRules is a helper method to define mock.On call +// - ctx context.Context +// - options client.UpdateWorkerVersioningRulesOptions +func (_e *MockClient_Expecter) UpdateWorkerVersioningRules(ctx interface{}, options interface{}) *MockClient_UpdateWorkerVersioningRules_Call { + return &MockClient_UpdateWorkerVersioningRules_Call{Call: _e.mock.On("UpdateWorkerVersioningRules", ctx, options)} +} + +func (_c *MockClient_UpdateWorkerVersioningRules_Call) Run(run func(ctx context.Context, options client.UpdateWorkerVersioningRulesOptions)) *MockClient_UpdateWorkerVersioningRules_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(client.UpdateWorkerVersioningRulesOptions)) + }) + return _c +} + +func (_c *MockClient_UpdateWorkerVersioningRules_Call) Return(_a0 *client.WorkerVersioningRules, _a1 error) *MockClient_UpdateWorkerVersioningRules_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_UpdateWorkerVersioningRules_Call) RunAndReturn(run func(context.Context, client.UpdateWorkerVersioningRulesOptions) (*client.WorkerVersioningRules, error)) *MockClient_UpdateWorkerVersioningRules_Call { + _c.Call.Return(run) + return _c +} + +// UpdateWorkflow provides a mock function with given fields: ctx, options +func (_m *MockClient) UpdateWorkflow(ctx context.Context, options client.UpdateWorkflowOptions) (client.WorkflowUpdateHandle, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for UpdateWorkflow") + } + + var r0 client.WorkflowUpdateHandle + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, client.UpdateWorkflowOptions) (client.WorkflowUpdateHandle, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, client.UpdateWorkflowOptions) client.WorkflowUpdateHandle); ok { + r0 = rf(ctx, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.WorkflowUpdateHandle) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, client.UpdateWorkflowOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_UpdateWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateWorkflow' +type MockClient_UpdateWorkflow_Call struct { + *mock.Call +} + +// UpdateWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - options client.UpdateWorkflowOptions +func (_e *MockClient_Expecter) UpdateWorkflow(ctx interface{}, options interface{}) *MockClient_UpdateWorkflow_Call { + return &MockClient_UpdateWorkflow_Call{Call: _e.mock.On("UpdateWorkflow", ctx, options)} +} + +func (_c *MockClient_UpdateWorkflow_Call) Run(run func(ctx context.Context, options client.UpdateWorkflowOptions)) *MockClient_UpdateWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(client.UpdateWorkflowOptions)) + }) + return _c +} + +func (_c *MockClient_UpdateWorkflow_Call) Return(_a0 client.WorkflowUpdateHandle, _a1 error) *MockClient_UpdateWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_UpdateWorkflow_Call) RunAndReturn(run func(context.Context, client.UpdateWorkflowOptions) (client.WorkflowUpdateHandle, error)) *MockClient_UpdateWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// WorkflowService provides a mock function with given fields: +func (_m *MockClient) WorkflowService() workflowservice.WorkflowServiceClient { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for WorkflowService") + } + + var r0 workflowservice.WorkflowServiceClient + if rf, ok := ret.Get(0).(func() workflowservice.WorkflowServiceClient); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(workflowservice.WorkflowServiceClient) + } + } + + return r0 +} + +// MockClient_WorkflowService_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WorkflowService' +type MockClient_WorkflowService_Call struct { + *mock.Call +} + +// WorkflowService is a helper method to define mock.On call +func (_e *MockClient_Expecter) WorkflowService() *MockClient_WorkflowService_Call { + return &MockClient_WorkflowService_Call{Call: _e.mock.On("WorkflowService")} +} + +func (_c *MockClient_WorkflowService_Call) Run(run func()) *MockClient_WorkflowService_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockClient_WorkflowService_Call) Return(_a0 workflowservice.WorkflowServiceClient) *MockClient_WorkflowService_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_WorkflowService_Call) RunAndReturn(run func() workflowservice.WorkflowServiceClient) *MockClient_WorkflowService_Call { + _c.Call.Return(run) + return _c +} + +// NewMockClient creates a new instance of MockClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockClient { + mock := &MockClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/component/mock_CollectionComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_CollectionComponent.go new file mode 100644 index 00000000..1d1fc50e --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/mock_CollectionComponent.go @@ -0,0 +1,541 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + database "opencsg.com/csghub-server/builder/store/database" + + types "opencsg.com/csghub-server/common/types" +) + +// MockCollectionComponent is an autogenerated mock type for the CollectionComponent type +type MockCollectionComponent struct { + mock.Mock +} + +type MockCollectionComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockCollectionComponent) EXPECT() *MockCollectionComponent_Expecter { + return &MockCollectionComponent_Expecter{mock: &_m.Mock} +} + +// AddReposToCollection provides a mock function with given fields: ctx, req +func (_m *MockCollectionComponent) AddReposToCollection(ctx context.Context, req types.UpdateCollectionReposReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for AddReposToCollection") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.UpdateCollectionReposReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockCollectionComponent_AddReposToCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddReposToCollection' +type MockCollectionComponent_AddReposToCollection_Call struct { + *mock.Call +} + +// AddReposToCollection is a helper method to define mock.On call +// - ctx context.Context +// - req types.UpdateCollectionReposReq +func (_e *MockCollectionComponent_Expecter) AddReposToCollection(ctx interface{}, req interface{}) *MockCollectionComponent_AddReposToCollection_Call { + return &MockCollectionComponent_AddReposToCollection_Call{Call: _e.mock.On("AddReposToCollection", ctx, req)} +} + +func (_c *MockCollectionComponent_AddReposToCollection_Call) Run(run func(ctx context.Context, req types.UpdateCollectionReposReq)) *MockCollectionComponent_AddReposToCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.UpdateCollectionReposReq)) + }) + return _c +} + +func (_c *MockCollectionComponent_AddReposToCollection_Call) Return(_a0 error) *MockCollectionComponent_AddReposToCollection_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCollectionComponent_AddReposToCollection_Call) RunAndReturn(run func(context.Context, types.UpdateCollectionReposReq) error) *MockCollectionComponent_AddReposToCollection_Call { + _c.Call.Return(run) + return _c +} + +// CreateCollection provides a mock function with given fields: ctx, input +func (_m *MockCollectionComponent) CreateCollection(ctx context.Context, input types.CreateCollectionReq) (*database.Collection, error) { + ret := _m.Called(ctx, input) + + if len(ret) == 0 { + panic("no return value specified for CreateCollection") + } + + var r0 *database.Collection + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.CreateCollectionReq) (*database.Collection, error)); ok { + return rf(ctx, input) + } + if rf, ok := ret.Get(0).(func(context.Context, types.CreateCollectionReq) *database.Collection); ok { + r0 = rf(ctx, input) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.Collection) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.CreateCollectionReq) error); ok { + r1 = rf(ctx, input) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCollectionComponent_CreateCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateCollection' +type MockCollectionComponent_CreateCollection_Call struct { + *mock.Call +} + +// CreateCollection is a helper method to define mock.On call +// - ctx context.Context +// - input types.CreateCollectionReq +func (_e *MockCollectionComponent_Expecter) CreateCollection(ctx interface{}, input interface{}) *MockCollectionComponent_CreateCollection_Call { + return &MockCollectionComponent_CreateCollection_Call{Call: _e.mock.On("CreateCollection", ctx, input)} +} + +func (_c *MockCollectionComponent_CreateCollection_Call) Run(run func(ctx context.Context, input types.CreateCollectionReq)) *MockCollectionComponent_CreateCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.CreateCollectionReq)) + }) + return _c +} + +func (_c *MockCollectionComponent_CreateCollection_Call) Return(_a0 *database.Collection, _a1 error) *MockCollectionComponent_CreateCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCollectionComponent_CreateCollection_Call) RunAndReturn(run func(context.Context, types.CreateCollectionReq) (*database.Collection, error)) *MockCollectionComponent_CreateCollection_Call { + _c.Call.Return(run) + return _c +} + +// DeleteCollection provides a mock function with given fields: ctx, id, userName +func (_m *MockCollectionComponent) DeleteCollection(ctx context.Context, id int64, userName string) error { + ret := _m.Called(ctx, id, userName) + + if len(ret) == 0 { + panic("no return value specified for DeleteCollection") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int64, string) error); ok { + r0 = rf(ctx, id, userName) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockCollectionComponent_DeleteCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteCollection' +type MockCollectionComponent_DeleteCollection_Call struct { + *mock.Call +} + +// DeleteCollection is a helper method to define mock.On call +// - ctx context.Context +// - id int64 +// - userName string +func (_e *MockCollectionComponent_Expecter) DeleteCollection(ctx interface{}, id interface{}, userName interface{}) *MockCollectionComponent_DeleteCollection_Call { + return &MockCollectionComponent_DeleteCollection_Call{Call: _e.mock.On("DeleteCollection", ctx, id, userName)} +} + +func (_c *MockCollectionComponent_DeleteCollection_Call) Run(run func(ctx context.Context, id int64, userName string)) *MockCollectionComponent_DeleteCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64), args[2].(string)) + }) + return _c +} + +func (_c *MockCollectionComponent_DeleteCollection_Call) Return(_a0 error) *MockCollectionComponent_DeleteCollection_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCollectionComponent_DeleteCollection_Call) RunAndReturn(run func(context.Context, int64, string) error) *MockCollectionComponent_DeleteCollection_Call { + _c.Call.Return(run) + return _c +} + +// GetCollection provides a mock function with given fields: ctx, currentUser, id +func (_m *MockCollectionComponent) GetCollection(ctx context.Context, currentUser string, id int64) (*types.Collection, error) { + ret := _m.Called(ctx, currentUser, id) + + if len(ret) == 0 { + panic("no return value specified for GetCollection") + } + + var r0 *types.Collection + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, int64) (*types.Collection, error)); ok { + return rf(ctx, currentUser, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string, int64) *types.Collection); ok { + r0 = rf(ctx, currentUser, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Collection) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, int64) error); ok { + r1 = rf(ctx, currentUser, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCollectionComponent_GetCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollection' +type MockCollectionComponent_GetCollection_Call struct { + *mock.Call +} + +// GetCollection is a helper method to define mock.On call +// - ctx context.Context +// - currentUser string +// - id int64 +func (_e *MockCollectionComponent_Expecter) GetCollection(ctx interface{}, currentUser interface{}, id interface{}) *MockCollectionComponent_GetCollection_Call { + return &MockCollectionComponent_GetCollection_Call{Call: _e.mock.On("GetCollection", ctx, currentUser, id)} +} + +func (_c *MockCollectionComponent_GetCollection_Call) Run(run func(ctx context.Context, currentUser string, id int64)) *MockCollectionComponent_GetCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(int64)) + }) + return _c +} + +func (_c *MockCollectionComponent_GetCollection_Call) Return(_a0 *types.Collection, _a1 error) *MockCollectionComponent_GetCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCollectionComponent_GetCollection_Call) RunAndReturn(run func(context.Context, string, int64) (*types.Collection, error)) *MockCollectionComponent_GetCollection_Call { + _c.Call.Return(run) + return _c +} + +// GetCollections provides a mock function with given fields: ctx, filter, per, page +func (_m *MockCollectionComponent) GetCollections(ctx context.Context, filter *types.CollectionFilter, per int, page int) ([]types.Collection, int, error) { + ret := _m.Called(ctx, filter, per, page) + + if len(ret) == 0 { + panic("no return value specified for GetCollections") + } + + var r0 []types.Collection + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.CollectionFilter, int, int) ([]types.Collection, int, error)); ok { + return rf(ctx, filter, per, page) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.CollectionFilter, int, int) []types.Collection); ok { + r0 = rf(ctx, filter, per, page) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Collection) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.CollectionFilter, int, int) int); ok { + r1 = rf(ctx, filter, per, page) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.CollectionFilter, int, int) error); ok { + r2 = rf(ctx, filter, per, page) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockCollectionComponent_GetCollections_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollections' +type MockCollectionComponent_GetCollections_Call struct { + *mock.Call +} + +// GetCollections is a helper method to define mock.On call +// - ctx context.Context +// - filter *types.CollectionFilter +// - per int +// - page int +func (_e *MockCollectionComponent_Expecter) GetCollections(ctx interface{}, filter interface{}, per interface{}, page interface{}) *MockCollectionComponent_GetCollections_Call { + return &MockCollectionComponent_GetCollections_Call{Call: _e.mock.On("GetCollections", ctx, filter, per, page)} +} + +func (_c *MockCollectionComponent_GetCollections_Call) Run(run func(ctx context.Context, filter *types.CollectionFilter, per int, page int)) *MockCollectionComponent_GetCollections_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.CollectionFilter), args[2].(int), args[3].(int)) + }) + return _c +} + +func (_c *MockCollectionComponent_GetCollections_Call) Return(_a0 []types.Collection, _a1 int, _a2 error) *MockCollectionComponent_GetCollections_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockCollectionComponent_GetCollections_Call) RunAndReturn(run func(context.Context, *types.CollectionFilter, int, int) ([]types.Collection, int, error)) *MockCollectionComponent_GetCollections_Call { + _c.Call.Return(run) + return _c +} + +// GetPublicRepos provides a mock function with given fields: collection +func (_m *MockCollectionComponent) GetPublicRepos(collection types.Collection) []types.CollectionRepository { + ret := _m.Called(collection) + + if len(ret) == 0 { + panic("no return value specified for GetPublicRepos") + } + + var r0 []types.CollectionRepository + if rf, ok := ret.Get(0).(func(types.Collection) []types.CollectionRepository); ok { + r0 = rf(collection) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.CollectionRepository) + } + } + + return r0 +} + +// MockCollectionComponent_GetPublicRepos_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPublicRepos' +type MockCollectionComponent_GetPublicRepos_Call struct { + *mock.Call +} + +// GetPublicRepos is a helper method to define mock.On call +// - collection types.Collection +func (_e *MockCollectionComponent_Expecter) GetPublicRepos(collection interface{}) *MockCollectionComponent_GetPublicRepos_Call { + return &MockCollectionComponent_GetPublicRepos_Call{Call: _e.mock.On("GetPublicRepos", collection)} +} + +func (_c *MockCollectionComponent_GetPublicRepos_Call) Run(run func(collection types.Collection)) *MockCollectionComponent_GetPublicRepos_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(types.Collection)) + }) + return _c +} + +func (_c *MockCollectionComponent_GetPublicRepos_Call) Return(_a0 []types.CollectionRepository) *MockCollectionComponent_GetPublicRepos_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCollectionComponent_GetPublicRepos_Call) RunAndReturn(run func(types.Collection) []types.CollectionRepository) *MockCollectionComponent_GetPublicRepos_Call { + _c.Call.Return(run) + return _c +} + +// OrgCollections provides a mock function with given fields: ctx, req +func (_m *MockCollectionComponent) OrgCollections(ctx context.Context, req *types.OrgCollectionsReq) ([]types.Collection, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for OrgCollections") + } + + var r0 []types.Collection + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.OrgCollectionsReq) ([]types.Collection, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.OrgCollectionsReq) []types.Collection); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Collection) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.OrgCollectionsReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.OrgCollectionsReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockCollectionComponent_OrgCollections_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OrgCollections' +type MockCollectionComponent_OrgCollections_Call struct { + *mock.Call +} + +// OrgCollections is a helper method to define mock.On call +// - ctx context.Context +// - req *types.OrgCollectionsReq +func (_e *MockCollectionComponent_Expecter) OrgCollections(ctx interface{}, req interface{}) *MockCollectionComponent_OrgCollections_Call { + return &MockCollectionComponent_OrgCollections_Call{Call: _e.mock.On("OrgCollections", ctx, req)} +} + +func (_c *MockCollectionComponent_OrgCollections_Call) Run(run func(ctx context.Context, req *types.OrgCollectionsReq)) *MockCollectionComponent_OrgCollections_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.OrgCollectionsReq)) + }) + return _c +} + +func (_c *MockCollectionComponent_OrgCollections_Call) Return(_a0 []types.Collection, _a1 int, _a2 error) *MockCollectionComponent_OrgCollections_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockCollectionComponent_OrgCollections_Call) RunAndReturn(run func(context.Context, *types.OrgCollectionsReq) ([]types.Collection, int, error)) *MockCollectionComponent_OrgCollections_Call { + _c.Call.Return(run) + return _c +} + +// RemoveReposFromCollection provides a mock function with given fields: ctx, req +func (_m *MockCollectionComponent) RemoveReposFromCollection(ctx context.Context, req types.UpdateCollectionReposReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for RemoveReposFromCollection") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.UpdateCollectionReposReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockCollectionComponent_RemoveReposFromCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveReposFromCollection' +type MockCollectionComponent_RemoveReposFromCollection_Call struct { + *mock.Call +} + +// RemoveReposFromCollection is a helper method to define mock.On call +// - ctx context.Context +// - req types.UpdateCollectionReposReq +func (_e *MockCollectionComponent_Expecter) RemoveReposFromCollection(ctx interface{}, req interface{}) *MockCollectionComponent_RemoveReposFromCollection_Call { + return &MockCollectionComponent_RemoveReposFromCollection_Call{Call: _e.mock.On("RemoveReposFromCollection", ctx, req)} +} + +func (_c *MockCollectionComponent_RemoveReposFromCollection_Call) Run(run func(ctx context.Context, req types.UpdateCollectionReposReq)) *MockCollectionComponent_RemoveReposFromCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.UpdateCollectionReposReq)) + }) + return _c +} + +func (_c *MockCollectionComponent_RemoveReposFromCollection_Call) Return(_a0 error) *MockCollectionComponent_RemoveReposFromCollection_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCollectionComponent_RemoveReposFromCollection_Call) RunAndReturn(run func(context.Context, types.UpdateCollectionReposReq) error) *MockCollectionComponent_RemoveReposFromCollection_Call { + _c.Call.Return(run) + return _c +} + +// UpdateCollection provides a mock function with given fields: ctx, input +func (_m *MockCollectionComponent) UpdateCollection(ctx context.Context, input types.CreateCollectionReq) (*database.Collection, error) { + ret := _m.Called(ctx, input) + + if len(ret) == 0 { + panic("no return value specified for UpdateCollection") + } + + var r0 *database.Collection + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.CreateCollectionReq) (*database.Collection, error)); ok { + return rf(ctx, input) + } + if rf, ok := ret.Get(0).(func(context.Context, types.CreateCollectionReq) *database.Collection); ok { + r0 = rf(ctx, input) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.Collection) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.CreateCollectionReq) error); ok { + r1 = rf(ctx, input) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCollectionComponent_UpdateCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCollection' +type MockCollectionComponent_UpdateCollection_Call struct { + *mock.Call +} + +// UpdateCollection is a helper method to define mock.On call +// - ctx context.Context +// - input types.CreateCollectionReq +func (_e *MockCollectionComponent_Expecter) UpdateCollection(ctx interface{}, input interface{}) *MockCollectionComponent_UpdateCollection_Call { + return &MockCollectionComponent_UpdateCollection_Call{Call: _e.mock.On("UpdateCollection", ctx, input)} +} + +func (_c *MockCollectionComponent_UpdateCollection_Call) Run(run func(ctx context.Context, input types.CreateCollectionReq)) *MockCollectionComponent_UpdateCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.CreateCollectionReq)) + }) + return _c +} + +func (_c *MockCollectionComponent_UpdateCollection_Call) Return(_a0 *database.Collection, _a1 error) *MockCollectionComponent_UpdateCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCollectionComponent_UpdateCollection_Call) RunAndReturn(run func(context.Context, types.CreateCollectionReq) (*database.Collection, error)) *MockCollectionComponent_UpdateCollection_Call { + _c.Call.Return(run) + return _c +} + +// NewMockCollectionComponent creates a new instance of MockCollectionComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockCollectionComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockCollectionComponent { + mock := &MockCollectionComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/component/mock_DatasetComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_DatasetComponent.go new file mode 100644 index 00000000..d61cf7e8 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/mock_DatasetComponent.go @@ -0,0 +1,460 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + types "opencsg.com/csghub-server/common/types" +) + +// MockDatasetComponent is an autogenerated mock type for the DatasetComponent type +type MockDatasetComponent struct { + mock.Mock +} + +type MockDatasetComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockDatasetComponent) EXPECT() *MockDatasetComponent_Expecter { + return &MockDatasetComponent_Expecter{mock: &_m.Mock} +} + +// Create provides a mock function with given fields: ctx, req +func (_m *MockDatasetComponent) Create(ctx context.Context, req *types.CreateDatasetReq) (*types.Dataset, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Create") + } + + var r0 *types.Dataset + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.CreateDatasetReq) (*types.Dataset, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.CreateDatasetReq) *types.Dataset); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Dataset) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.CreateDatasetReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDatasetComponent_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' +type MockDatasetComponent_Create_Call struct { + *mock.Call +} + +// Create is a helper method to define mock.On call +// - ctx context.Context +// - req *types.CreateDatasetReq +func (_e *MockDatasetComponent_Expecter) Create(ctx interface{}, req interface{}) *MockDatasetComponent_Create_Call { + return &MockDatasetComponent_Create_Call{Call: _e.mock.On("Create", ctx, req)} +} + +func (_c *MockDatasetComponent_Create_Call) Run(run func(ctx context.Context, req *types.CreateDatasetReq)) *MockDatasetComponent_Create_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.CreateDatasetReq)) + }) + return _c +} + +func (_c *MockDatasetComponent_Create_Call) Return(_a0 *types.Dataset, _a1 error) *MockDatasetComponent_Create_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDatasetComponent_Create_Call) RunAndReturn(run func(context.Context, *types.CreateDatasetReq) (*types.Dataset, error)) *MockDatasetComponent_Create_Call { + _c.Call.Return(run) + return _c +} + +// Delete provides a mock function with given fields: ctx, namespace, name, currentUser +func (_m *MockDatasetComponent) Delete(ctx context.Context, namespace string, name string, currentUser string) error { + ret := _m.Called(ctx, namespace, name, currentUser) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { + r0 = rf(ctx, namespace, name, currentUser) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockDatasetComponent_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type MockDatasetComponent_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - name string +// - currentUser string +func (_e *MockDatasetComponent_Expecter) Delete(ctx interface{}, namespace interface{}, name interface{}, currentUser interface{}) *MockDatasetComponent_Delete_Call { + return &MockDatasetComponent_Delete_Call{Call: _e.mock.On("Delete", ctx, namespace, name, currentUser)} +} + +func (_c *MockDatasetComponent_Delete_Call) Run(run func(ctx context.Context, namespace string, name string, currentUser string)) *MockDatasetComponent_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MockDatasetComponent_Delete_Call) Return(_a0 error) *MockDatasetComponent_Delete_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockDatasetComponent_Delete_Call) RunAndReturn(run func(context.Context, string, string, string) error) *MockDatasetComponent_Delete_Call { + _c.Call.Return(run) + return _c +} + +// Index provides a mock function with given fields: ctx, filter, per, page +func (_m *MockDatasetComponent) Index(ctx context.Context, filter *types.RepoFilter, per int, page int) ([]types.Dataset, int, error) { + ret := _m.Called(ctx, filter, per, page) + + if len(ret) == 0 { + panic("no return value specified for Index") + } + + var r0 []types.Dataset + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.RepoFilter, int, int) ([]types.Dataset, int, error)); ok { + return rf(ctx, filter, per, page) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.RepoFilter, int, int) []types.Dataset); ok { + r0 = rf(ctx, filter, per, page) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Dataset) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.RepoFilter, int, int) int); ok { + r1 = rf(ctx, filter, per, page) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.RepoFilter, int, int) error); ok { + r2 = rf(ctx, filter, per, page) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockDatasetComponent_Index_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Index' +type MockDatasetComponent_Index_Call struct { + *mock.Call +} + +// Index is a helper method to define mock.On call +// - ctx context.Context +// - filter *types.RepoFilter +// - per int +// - page int +func (_e *MockDatasetComponent_Expecter) Index(ctx interface{}, filter interface{}, per interface{}, page interface{}) *MockDatasetComponent_Index_Call { + return &MockDatasetComponent_Index_Call{Call: _e.mock.On("Index", ctx, filter, per, page)} +} + +func (_c *MockDatasetComponent_Index_Call) Run(run func(ctx context.Context, filter *types.RepoFilter, per int, page int)) *MockDatasetComponent_Index_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.RepoFilter), args[2].(int), args[3].(int)) + }) + return _c +} + +func (_c *MockDatasetComponent_Index_Call) Return(_a0 []types.Dataset, _a1 int, _a2 error) *MockDatasetComponent_Index_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockDatasetComponent_Index_Call) RunAndReturn(run func(context.Context, *types.RepoFilter, int, int) ([]types.Dataset, int, error)) *MockDatasetComponent_Index_Call { + _c.Call.Return(run) + return _c +} + +// OrgDatasets provides a mock function with given fields: ctx, req +func (_m *MockDatasetComponent) OrgDatasets(ctx context.Context, req *types.OrgDatasetsReq) ([]types.Dataset, int, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for OrgDatasets") + } + + var r0 []types.Dataset + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.OrgDatasetsReq) ([]types.Dataset, int, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.OrgDatasetsReq) []types.Dataset); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Dataset) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.OrgDatasetsReq) int); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, *types.OrgDatasetsReq) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockDatasetComponent_OrgDatasets_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OrgDatasets' +type MockDatasetComponent_OrgDatasets_Call struct { + *mock.Call +} + +// OrgDatasets is a helper method to define mock.On call +// - ctx context.Context +// - req *types.OrgDatasetsReq +func (_e *MockDatasetComponent_Expecter) OrgDatasets(ctx interface{}, req interface{}) *MockDatasetComponent_OrgDatasets_Call { + return &MockDatasetComponent_OrgDatasets_Call{Call: _e.mock.On("OrgDatasets", ctx, req)} +} + +func (_c *MockDatasetComponent_OrgDatasets_Call) Run(run func(ctx context.Context, req *types.OrgDatasetsReq)) *MockDatasetComponent_OrgDatasets_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.OrgDatasetsReq)) + }) + return _c +} + +func (_c *MockDatasetComponent_OrgDatasets_Call) Return(_a0 []types.Dataset, _a1 int, _a2 error) *MockDatasetComponent_OrgDatasets_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockDatasetComponent_OrgDatasets_Call) RunAndReturn(run func(context.Context, *types.OrgDatasetsReq) ([]types.Dataset, int, error)) *MockDatasetComponent_OrgDatasets_Call { + _c.Call.Return(run) + return _c +} + +// Relations provides a mock function with given fields: ctx, namespace, name, currentUser +func (_m *MockDatasetComponent) Relations(ctx context.Context, namespace string, name string, currentUser string) (*types.Relations, error) { + ret := _m.Called(ctx, namespace, name, currentUser) + + if len(ret) == 0 { + panic("no return value specified for Relations") + } + + var r0 *types.Relations + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (*types.Relations, error)); ok { + return rf(ctx, namespace, name, currentUser) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *types.Relations); ok { + r0 = rf(ctx, namespace, name, currentUser) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Relations) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, namespace, name, currentUser) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDatasetComponent_Relations_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Relations' +type MockDatasetComponent_Relations_Call struct { + *mock.Call +} + +// Relations is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - name string +// - currentUser string +func (_e *MockDatasetComponent_Expecter) Relations(ctx interface{}, namespace interface{}, name interface{}, currentUser interface{}) *MockDatasetComponent_Relations_Call { + return &MockDatasetComponent_Relations_Call{Call: _e.mock.On("Relations", ctx, namespace, name, currentUser)} +} + +func (_c *MockDatasetComponent_Relations_Call) Run(run func(ctx context.Context, namespace string, name string, currentUser string)) *MockDatasetComponent_Relations_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MockDatasetComponent_Relations_Call) Return(_a0 *types.Relations, _a1 error) *MockDatasetComponent_Relations_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDatasetComponent_Relations_Call) RunAndReturn(run func(context.Context, string, string, string) (*types.Relations, error)) *MockDatasetComponent_Relations_Call { + _c.Call.Return(run) + return _c +} + +// Show provides a mock function with given fields: ctx, namespace, name, currentUser +func (_m *MockDatasetComponent) Show(ctx context.Context, namespace string, name string, currentUser string) (*types.Dataset, error) { + ret := _m.Called(ctx, namespace, name, currentUser) + + if len(ret) == 0 { + panic("no return value specified for Show") + } + + var r0 *types.Dataset + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (*types.Dataset, error)); ok { + return rf(ctx, namespace, name, currentUser) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *types.Dataset); ok { + r0 = rf(ctx, namespace, name, currentUser) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Dataset) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, namespace, name, currentUser) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDatasetComponent_Show_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Show' +type MockDatasetComponent_Show_Call struct { + *mock.Call +} + +// Show is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - name string +// - currentUser string +func (_e *MockDatasetComponent_Expecter) Show(ctx interface{}, namespace interface{}, name interface{}, currentUser interface{}) *MockDatasetComponent_Show_Call { + return &MockDatasetComponent_Show_Call{Call: _e.mock.On("Show", ctx, namespace, name, currentUser)} +} + +func (_c *MockDatasetComponent_Show_Call) Run(run func(ctx context.Context, namespace string, name string, currentUser string)) *MockDatasetComponent_Show_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MockDatasetComponent_Show_Call) Return(_a0 *types.Dataset, _a1 error) *MockDatasetComponent_Show_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDatasetComponent_Show_Call) RunAndReturn(run func(context.Context, string, string, string) (*types.Dataset, error)) *MockDatasetComponent_Show_Call { + _c.Call.Return(run) + return _c +} + +// Update provides a mock function with given fields: ctx, req +func (_m *MockDatasetComponent) Update(ctx context.Context, req *types.UpdateDatasetReq) (*types.Dataset, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Update") + } + + var r0 *types.Dataset + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UpdateDatasetReq) (*types.Dataset, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.UpdateDatasetReq) *types.Dataset); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Dataset) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.UpdateDatasetReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDatasetComponent_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' +type MockDatasetComponent_Update_Call struct { + *mock.Call +} + +// Update is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UpdateDatasetReq +func (_e *MockDatasetComponent_Expecter) Update(ctx interface{}, req interface{}) *MockDatasetComponent_Update_Call { + return &MockDatasetComponent_Update_Call{Call: _e.mock.On("Update", ctx, req)} +} + +func (_c *MockDatasetComponent_Update_Call) Run(run func(ctx context.Context, req *types.UpdateDatasetReq)) *MockDatasetComponent_Update_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UpdateDatasetReq)) + }) + return _c +} + +func (_c *MockDatasetComponent_Update_Call) Return(_a0 *types.Dataset, _a1 error) *MockDatasetComponent_Update_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDatasetComponent_Update_Call) RunAndReturn(run func(context.Context, *types.UpdateDatasetReq) (*types.Dataset, error)) *MockDatasetComponent_Update_Call { + _c.Call.Return(run) + return _c +} + +// NewMockDatasetComponent creates a new instance of MockDatasetComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockDatasetComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockDatasetComponent { + mock := &MockDatasetComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/component/mock_DiscussionComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_DiscussionComponent.go new file mode 100644 index 00000000..91c3766d --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/mock_DiscussionComponent.go @@ -0,0 +1,524 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + types "opencsg.com/csghub-server/common/types" +) + +// MockDiscussionComponent is an autogenerated mock type for the DiscussionComponent type +type MockDiscussionComponent struct { + mock.Mock +} + +type MockDiscussionComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockDiscussionComponent) EXPECT() *MockDiscussionComponent_Expecter { + return &MockDiscussionComponent_Expecter{mock: &_m.Mock} +} + +// CreateDiscussionComment provides a mock function with given fields: ctx, req +func (_m *MockDiscussionComponent) CreateDiscussionComment(ctx context.Context, req types.CreateCommentRequest) (*types.CreateCommentResponse, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateDiscussionComment") + } + + var r0 *types.CreateCommentResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.CreateCommentRequest) (*types.CreateCommentResponse, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.CreateCommentRequest) *types.CreateCommentResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.CreateCommentResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.CreateCommentRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDiscussionComponent_CreateDiscussionComment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateDiscussionComment' +type MockDiscussionComponent_CreateDiscussionComment_Call struct { + *mock.Call +} + +// CreateDiscussionComment is a helper method to define mock.On call +// - ctx context.Context +// - req types.CreateCommentRequest +func (_e *MockDiscussionComponent_Expecter) CreateDiscussionComment(ctx interface{}, req interface{}) *MockDiscussionComponent_CreateDiscussionComment_Call { + return &MockDiscussionComponent_CreateDiscussionComment_Call{Call: _e.mock.On("CreateDiscussionComment", ctx, req)} +} + +func (_c *MockDiscussionComponent_CreateDiscussionComment_Call) Run(run func(ctx context.Context, req types.CreateCommentRequest)) *MockDiscussionComponent_CreateDiscussionComment_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.CreateCommentRequest)) + }) + return _c +} + +func (_c *MockDiscussionComponent_CreateDiscussionComment_Call) Return(_a0 *types.CreateCommentResponse, _a1 error) *MockDiscussionComponent_CreateDiscussionComment_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDiscussionComponent_CreateDiscussionComment_Call) RunAndReturn(run func(context.Context, types.CreateCommentRequest) (*types.CreateCommentResponse, error)) *MockDiscussionComponent_CreateDiscussionComment_Call { + _c.Call.Return(run) + return _c +} + +// CreateRepoDiscussion provides a mock function with given fields: ctx, req +func (_m *MockDiscussionComponent) CreateRepoDiscussion(ctx context.Context, req types.CreateRepoDiscussionRequest) (*types.CreateDiscussionResponse, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateRepoDiscussion") + } + + var r0 *types.CreateDiscussionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.CreateRepoDiscussionRequest) (*types.CreateDiscussionResponse, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.CreateRepoDiscussionRequest) *types.CreateDiscussionResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.CreateDiscussionResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.CreateRepoDiscussionRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDiscussionComponent_CreateRepoDiscussion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateRepoDiscussion' +type MockDiscussionComponent_CreateRepoDiscussion_Call struct { + *mock.Call +} + +// CreateRepoDiscussion is a helper method to define mock.On call +// - ctx context.Context +// - req types.CreateRepoDiscussionRequest +func (_e *MockDiscussionComponent_Expecter) CreateRepoDiscussion(ctx interface{}, req interface{}) *MockDiscussionComponent_CreateRepoDiscussion_Call { + return &MockDiscussionComponent_CreateRepoDiscussion_Call{Call: _e.mock.On("CreateRepoDiscussion", ctx, req)} +} + +func (_c *MockDiscussionComponent_CreateRepoDiscussion_Call) Run(run func(ctx context.Context, req types.CreateRepoDiscussionRequest)) *MockDiscussionComponent_CreateRepoDiscussion_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.CreateRepoDiscussionRequest)) + }) + return _c +} + +func (_c *MockDiscussionComponent_CreateRepoDiscussion_Call) Return(_a0 *types.CreateDiscussionResponse, _a1 error) *MockDiscussionComponent_CreateRepoDiscussion_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDiscussionComponent_CreateRepoDiscussion_Call) RunAndReturn(run func(context.Context, types.CreateRepoDiscussionRequest) (*types.CreateDiscussionResponse, error)) *MockDiscussionComponent_CreateRepoDiscussion_Call { + _c.Call.Return(run) + return _c +} + +// DeleteComment provides a mock function with given fields: ctx, currentUser, id +func (_m *MockDiscussionComponent) DeleteComment(ctx context.Context, currentUser string, id int64) error { + ret := _m.Called(ctx, currentUser, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteComment") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, int64) error); ok { + r0 = rf(ctx, currentUser, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockDiscussionComponent_DeleteComment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteComment' +type MockDiscussionComponent_DeleteComment_Call struct { + *mock.Call +} + +// DeleteComment is a helper method to define mock.On call +// - ctx context.Context +// - currentUser string +// - id int64 +func (_e *MockDiscussionComponent_Expecter) DeleteComment(ctx interface{}, currentUser interface{}, id interface{}) *MockDiscussionComponent_DeleteComment_Call { + return &MockDiscussionComponent_DeleteComment_Call{Call: _e.mock.On("DeleteComment", ctx, currentUser, id)} +} + +func (_c *MockDiscussionComponent_DeleteComment_Call) Run(run func(ctx context.Context, currentUser string, id int64)) *MockDiscussionComponent_DeleteComment_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(int64)) + }) + return _c +} + +func (_c *MockDiscussionComponent_DeleteComment_Call) Return(_a0 error) *MockDiscussionComponent_DeleteComment_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockDiscussionComponent_DeleteComment_Call) RunAndReturn(run func(context.Context, string, int64) error) *MockDiscussionComponent_DeleteComment_Call { + _c.Call.Return(run) + return _c +} + +// DeleteDiscussion provides a mock function with given fields: ctx, currentUser, id +func (_m *MockDiscussionComponent) DeleteDiscussion(ctx context.Context, currentUser string, id int64) error { + ret := _m.Called(ctx, currentUser, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteDiscussion") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, int64) error); ok { + r0 = rf(ctx, currentUser, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockDiscussionComponent_DeleteDiscussion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteDiscussion' +type MockDiscussionComponent_DeleteDiscussion_Call struct { + *mock.Call +} + +// DeleteDiscussion is a helper method to define mock.On call +// - ctx context.Context +// - currentUser string +// - id int64 +func (_e *MockDiscussionComponent_Expecter) DeleteDiscussion(ctx interface{}, currentUser interface{}, id interface{}) *MockDiscussionComponent_DeleteDiscussion_Call { + return &MockDiscussionComponent_DeleteDiscussion_Call{Call: _e.mock.On("DeleteDiscussion", ctx, currentUser, id)} +} + +func (_c *MockDiscussionComponent_DeleteDiscussion_Call) Run(run func(ctx context.Context, currentUser string, id int64)) *MockDiscussionComponent_DeleteDiscussion_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(int64)) + }) + return _c +} + +func (_c *MockDiscussionComponent_DeleteDiscussion_Call) Return(_a0 error) *MockDiscussionComponent_DeleteDiscussion_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockDiscussionComponent_DeleteDiscussion_Call) RunAndReturn(run func(context.Context, string, int64) error) *MockDiscussionComponent_DeleteDiscussion_Call { + _c.Call.Return(run) + return _c +} + +// GetDiscussion provides a mock function with given fields: ctx, id +func (_m *MockDiscussionComponent) GetDiscussion(ctx context.Context, id int64) (*types.ShowDiscussionResponse, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for GetDiscussion") + } + + var r0 *types.ShowDiscussionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int64) (*types.ShowDiscussionResponse, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, int64) *types.ShowDiscussionResponse); ok { + r0 = rf(ctx, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.ShowDiscussionResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDiscussionComponent_GetDiscussion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDiscussion' +type MockDiscussionComponent_GetDiscussion_Call struct { + *mock.Call +} + +// GetDiscussion is a helper method to define mock.On call +// - ctx context.Context +// - id int64 +func (_e *MockDiscussionComponent_Expecter) GetDiscussion(ctx interface{}, id interface{}) *MockDiscussionComponent_GetDiscussion_Call { + return &MockDiscussionComponent_GetDiscussion_Call{Call: _e.mock.On("GetDiscussion", ctx, id)} +} + +func (_c *MockDiscussionComponent_GetDiscussion_Call) Run(run func(ctx context.Context, id int64)) *MockDiscussionComponent_GetDiscussion_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockDiscussionComponent_GetDiscussion_Call) Return(_a0 *types.ShowDiscussionResponse, _a1 error) *MockDiscussionComponent_GetDiscussion_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDiscussionComponent_GetDiscussion_Call) RunAndReturn(run func(context.Context, int64) (*types.ShowDiscussionResponse, error)) *MockDiscussionComponent_GetDiscussion_Call { + _c.Call.Return(run) + return _c +} + +// ListDiscussionComments provides a mock function with given fields: ctx, discussionID +func (_m *MockDiscussionComponent) ListDiscussionComments(ctx context.Context, discussionID int64) ([]*types.DiscussionResponse_Comment, error) { + ret := _m.Called(ctx, discussionID) + + if len(ret) == 0 { + panic("no return value specified for ListDiscussionComments") + } + + var r0 []*types.DiscussionResponse_Comment + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int64) ([]*types.DiscussionResponse_Comment, error)); ok { + return rf(ctx, discussionID) + } + if rf, ok := ret.Get(0).(func(context.Context, int64) []*types.DiscussionResponse_Comment); ok { + r0 = rf(ctx, discussionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*types.DiscussionResponse_Comment) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, discussionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDiscussionComponent_ListDiscussionComments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListDiscussionComments' +type MockDiscussionComponent_ListDiscussionComments_Call struct { + *mock.Call +} + +// ListDiscussionComments is a helper method to define mock.On call +// - ctx context.Context +// - discussionID int64 +func (_e *MockDiscussionComponent_Expecter) ListDiscussionComments(ctx interface{}, discussionID interface{}) *MockDiscussionComponent_ListDiscussionComments_Call { + return &MockDiscussionComponent_ListDiscussionComments_Call{Call: _e.mock.On("ListDiscussionComments", ctx, discussionID)} +} + +func (_c *MockDiscussionComponent_ListDiscussionComments_Call) Run(run func(ctx context.Context, discussionID int64)) *MockDiscussionComponent_ListDiscussionComments_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockDiscussionComponent_ListDiscussionComments_Call) Return(_a0 []*types.DiscussionResponse_Comment, _a1 error) *MockDiscussionComponent_ListDiscussionComments_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDiscussionComponent_ListDiscussionComments_Call) RunAndReturn(run func(context.Context, int64) ([]*types.DiscussionResponse_Comment, error)) *MockDiscussionComponent_ListDiscussionComments_Call { + _c.Call.Return(run) + return _c +} + +// ListRepoDiscussions provides a mock function with given fields: ctx, req +func (_m *MockDiscussionComponent) ListRepoDiscussions(ctx context.Context, req types.ListRepoDiscussionRequest) (*types.ListRepoDiscussionResponse, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for ListRepoDiscussions") + } + + var r0 *types.ListRepoDiscussionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.ListRepoDiscussionRequest) (*types.ListRepoDiscussionResponse, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.ListRepoDiscussionRequest) *types.ListRepoDiscussionResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.ListRepoDiscussionResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.ListRepoDiscussionRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDiscussionComponent_ListRepoDiscussions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListRepoDiscussions' +type MockDiscussionComponent_ListRepoDiscussions_Call struct { + *mock.Call +} + +// ListRepoDiscussions is a helper method to define mock.On call +// - ctx context.Context +// - req types.ListRepoDiscussionRequest +func (_e *MockDiscussionComponent_Expecter) ListRepoDiscussions(ctx interface{}, req interface{}) *MockDiscussionComponent_ListRepoDiscussions_Call { + return &MockDiscussionComponent_ListRepoDiscussions_Call{Call: _e.mock.On("ListRepoDiscussions", ctx, req)} +} + +func (_c *MockDiscussionComponent_ListRepoDiscussions_Call) Run(run func(ctx context.Context, req types.ListRepoDiscussionRequest)) *MockDiscussionComponent_ListRepoDiscussions_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.ListRepoDiscussionRequest)) + }) + return _c +} + +func (_c *MockDiscussionComponent_ListRepoDiscussions_Call) Return(_a0 *types.ListRepoDiscussionResponse, _a1 error) *MockDiscussionComponent_ListRepoDiscussions_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDiscussionComponent_ListRepoDiscussions_Call) RunAndReturn(run func(context.Context, types.ListRepoDiscussionRequest) (*types.ListRepoDiscussionResponse, error)) *MockDiscussionComponent_ListRepoDiscussions_Call { + _c.Call.Return(run) + return _c +} + +// UpdateComment provides a mock function with given fields: ctx, currentUser, id, content +func (_m *MockDiscussionComponent) UpdateComment(ctx context.Context, currentUser string, id int64, content string) error { + ret := _m.Called(ctx, currentUser, id, content) + + if len(ret) == 0 { + panic("no return value specified for UpdateComment") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, int64, string) error); ok { + r0 = rf(ctx, currentUser, id, content) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockDiscussionComponent_UpdateComment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateComment' +type MockDiscussionComponent_UpdateComment_Call struct { + *mock.Call +} + +// UpdateComment is a helper method to define mock.On call +// - ctx context.Context +// - currentUser string +// - id int64 +// - content string +func (_e *MockDiscussionComponent_Expecter) UpdateComment(ctx interface{}, currentUser interface{}, id interface{}, content interface{}) *MockDiscussionComponent_UpdateComment_Call { + return &MockDiscussionComponent_UpdateComment_Call{Call: _e.mock.On("UpdateComment", ctx, currentUser, id, content)} +} + +func (_c *MockDiscussionComponent_UpdateComment_Call) Run(run func(ctx context.Context, currentUser string, id int64, content string)) *MockDiscussionComponent_UpdateComment_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(int64), args[3].(string)) + }) + return _c +} + +func (_c *MockDiscussionComponent_UpdateComment_Call) Return(_a0 error) *MockDiscussionComponent_UpdateComment_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockDiscussionComponent_UpdateComment_Call) RunAndReturn(run func(context.Context, string, int64, string) error) *MockDiscussionComponent_UpdateComment_Call { + _c.Call.Return(run) + return _c +} + +// UpdateDiscussion provides a mock function with given fields: ctx, req +func (_m *MockDiscussionComponent) UpdateDiscussion(ctx context.Context, req types.UpdateDiscussionRequest) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for UpdateDiscussion") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.UpdateDiscussionRequest) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockDiscussionComponent_UpdateDiscussion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateDiscussion' +type MockDiscussionComponent_UpdateDiscussion_Call struct { + *mock.Call +} + +// UpdateDiscussion is a helper method to define mock.On call +// - ctx context.Context +// - req types.UpdateDiscussionRequest +func (_e *MockDiscussionComponent_Expecter) UpdateDiscussion(ctx interface{}, req interface{}) *MockDiscussionComponent_UpdateDiscussion_Call { + return &MockDiscussionComponent_UpdateDiscussion_Call{Call: _e.mock.On("UpdateDiscussion", ctx, req)} +} + +func (_c *MockDiscussionComponent_UpdateDiscussion_Call) Run(run func(ctx context.Context, req types.UpdateDiscussionRequest)) *MockDiscussionComponent_UpdateDiscussion_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.UpdateDiscussionRequest)) + }) + return _c +} + +func (_c *MockDiscussionComponent_UpdateDiscussion_Call) Return(_a0 error) *MockDiscussionComponent_UpdateDiscussion_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockDiscussionComponent_UpdateDiscussion_Call) RunAndReturn(run func(context.Context, types.UpdateDiscussionRequest) error) *MockDiscussionComponent_UpdateDiscussion_Call { + _c.Call.Return(run) + return _c +} + +// NewMockDiscussionComponent creates a new instance of MockDiscussionComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockDiscussionComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockDiscussionComponent { + mock := &MockDiscussionComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/component/mock_EvaluationComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_EvaluationComponent.go new file mode 100644 index 00000000..324e4991 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/mock_EvaluationComponent.go @@ -0,0 +1,202 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + types "opencsg.com/csghub-server/common/types" +) + +// MockEvaluationComponent is an autogenerated mock type for the EvaluationComponent type +type MockEvaluationComponent struct { + mock.Mock +} + +type MockEvaluationComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockEvaluationComponent) EXPECT() *MockEvaluationComponent_Expecter { + return &MockEvaluationComponent_Expecter{mock: &_m.Mock} +} + +// CreateEvaluation provides a mock function with given fields: ctx, req +func (_m *MockEvaluationComponent) CreateEvaluation(ctx context.Context, req types.EvaluationReq) (*types.ArgoWorkFlowRes, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateEvaluation") + } + + var r0 *types.ArgoWorkFlowRes + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.EvaluationReq) (*types.ArgoWorkFlowRes, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.EvaluationReq) *types.ArgoWorkFlowRes); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.ArgoWorkFlowRes) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.EvaluationReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockEvaluationComponent_CreateEvaluation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateEvaluation' +type MockEvaluationComponent_CreateEvaluation_Call struct { + *mock.Call +} + +// CreateEvaluation is a helper method to define mock.On call +// - ctx context.Context +// - req types.EvaluationReq +func (_e *MockEvaluationComponent_Expecter) CreateEvaluation(ctx interface{}, req interface{}) *MockEvaluationComponent_CreateEvaluation_Call { + return &MockEvaluationComponent_CreateEvaluation_Call{Call: _e.mock.On("CreateEvaluation", ctx, req)} +} + +func (_c *MockEvaluationComponent_CreateEvaluation_Call) Run(run func(ctx context.Context, req types.EvaluationReq)) *MockEvaluationComponent_CreateEvaluation_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.EvaluationReq)) + }) + return _c +} + +func (_c *MockEvaluationComponent_CreateEvaluation_Call) Return(_a0 *types.ArgoWorkFlowRes, _a1 error) *MockEvaluationComponent_CreateEvaluation_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockEvaluationComponent_CreateEvaluation_Call) RunAndReturn(run func(context.Context, types.EvaluationReq) (*types.ArgoWorkFlowRes, error)) *MockEvaluationComponent_CreateEvaluation_Call { + _c.Call.Return(run) + return _c +} + +// DeleteEvaluation provides a mock function with given fields: ctx, req +func (_m *MockEvaluationComponent) DeleteEvaluation(ctx context.Context, req types.ArgoWorkFlowDeleteReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for DeleteEvaluation") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.ArgoWorkFlowDeleteReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockEvaluationComponent_DeleteEvaluation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteEvaluation' +type MockEvaluationComponent_DeleteEvaluation_Call struct { + *mock.Call +} + +// DeleteEvaluation is a helper method to define mock.On call +// - ctx context.Context +// - req types.ArgoWorkFlowDeleteReq +func (_e *MockEvaluationComponent_Expecter) DeleteEvaluation(ctx interface{}, req interface{}) *MockEvaluationComponent_DeleteEvaluation_Call { + return &MockEvaluationComponent_DeleteEvaluation_Call{Call: _e.mock.On("DeleteEvaluation", ctx, req)} +} + +func (_c *MockEvaluationComponent_DeleteEvaluation_Call) Run(run func(ctx context.Context, req types.ArgoWorkFlowDeleteReq)) *MockEvaluationComponent_DeleteEvaluation_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.ArgoWorkFlowDeleteReq)) + }) + return _c +} + +func (_c *MockEvaluationComponent_DeleteEvaluation_Call) Return(_a0 error) *MockEvaluationComponent_DeleteEvaluation_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockEvaluationComponent_DeleteEvaluation_Call) RunAndReturn(run func(context.Context, types.ArgoWorkFlowDeleteReq) error) *MockEvaluationComponent_DeleteEvaluation_Call { + _c.Call.Return(run) + return _c +} + +// GetEvaluation provides a mock function with given fields: ctx, req +func (_m *MockEvaluationComponent) GetEvaluation(ctx context.Context, req types.EvaluationGetReq) (*types.EvaluationRes, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for GetEvaluation") + } + + var r0 *types.EvaluationRes + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.EvaluationGetReq) (*types.EvaluationRes, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.EvaluationGetReq) *types.EvaluationRes); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.EvaluationRes) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.EvaluationGetReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockEvaluationComponent_GetEvaluation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetEvaluation' +type MockEvaluationComponent_GetEvaluation_Call struct { + *mock.Call +} + +// GetEvaluation is a helper method to define mock.On call +// - ctx context.Context +// - req types.EvaluationGetReq +func (_e *MockEvaluationComponent_Expecter) GetEvaluation(ctx interface{}, req interface{}) *MockEvaluationComponent_GetEvaluation_Call { + return &MockEvaluationComponent_GetEvaluation_Call{Call: _e.mock.On("GetEvaluation", ctx, req)} +} + +func (_c *MockEvaluationComponent_GetEvaluation_Call) Run(run func(ctx context.Context, req types.EvaluationGetReq)) *MockEvaluationComponent_GetEvaluation_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.EvaluationGetReq)) + }) + return _c +} + +func (_c *MockEvaluationComponent_GetEvaluation_Call) Return(_a0 *types.EvaluationRes, _a1 error) *MockEvaluationComponent_GetEvaluation_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockEvaluationComponent_GetEvaluation_Call) RunAndReturn(run func(context.Context, types.EvaluationGetReq) (*types.EvaluationRes, error)) *MockEvaluationComponent_GetEvaluation_Call { + _c.Call.Return(run) + return _c +} + +// NewMockEvaluationComponent creates a new instance of MockEvaluationComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockEvaluationComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockEvaluationComponent { + mock := &MockEvaluationComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/component/mock_InternalComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_InternalComponent.go new file mode 100644 index 00000000..2b8d5397 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/mock_InternalComponent.go @@ -0,0 +1,331 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + database "opencsg.com/csghub-server/builder/store/database" + + types "opencsg.com/csghub-server/common/types" +) + +// MockInternalComponent is an autogenerated mock type for the InternalComponent type +type MockInternalComponent struct { + mock.Mock +} + +type MockInternalComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockInternalComponent) EXPECT() *MockInternalComponent_Expecter { + return &MockInternalComponent_Expecter{mock: &_m.Mock} +} + +// Allowed provides a mock function with given fields: ctx +func (_m *MockInternalComponent) Allowed(ctx context.Context) (bool, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Allowed") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (bool, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) bool); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockInternalComponent_Allowed_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Allowed' +type MockInternalComponent_Allowed_Call struct { + *mock.Call +} + +// Allowed is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockInternalComponent_Expecter) Allowed(ctx interface{}) *MockInternalComponent_Allowed_Call { + return &MockInternalComponent_Allowed_Call{Call: _e.mock.On("Allowed", ctx)} +} + +func (_c *MockInternalComponent_Allowed_Call) Run(run func(ctx context.Context)) *MockInternalComponent_Allowed_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockInternalComponent_Allowed_Call) Return(_a0 bool, _a1 error) *MockInternalComponent_Allowed_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockInternalComponent_Allowed_Call) RunAndReturn(run func(context.Context) (bool, error)) *MockInternalComponent_Allowed_Call { + _c.Call.Return(run) + return _c +} + +// GetAuthorizedKeys provides a mock function with given fields: ctx, key +func (_m *MockInternalComponent) GetAuthorizedKeys(ctx context.Context, key string) (*database.SSHKey, error) { + ret := _m.Called(ctx, key) + + if len(ret) == 0 { + panic("no return value specified for GetAuthorizedKeys") + } + + var r0 *database.SSHKey + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*database.SSHKey, error)); ok { + return rf(ctx, key) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *database.SSHKey); ok { + r0 = rf(ctx, key) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.SSHKey) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, key) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockInternalComponent_GetAuthorizedKeys_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAuthorizedKeys' +type MockInternalComponent_GetAuthorizedKeys_Call struct { + *mock.Call +} + +// GetAuthorizedKeys is a helper method to define mock.On call +// - ctx context.Context +// - key string +func (_e *MockInternalComponent_Expecter) GetAuthorizedKeys(ctx interface{}, key interface{}) *MockInternalComponent_GetAuthorizedKeys_Call { + return &MockInternalComponent_GetAuthorizedKeys_Call{Call: _e.mock.On("GetAuthorizedKeys", ctx, key)} +} + +func (_c *MockInternalComponent_GetAuthorizedKeys_Call) Run(run func(ctx context.Context, key string)) *MockInternalComponent_GetAuthorizedKeys_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockInternalComponent_GetAuthorizedKeys_Call) Return(_a0 *database.SSHKey, _a1 error) *MockInternalComponent_GetAuthorizedKeys_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockInternalComponent_GetAuthorizedKeys_Call) RunAndReturn(run func(context.Context, string) (*database.SSHKey, error)) *MockInternalComponent_GetAuthorizedKeys_Call { + _c.Call.Return(run) + return _c +} + +// GetCommitDiff provides a mock function with given fields: ctx, req +func (_m *MockInternalComponent) GetCommitDiff(ctx context.Context, req types.GetDiffBetweenTwoCommitsReq) (*types.GiteaCallbackPushReq, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for GetCommitDiff") + } + + var r0 *types.GiteaCallbackPushReq + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.GetDiffBetweenTwoCommitsReq) (*types.GiteaCallbackPushReq, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.GetDiffBetweenTwoCommitsReq) *types.GiteaCallbackPushReq); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.GiteaCallbackPushReq) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.GetDiffBetweenTwoCommitsReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockInternalComponent_GetCommitDiff_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCommitDiff' +type MockInternalComponent_GetCommitDiff_Call struct { + *mock.Call +} + +// GetCommitDiff is a helper method to define mock.On call +// - ctx context.Context +// - req types.GetDiffBetweenTwoCommitsReq +func (_e *MockInternalComponent_Expecter) GetCommitDiff(ctx interface{}, req interface{}) *MockInternalComponent_GetCommitDiff_Call { + return &MockInternalComponent_GetCommitDiff_Call{Call: _e.mock.On("GetCommitDiff", ctx, req)} +} + +func (_c *MockInternalComponent_GetCommitDiff_Call) Run(run func(ctx context.Context, req types.GetDiffBetweenTwoCommitsReq)) *MockInternalComponent_GetCommitDiff_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.GetDiffBetweenTwoCommitsReq)) + }) + return _c +} + +func (_c *MockInternalComponent_GetCommitDiff_Call) Return(_a0 *types.GiteaCallbackPushReq, _a1 error) *MockInternalComponent_GetCommitDiff_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockInternalComponent_GetCommitDiff_Call) RunAndReturn(run func(context.Context, types.GetDiffBetweenTwoCommitsReq) (*types.GiteaCallbackPushReq, error)) *MockInternalComponent_GetCommitDiff_Call { + _c.Call.Return(run) + return _c +} + +// LfsAuthenticate provides a mock function with given fields: ctx, req +func (_m *MockInternalComponent) LfsAuthenticate(ctx context.Context, req types.LfsAuthenticateReq) (*types.LfsAuthenticateResp, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for LfsAuthenticate") + } + + var r0 *types.LfsAuthenticateResp + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.LfsAuthenticateReq) (*types.LfsAuthenticateResp, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.LfsAuthenticateReq) *types.LfsAuthenticateResp); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.LfsAuthenticateResp) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.LfsAuthenticateReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockInternalComponent_LfsAuthenticate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LfsAuthenticate' +type MockInternalComponent_LfsAuthenticate_Call struct { + *mock.Call +} + +// LfsAuthenticate is a helper method to define mock.On call +// - ctx context.Context +// - req types.LfsAuthenticateReq +func (_e *MockInternalComponent_Expecter) LfsAuthenticate(ctx interface{}, req interface{}) *MockInternalComponent_LfsAuthenticate_Call { + return &MockInternalComponent_LfsAuthenticate_Call{Call: _e.mock.On("LfsAuthenticate", ctx, req)} +} + +func (_c *MockInternalComponent_LfsAuthenticate_Call) Run(run func(ctx context.Context, req types.LfsAuthenticateReq)) *MockInternalComponent_LfsAuthenticate_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.LfsAuthenticateReq)) + }) + return _c +} + +func (_c *MockInternalComponent_LfsAuthenticate_Call) Return(_a0 *types.LfsAuthenticateResp, _a1 error) *MockInternalComponent_LfsAuthenticate_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockInternalComponent_LfsAuthenticate_Call) RunAndReturn(run func(context.Context, types.LfsAuthenticateReq) (*types.LfsAuthenticateResp, error)) *MockInternalComponent_LfsAuthenticate_Call { + _c.Call.Return(run) + return _c +} + +// SSHAllowed provides a mock function with given fields: ctx, req +func (_m *MockInternalComponent) SSHAllowed(ctx context.Context, req types.SSHAllowedReq) (*types.SSHAllowedResp, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for SSHAllowed") + } + + var r0 *types.SSHAllowedResp + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.SSHAllowedReq) (*types.SSHAllowedResp, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.SSHAllowedReq) *types.SSHAllowedResp); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.SSHAllowedResp) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.SSHAllowedReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockInternalComponent_SSHAllowed_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SSHAllowed' +type MockInternalComponent_SSHAllowed_Call struct { + *mock.Call +} + +// SSHAllowed is a helper method to define mock.On call +// - ctx context.Context +// - req types.SSHAllowedReq +func (_e *MockInternalComponent_Expecter) SSHAllowed(ctx interface{}, req interface{}) *MockInternalComponent_SSHAllowed_Call { + return &MockInternalComponent_SSHAllowed_Call{Call: _e.mock.On("SSHAllowed", ctx, req)} +} + +func (_c *MockInternalComponent_SSHAllowed_Call) Run(run func(ctx context.Context, req types.SSHAllowedReq)) *MockInternalComponent_SSHAllowed_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.SSHAllowedReq)) + }) + return _c +} + +func (_c *MockInternalComponent_SSHAllowed_Call) Return(_a0 *types.SSHAllowedResp, _a1 error) *MockInternalComponent_SSHAllowed_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockInternalComponent_SSHAllowed_Call) RunAndReturn(run func(context.Context, types.SSHAllowedReq) (*types.SSHAllowedResp, error)) *MockInternalComponent_SSHAllowed_Call { + _c.Call.Return(run) + return _c +} + +// NewMockInternalComponent creates a new instance of MockInternalComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockInternalComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockInternalComponent { + mock := &MockInternalComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/component/mock_MirrorComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_MirrorComponent.go new file mode 100644 index 00000000..0a8f398a --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/mock_MirrorComponent.go @@ -0,0 +1,386 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + database "opencsg.com/csghub-server/builder/store/database" + + types "opencsg.com/csghub-server/common/types" +) + +// MockMirrorComponent is an autogenerated mock type for the MirrorComponent type +type MockMirrorComponent struct { + mock.Mock +} + +type MockMirrorComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockMirrorComponent) EXPECT() *MockMirrorComponent_Expecter { + return &MockMirrorComponent_Expecter{mock: &_m.Mock} +} + +// CheckMirrorProgress provides a mock function with given fields: ctx +func (_m *MockMirrorComponent) CheckMirrorProgress(ctx context.Context) error { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for CheckMirrorProgress") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMirrorComponent_CheckMirrorProgress_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckMirrorProgress' +type MockMirrorComponent_CheckMirrorProgress_Call struct { + *mock.Call +} + +// CheckMirrorProgress is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockMirrorComponent_Expecter) CheckMirrorProgress(ctx interface{}) *MockMirrorComponent_CheckMirrorProgress_Call { + return &MockMirrorComponent_CheckMirrorProgress_Call{Call: _e.mock.On("CheckMirrorProgress", ctx)} +} + +func (_c *MockMirrorComponent_CheckMirrorProgress_Call) Run(run func(ctx context.Context)) *MockMirrorComponent_CheckMirrorProgress_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockMirrorComponent_CheckMirrorProgress_Call) Return(_a0 error) *MockMirrorComponent_CheckMirrorProgress_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMirrorComponent_CheckMirrorProgress_Call) RunAndReturn(run func(context.Context) error) *MockMirrorComponent_CheckMirrorProgress_Call { + _c.Call.Return(run) + return _c +} + +// CreateMirrorRepo provides a mock function with given fields: ctx, req +func (_m *MockMirrorComponent) CreateMirrorRepo(ctx context.Context, req types.CreateMirrorRepoReq) (*database.Mirror, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateMirrorRepo") + } + + var r0 *database.Mirror + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.CreateMirrorRepoReq) (*database.Mirror, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.CreateMirrorRepoReq) *database.Mirror); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.Mirror) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.CreateMirrorRepoReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMirrorComponent_CreateMirrorRepo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateMirrorRepo' +type MockMirrorComponent_CreateMirrorRepo_Call struct { + *mock.Call +} + +// CreateMirrorRepo is a helper method to define mock.On call +// - ctx context.Context +// - req types.CreateMirrorRepoReq +func (_e *MockMirrorComponent_Expecter) CreateMirrorRepo(ctx interface{}, req interface{}) *MockMirrorComponent_CreateMirrorRepo_Call { + return &MockMirrorComponent_CreateMirrorRepo_Call{Call: _e.mock.On("CreateMirrorRepo", ctx, req)} +} + +func (_c *MockMirrorComponent_CreateMirrorRepo_Call) Run(run func(ctx context.Context, req types.CreateMirrorRepoReq)) *MockMirrorComponent_CreateMirrorRepo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.CreateMirrorRepoReq)) + }) + return _c +} + +func (_c *MockMirrorComponent_CreateMirrorRepo_Call) Return(_a0 *database.Mirror, _a1 error) *MockMirrorComponent_CreateMirrorRepo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMirrorComponent_CreateMirrorRepo_Call) RunAndReturn(run func(context.Context, types.CreateMirrorRepoReq) (*database.Mirror, error)) *MockMirrorComponent_CreateMirrorRepo_Call { + _c.Call.Return(run) + return _c +} + +// CreatePushMirrorForFinishedMirrorTask provides a mock function with given fields: ctx +func (_m *MockMirrorComponent) CreatePushMirrorForFinishedMirrorTask(ctx context.Context) error { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for CreatePushMirrorForFinishedMirrorTask") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMirrorComponent_CreatePushMirrorForFinishedMirrorTask_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreatePushMirrorForFinishedMirrorTask' +type MockMirrorComponent_CreatePushMirrorForFinishedMirrorTask_Call struct { + *mock.Call +} + +// CreatePushMirrorForFinishedMirrorTask is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockMirrorComponent_Expecter) CreatePushMirrorForFinishedMirrorTask(ctx interface{}) *MockMirrorComponent_CreatePushMirrorForFinishedMirrorTask_Call { + return &MockMirrorComponent_CreatePushMirrorForFinishedMirrorTask_Call{Call: _e.mock.On("CreatePushMirrorForFinishedMirrorTask", ctx)} +} + +func (_c *MockMirrorComponent_CreatePushMirrorForFinishedMirrorTask_Call) Run(run func(ctx context.Context)) *MockMirrorComponent_CreatePushMirrorForFinishedMirrorTask_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockMirrorComponent_CreatePushMirrorForFinishedMirrorTask_Call) Return(_a0 error) *MockMirrorComponent_CreatePushMirrorForFinishedMirrorTask_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMirrorComponent_CreatePushMirrorForFinishedMirrorTask_Call) RunAndReturn(run func(context.Context) error) *MockMirrorComponent_CreatePushMirrorForFinishedMirrorTask_Call { + _c.Call.Return(run) + return _c +} + +// Index provides a mock function with given fields: ctx, currentUser, per, page, search +func (_m *MockMirrorComponent) Index(ctx context.Context, currentUser string, per int, page int, search string) ([]types.Mirror, int, error) { + ret := _m.Called(ctx, currentUser, per, page, search) + + if len(ret) == 0 { + panic("no return value specified for Index") + } + + var r0 []types.Mirror + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string, int, int, string) ([]types.Mirror, int, error)); ok { + return rf(ctx, currentUser, per, page, search) + } + if rf, ok := ret.Get(0).(func(context.Context, string, int, int, string) []types.Mirror); ok { + r0 = rf(ctx, currentUser, per, page, search) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Mirror) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, int, int, string) int); ok { + r1 = rf(ctx, currentUser, per, page, search) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, string, int, int, string) error); ok { + r2 = rf(ctx, currentUser, per, page, search) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockMirrorComponent_Index_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Index' +type MockMirrorComponent_Index_Call struct { + *mock.Call +} + +// Index is a helper method to define mock.On call +// - ctx context.Context +// - currentUser string +// - per int +// - page int +// - search string +func (_e *MockMirrorComponent_Expecter) Index(ctx interface{}, currentUser interface{}, per interface{}, page interface{}, search interface{}) *MockMirrorComponent_Index_Call { + return &MockMirrorComponent_Index_Call{Call: _e.mock.On("Index", ctx, currentUser, per, page, search)} +} + +func (_c *MockMirrorComponent_Index_Call) Run(run func(ctx context.Context, currentUser string, per int, page int, search string)) *MockMirrorComponent_Index_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(int), args[3].(int), args[4].(string)) + }) + return _c +} + +func (_c *MockMirrorComponent_Index_Call) Return(_a0 []types.Mirror, _a1 int, _a2 error) *MockMirrorComponent_Index_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockMirrorComponent_Index_Call) RunAndReturn(run func(context.Context, string, int, int, string) ([]types.Mirror, int, error)) *MockMirrorComponent_Index_Call { + _c.Call.Return(run) + return _c +} + +// Repos provides a mock function with given fields: ctx, currentUser, per, page +func (_m *MockMirrorComponent) Repos(ctx context.Context, currentUser string, per int, page int) ([]types.MirrorRepo, int, error) { + ret := _m.Called(ctx, currentUser, per, page) + + if len(ret) == 0 { + panic("no return value specified for Repos") + } + + var r0 []types.MirrorRepo + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string, int, int) ([]types.MirrorRepo, int, error)); ok { + return rf(ctx, currentUser, per, page) + } + if rf, ok := ret.Get(0).(func(context.Context, string, int, int) []types.MirrorRepo); ok { + r0 = rf(ctx, currentUser, per, page) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.MirrorRepo) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, int, int) int); ok { + r1 = rf(ctx, currentUser, per, page) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, string, int, int) error); ok { + r2 = rf(ctx, currentUser, per, page) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockMirrorComponent_Repos_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Repos' +type MockMirrorComponent_Repos_Call struct { + *mock.Call +} + +// Repos is a helper method to define mock.On call +// - ctx context.Context +// - currentUser string +// - per int +// - page int +func (_e *MockMirrorComponent_Expecter) Repos(ctx interface{}, currentUser interface{}, per interface{}, page interface{}) *MockMirrorComponent_Repos_Call { + return &MockMirrorComponent_Repos_Call{Call: _e.mock.On("Repos", ctx, currentUser, per, page)} +} + +func (_c *MockMirrorComponent_Repos_Call) Run(run func(ctx context.Context, currentUser string, per int, page int)) *MockMirrorComponent_Repos_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(int), args[3].(int)) + }) + return _c +} + +func (_c *MockMirrorComponent_Repos_Call) Return(_a0 []types.MirrorRepo, _a1 int, _a2 error) *MockMirrorComponent_Repos_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockMirrorComponent_Repos_Call) RunAndReturn(run func(context.Context, string, int, int) ([]types.MirrorRepo, int, error)) *MockMirrorComponent_Repos_Call { + _c.Call.Return(run) + return _c +} + +// Statistics provides a mock function with given fields: ctx, currentUser +func (_m *MockMirrorComponent) Statistics(ctx context.Context, currentUser string) ([]types.MirrorStatusCount, error) { + ret := _m.Called(ctx, currentUser) + + if len(ret) == 0 { + panic("no return value specified for Statistics") + } + + var r0 []types.MirrorStatusCount + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]types.MirrorStatusCount, error)); ok { + return rf(ctx, currentUser) + } + if rf, ok := ret.Get(0).(func(context.Context, string) []types.MirrorStatusCount); ok { + r0 = rf(ctx, currentUser) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.MirrorStatusCount) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, currentUser) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMirrorComponent_Statistics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Statistics' +type MockMirrorComponent_Statistics_Call struct { + *mock.Call +} + +// Statistics is a helper method to define mock.On call +// - ctx context.Context +// - currentUser string +func (_e *MockMirrorComponent_Expecter) Statistics(ctx interface{}, currentUser interface{}) *MockMirrorComponent_Statistics_Call { + return &MockMirrorComponent_Statistics_Call{Call: _e.mock.On("Statistics", ctx, currentUser)} +} + +func (_c *MockMirrorComponent_Statistics_Call) Run(run func(ctx context.Context, currentUser string)) *MockMirrorComponent_Statistics_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockMirrorComponent_Statistics_Call) Return(_a0 []types.MirrorStatusCount, _a1 error) *MockMirrorComponent_Statistics_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMirrorComponent_Statistics_Call) RunAndReturn(run func(context.Context, string) ([]types.MirrorStatusCount, error)) *MockMirrorComponent_Statistics_Call { + _c.Call.Return(run) + return _c +} + +// NewMockMirrorComponent creates a new instance of MockMirrorComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockMirrorComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockMirrorComponent { + mock := &MockMirrorComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/component/mock_MirrorSourceComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_MirrorSourceComponent.go new file mode 100644 index 00000000..c4c038e7 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/mock_MirrorSourceComponent.go @@ -0,0 +1,324 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + database "opencsg.com/csghub-server/builder/store/database" + + types "opencsg.com/csghub-server/common/types" +) + +// MockMirrorSourceComponent is an autogenerated mock type for the MirrorSourceComponent type +type MockMirrorSourceComponent struct { + mock.Mock +} + +type MockMirrorSourceComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockMirrorSourceComponent) EXPECT() *MockMirrorSourceComponent_Expecter { + return &MockMirrorSourceComponent_Expecter{mock: &_m.Mock} +} + +// Create provides a mock function with given fields: ctx, req +func (_m *MockMirrorSourceComponent) Create(ctx context.Context, req types.CreateMirrorSourceReq) (*database.MirrorSource, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Create") + } + + var r0 *database.MirrorSource + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.CreateMirrorSourceReq) (*database.MirrorSource, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.CreateMirrorSourceReq) *database.MirrorSource); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.MirrorSource) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.CreateMirrorSourceReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMirrorSourceComponent_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' +type MockMirrorSourceComponent_Create_Call struct { + *mock.Call +} + +// Create is a helper method to define mock.On call +// - ctx context.Context +// - req types.CreateMirrorSourceReq +func (_e *MockMirrorSourceComponent_Expecter) Create(ctx interface{}, req interface{}) *MockMirrorSourceComponent_Create_Call { + return &MockMirrorSourceComponent_Create_Call{Call: _e.mock.On("Create", ctx, req)} +} + +func (_c *MockMirrorSourceComponent_Create_Call) Run(run func(ctx context.Context, req types.CreateMirrorSourceReq)) *MockMirrorSourceComponent_Create_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.CreateMirrorSourceReq)) + }) + return _c +} + +func (_c *MockMirrorSourceComponent_Create_Call) Return(_a0 *database.MirrorSource, _a1 error) *MockMirrorSourceComponent_Create_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMirrorSourceComponent_Create_Call) RunAndReturn(run func(context.Context, types.CreateMirrorSourceReq) (*database.MirrorSource, error)) *MockMirrorSourceComponent_Create_Call { + _c.Call.Return(run) + return _c +} + +// Delete provides a mock function with given fields: ctx, id, currentUser +func (_m *MockMirrorSourceComponent) Delete(ctx context.Context, id int64, currentUser string) error { + ret := _m.Called(ctx, id, currentUser) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int64, string) error); ok { + r0 = rf(ctx, id, currentUser) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMirrorSourceComponent_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type MockMirrorSourceComponent_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - ctx context.Context +// - id int64 +// - currentUser string +func (_e *MockMirrorSourceComponent_Expecter) Delete(ctx interface{}, id interface{}, currentUser interface{}) *MockMirrorSourceComponent_Delete_Call { + return &MockMirrorSourceComponent_Delete_Call{Call: _e.mock.On("Delete", ctx, id, currentUser)} +} + +func (_c *MockMirrorSourceComponent_Delete_Call) Run(run func(ctx context.Context, id int64, currentUser string)) *MockMirrorSourceComponent_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64), args[2].(string)) + }) + return _c +} + +func (_c *MockMirrorSourceComponent_Delete_Call) Return(_a0 error) *MockMirrorSourceComponent_Delete_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMirrorSourceComponent_Delete_Call) RunAndReturn(run func(context.Context, int64, string) error) *MockMirrorSourceComponent_Delete_Call { + _c.Call.Return(run) + return _c +} + +// Get provides a mock function with given fields: ctx, id, currentUser +func (_m *MockMirrorSourceComponent) Get(ctx context.Context, id int64, currentUser string) (*database.MirrorSource, error) { + ret := _m.Called(ctx, id, currentUser) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 *database.MirrorSource + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int64, string) (*database.MirrorSource, error)); ok { + return rf(ctx, id, currentUser) + } + if rf, ok := ret.Get(0).(func(context.Context, int64, string) *database.MirrorSource); ok { + r0 = rf(ctx, id, currentUser) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.MirrorSource) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int64, string) error); ok { + r1 = rf(ctx, id, currentUser) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMirrorSourceComponent_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type MockMirrorSourceComponent_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - ctx context.Context +// - id int64 +// - currentUser string +func (_e *MockMirrorSourceComponent_Expecter) Get(ctx interface{}, id interface{}, currentUser interface{}) *MockMirrorSourceComponent_Get_Call { + return &MockMirrorSourceComponent_Get_Call{Call: _e.mock.On("Get", ctx, id, currentUser)} +} + +func (_c *MockMirrorSourceComponent_Get_Call) Run(run func(ctx context.Context, id int64, currentUser string)) *MockMirrorSourceComponent_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64), args[2].(string)) + }) + return _c +} + +func (_c *MockMirrorSourceComponent_Get_Call) Return(_a0 *database.MirrorSource, _a1 error) *MockMirrorSourceComponent_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMirrorSourceComponent_Get_Call) RunAndReturn(run func(context.Context, int64, string) (*database.MirrorSource, error)) *MockMirrorSourceComponent_Get_Call { + _c.Call.Return(run) + return _c +} + +// Index provides a mock function with given fields: ctx, currentUser +func (_m *MockMirrorSourceComponent) Index(ctx context.Context, currentUser string) ([]database.MirrorSource, error) { + ret := _m.Called(ctx, currentUser) + + if len(ret) == 0 { + panic("no return value specified for Index") + } + + var r0 []database.MirrorSource + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]database.MirrorSource, error)); ok { + return rf(ctx, currentUser) + } + if rf, ok := ret.Get(0).(func(context.Context, string) []database.MirrorSource); ok { + r0 = rf(ctx, currentUser) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]database.MirrorSource) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, currentUser) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMirrorSourceComponent_Index_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Index' +type MockMirrorSourceComponent_Index_Call struct { + *mock.Call +} + +// Index is a helper method to define mock.On call +// - ctx context.Context +// - currentUser string +func (_e *MockMirrorSourceComponent_Expecter) Index(ctx interface{}, currentUser interface{}) *MockMirrorSourceComponent_Index_Call { + return &MockMirrorSourceComponent_Index_Call{Call: _e.mock.On("Index", ctx, currentUser)} +} + +func (_c *MockMirrorSourceComponent_Index_Call) Run(run func(ctx context.Context, currentUser string)) *MockMirrorSourceComponent_Index_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockMirrorSourceComponent_Index_Call) Return(_a0 []database.MirrorSource, _a1 error) *MockMirrorSourceComponent_Index_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMirrorSourceComponent_Index_Call) RunAndReturn(run func(context.Context, string) ([]database.MirrorSource, error)) *MockMirrorSourceComponent_Index_Call { + _c.Call.Return(run) + return _c +} + +// Update provides a mock function with given fields: ctx, req +func (_m *MockMirrorSourceComponent) Update(ctx context.Context, req types.UpdateMirrorSourceReq) (*database.MirrorSource, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Update") + } + + var r0 *database.MirrorSource + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.UpdateMirrorSourceReq) (*database.MirrorSource, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.UpdateMirrorSourceReq) *database.MirrorSource); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.MirrorSource) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.UpdateMirrorSourceReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMirrorSourceComponent_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' +type MockMirrorSourceComponent_Update_Call struct { + *mock.Call +} + +// Update is a helper method to define mock.On call +// - ctx context.Context +// - req types.UpdateMirrorSourceReq +func (_e *MockMirrorSourceComponent_Expecter) Update(ctx interface{}, req interface{}) *MockMirrorSourceComponent_Update_Call { + return &MockMirrorSourceComponent_Update_Call{Call: _e.mock.On("Update", ctx, req)} +} + +func (_c *MockMirrorSourceComponent_Update_Call) Run(run func(ctx context.Context, req types.UpdateMirrorSourceReq)) *MockMirrorSourceComponent_Update_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.UpdateMirrorSourceReq)) + }) + return _c +} + +func (_c *MockMirrorSourceComponent_Update_Call) Return(_a0 *database.MirrorSource, _a1 error) *MockMirrorSourceComponent_Update_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMirrorSourceComponent_Update_Call) RunAndReturn(run func(context.Context, types.UpdateMirrorSourceReq) (*database.MirrorSource, error)) *MockMirrorSourceComponent_Update_Call { + _c.Call.Return(run) + return _c +} + +// NewMockMirrorSourceComponent creates a new instance of MockMirrorSourceComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockMirrorSourceComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockMirrorSourceComponent { + mock := &MockMirrorSourceComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/component/mock_SpaceResourceComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_SpaceResourceComponent.go new file mode 100644 index 00000000..9e98ba3d --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/mock_SpaceResourceComponent.go @@ -0,0 +1,263 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + types "opencsg.com/csghub-server/common/types" +) + +// MockSpaceResourceComponent is an autogenerated mock type for the SpaceResourceComponent type +type MockSpaceResourceComponent struct { + mock.Mock +} + +type MockSpaceResourceComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockSpaceResourceComponent) EXPECT() *MockSpaceResourceComponent_Expecter { + return &MockSpaceResourceComponent_Expecter{mock: &_m.Mock} +} + +// Create provides a mock function with given fields: ctx, req +func (_m *MockSpaceResourceComponent) Create(ctx context.Context, req *types.CreateSpaceResourceReq) (*types.SpaceResource, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Create") + } + + var r0 *types.SpaceResource + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.CreateSpaceResourceReq) (*types.SpaceResource, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.CreateSpaceResourceReq) *types.SpaceResource); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.SpaceResource) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.CreateSpaceResourceReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockSpaceResourceComponent_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' +type MockSpaceResourceComponent_Create_Call struct { + *mock.Call +} + +// Create is a helper method to define mock.On call +// - ctx context.Context +// - req *types.CreateSpaceResourceReq +func (_e *MockSpaceResourceComponent_Expecter) Create(ctx interface{}, req interface{}) *MockSpaceResourceComponent_Create_Call { + return &MockSpaceResourceComponent_Create_Call{Call: _e.mock.On("Create", ctx, req)} +} + +func (_c *MockSpaceResourceComponent_Create_Call) Run(run func(ctx context.Context, req *types.CreateSpaceResourceReq)) *MockSpaceResourceComponent_Create_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.CreateSpaceResourceReq)) + }) + return _c +} + +func (_c *MockSpaceResourceComponent_Create_Call) Return(_a0 *types.SpaceResource, _a1 error) *MockSpaceResourceComponent_Create_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSpaceResourceComponent_Create_Call) RunAndReturn(run func(context.Context, *types.CreateSpaceResourceReq) (*types.SpaceResource, error)) *MockSpaceResourceComponent_Create_Call { + _c.Call.Return(run) + return _c +} + +// Delete provides a mock function with given fields: ctx, id +func (_m *MockSpaceResourceComponent) Delete(ctx context.Context, id int64) error { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockSpaceResourceComponent_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type MockSpaceResourceComponent_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - ctx context.Context +// - id int64 +func (_e *MockSpaceResourceComponent_Expecter) Delete(ctx interface{}, id interface{}) *MockSpaceResourceComponent_Delete_Call { + return &MockSpaceResourceComponent_Delete_Call{Call: _e.mock.On("Delete", ctx, id)} +} + +func (_c *MockSpaceResourceComponent_Delete_Call) Run(run func(ctx context.Context, id int64)) *MockSpaceResourceComponent_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockSpaceResourceComponent_Delete_Call) Return(_a0 error) *MockSpaceResourceComponent_Delete_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockSpaceResourceComponent_Delete_Call) RunAndReturn(run func(context.Context, int64) error) *MockSpaceResourceComponent_Delete_Call { + _c.Call.Return(run) + return _c +} + +// Index provides a mock function with given fields: ctx, clusterId, deployType, currentUser +func (_m *MockSpaceResourceComponent) Index(ctx context.Context, clusterId string, deployType int, currentUser string) ([]types.SpaceResource, error) { + ret := _m.Called(ctx, clusterId, deployType, currentUser) + + if len(ret) == 0 { + panic("no return value specified for Index") + } + + var r0 []types.SpaceResource + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, int, string) ([]types.SpaceResource, error)); ok { + return rf(ctx, clusterId, deployType, currentUser) + } + if rf, ok := ret.Get(0).(func(context.Context, string, int, string) []types.SpaceResource); ok { + r0 = rf(ctx, clusterId, deployType, currentUser) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.SpaceResource) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, int, string) error); ok { + r1 = rf(ctx, clusterId, deployType, currentUser) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockSpaceResourceComponent_Index_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Index' +type MockSpaceResourceComponent_Index_Call struct { + *mock.Call +} + +// Index is a helper method to define mock.On call +// - ctx context.Context +// - clusterId string +// - deployType int +// - currentUser string +func (_e *MockSpaceResourceComponent_Expecter) Index(ctx interface{}, clusterId interface{}, deployType interface{}, currentUser interface{}) *MockSpaceResourceComponent_Index_Call { + return &MockSpaceResourceComponent_Index_Call{Call: _e.mock.On("Index", ctx, clusterId, deployType, currentUser)} +} + +func (_c *MockSpaceResourceComponent_Index_Call) Run(run func(ctx context.Context, clusterId string, deployType int, currentUser string)) *MockSpaceResourceComponent_Index_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(int), args[3].(string)) + }) + return _c +} + +func (_c *MockSpaceResourceComponent_Index_Call) Return(_a0 []types.SpaceResource, _a1 error) *MockSpaceResourceComponent_Index_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSpaceResourceComponent_Index_Call) RunAndReturn(run func(context.Context, string, int, string) ([]types.SpaceResource, error)) *MockSpaceResourceComponent_Index_Call { + _c.Call.Return(run) + return _c +} + +// Update provides a mock function with given fields: ctx, req +func (_m *MockSpaceResourceComponent) Update(ctx context.Context, req *types.UpdateSpaceResourceReq) (*types.SpaceResource, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Update") + } + + var r0 *types.SpaceResource + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.UpdateSpaceResourceReq) (*types.SpaceResource, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.UpdateSpaceResourceReq) *types.SpaceResource); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.SpaceResource) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.UpdateSpaceResourceReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockSpaceResourceComponent_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' +type MockSpaceResourceComponent_Update_Call struct { + *mock.Call +} + +// Update is a helper method to define mock.On call +// - ctx context.Context +// - req *types.UpdateSpaceResourceReq +func (_e *MockSpaceResourceComponent_Expecter) Update(ctx interface{}, req interface{}) *MockSpaceResourceComponent_Update_Call { + return &MockSpaceResourceComponent_Update_Call{Call: _e.mock.On("Update", ctx, req)} +} + +func (_c *MockSpaceResourceComponent_Update_Call) Run(run func(ctx context.Context, req *types.UpdateSpaceResourceReq)) *MockSpaceResourceComponent_Update_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.UpdateSpaceResourceReq)) + }) + return _c +} + +func (_c *MockSpaceResourceComponent_Update_Call) Return(_a0 *types.SpaceResource, _a1 error) *MockSpaceResourceComponent_Update_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSpaceResourceComponent_Update_Call) RunAndReturn(run func(context.Context, *types.UpdateSpaceResourceReq) (*types.SpaceResource, error)) *MockSpaceResourceComponent_Update_Call { + _c.Call.Return(run) + return _c +} + +// NewMockSpaceResourceComponent creates a new instance of MockSpaceResourceComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockSpaceResourceComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockSpaceResourceComponent { + mock := &MockSpaceResourceComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/api/handler/accounting.go b/api/handler/accounting.go index 71ec3346..97e49391 100644 --- a/api/handler/accounting.go +++ b/api/handler/accounting.go @@ -16,8 +16,8 @@ import ( ) type AccountingHandler struct { - ac component.AccountingComponent - apiToken string + accounting component.AccountingComponent + apiToken string } func NewAccountingHandler(config *config.Config) (*AccountingHandler, error) { @@ -26,8 +26,8 @@ func NewAccountingHandler(config *config.Config) (*AccountingHandler, error) { return nil, err } return &AccountingHandler{ - ac: acctComp, - apiToken: config.APIToken, + accounting: acctComp, + apiToken: config.APIToken, }, nil } @@ -106,7 +106,7 @@ func (ah *AccountingHandler) QueryMeteringStatementByUserID(ctx *gin.Context) { Per: per, Page: page, } - data, err := ah.ac.ListMeteringsByUserIDAndTime(ctx, req) + data, err := ah.accounting.ListMeteringsByUserIDAndTime(ctx, req) if err != nil { errTip := "fail to query meterings by user" slog.Error(errTip, slog.Any("req", req), slog.Any("error", err)) diff --git a/api/handler/accounting_test.go b/api/handler/accounting_test.go new file mode 100644 index 00000000..dc066a49 --- /dev/null +++ b/api/handler/accounting_test.go @@ -0,0 +1,60 @@ +package handler + +import ( + "testing" + + "github.com/gin-gonic/gin" + mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/common/types" +) + +type AccountingTester struct { + *GinTester + handler *AccountingHandler + mocks struct { + accounting *mockcomponent.MockAccountingComponent + } +} + +func NewAccountingTester(t *testing.T) *AccountingTester { + tester := &AccountingTester{GinTester: NewGinTester()} + tester.mocks.accounting = mockcomponent.NewMockAccountingComponent(t) + + tester.handler = &AccountingHandler{ + accounting: tester.mocks.accounting, + apiToken: "testApiToken", // You can set this dynamically if needed + } + tester.WithParam("namespace", "u") + tester.WithParam("name", "r") + return tester +} + +func (t *AccountingTester) WithHandleFunc(fn func(h *AccountingHandler) gin.HandlerFunc) *AccountingTester { + t.ginHandler = fn(t.handler) + return t +} + +func TestAccountingHandler_QueryMeteringStatementByUserID(t *testing.T) { + tester := NewAccountingTester(t).WithHandleFunc(func(h *AccountingHandler) gin.HandlerFunc { + return h.QueryMeteringStatementByUserID + }) + tester.RequireUser(t) + + tester.mocks.accounting.EXPECT().ListMeteringsByUserIDAndTime( + tester.ctx, types.ACCT_STATEMENTS_REQ{ + CurrentUser: "u", + UserUUID: "1", + Scene: 2, + InstanceName: "in", + StartTime: "2020-10-20 12:34:05", + EndTime: "2020-11-21 12:34:05", + Per: 10, + Page: 1, + }, + ).Return("go", nil) + tester.AddPagination(1, 10).WithParam("id", "1").WithQuery("instance_name", "in").WithQuery("scene", "2") + tester.WithQuery("start_time", "2020-10-20 12:34:05").WithQuery("end_time", "2020-11-21 12:34:05") + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, "go") +} diff --git a/api/handler/collection.go b/api/handler/collection.go index 37efd370..1d8e28ac 100644 --- a/api/handler/collection.go +++ b/api/handler/collection.go @@ -26,14 +26,14 @@ func NewCollectionHandler(cfg *config.Config) (*CollectionHandler, error) { return nil, fmt.Errorf("error creating sensitive component:%w", err) } return &CollectionHandler{ - cc: cc, - sc: sc, + collection: cc, + sensitive: sc, }, nil } type CollectionHandler struct { - cc component.CollectionComponent - sc component.SensitiveComponent + collection component.CollectionComponent + sensitive component.SensitiveComponent } // GetCollections godoc @@ -65,7 +65,7 @@ func (c *CollectionHandler) Index(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - collections, total, err := c.cc.GetCollections(ctx, filter, per, page) + collections, total, err := c.collection.GetCollections(ctx, filter, per, page) if err != nil { slog.Error("Failed to load collections", "error", err) httpbase.ServerError(ctx, err) @@ -104,7 +104,7 @@ func (c *CollectionHandler) Create(ctx *gin.Context) { return } - _, err := c.sc.CheckRequestV2(ctx, req) + _, err := c.sensitive.CheckRequestV2(ctx, req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -112,7 +112,7 @@ func (c *CollectionHandler) Create(ctx *gin.Context) { } req.Username = currentUser - collection, err := c.cc.CreateCollection(ctx, *req) + collection, err := c.collection.CreateCollection(ctx, *req) if err != nil { slog.Error("Failed to create collection", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -140,7 +140,7 @@ func (c *CollectionHandler) GetCollection(ctx *gin.Context) { return } currentUser := httpbase.GetCurrentUser(ctx) - collection, err := c.cc.GetCollection(ctx, currentUser, id) + collection, err := c.collection.GetCollection(ctx, currentUser, id) if err != nil { slog.Error("Failed to create space", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -176,7 +176,7 @@ func (c *CollectionHandler) UpdateCollection(ctx *gin.Context) { return } - _, err := c.sc.CheckRequestV2(ctx, req) + _, err := c.sensitive.CheckRequestV2(ctx, req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -192,7 +192,7 @@ func (c *CollectionHandler) UpdateCollection(ctx *gin.Context) { req.ID = id - collection, err := c.cc.UpdateCollection(ctx, *req) + collection, err := c.collection.UpdateCollection(ctx, *req) if err != nil { slog.Error("Failed to create space", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -227,7 +227,7 @@ func (c *CollectionHandler) DeleteCollection(ctx *gin.Context) { return } - err = c.cc.DeleteCollection(ctx, id, currentUser) + err = c.collection.DeleteCollection(ctx, id, currentUser) if err != nil { slog.Error("Failed to delete collection", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -271,7 +271,7 @@ func (c *CollectionHandler) AddRepoToCollection(ctx *gin.Context) { } req.ID = id - err = c.cc.AddReposToCollection(ctx, *req) + err = c.collection.AddReposToCollection(ctx, *req) if err != nil { slog.Error("Failed to create collection", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -314,7 +314,7 @@ func (c *CollectionHandler) RemoveRepoFromCollection(ctx *gin.Context) { } req.ID = id - err = c.cc.RemoveReposFromCollection(ctx, *req) + err = c.collection.RemoveReposFromCollection(ctx, *req) if err != nil { slog.Error("Failed to create collection", slog.Any("error", err)) httpbase.ServerError(ctx, err) diff --git a/api/handler/collection_test.go b/api/handler/collection_test.go new file mode 100644 index 00000000..e1bb914f --- /dev/null +++ b/api/handler/collection_test.go @@ -0,0 +1,173 @@ +package handler + +import ( + "fmt" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +type CollectionTester struct { + *GinTester + handler *CollectionHandler + mocks struct { + collection *mockcomponent.MockCollectionComponent + sensitive *mockcomponent.MockSensitiveComponent + } +} + +func NewCollectionTester(t *testing.T) *CollectionTester { + tester := &CollectionTester{GinTester: NewGinTester()} + tester.mocks.collection = mockcomponent.NewMockCollectionComponent(t) + tester.mocks.sensitive = mockcomponent.NewMockSensitiveComponent(t) + + tester.handler = &CollectionHandler{ + collection: tester.mocks.collection, + sensitive: tester.mocks.sensitive, + } + tester.WithParam("namespace", "u") + tester.WithParam("name", "r") + return tester +} + +func (t *CollectionTester) WithHandleFunc(fn func(h *CollectionHandler) gin.HandlerFunc) *CollectionTester { + t.ginHandler = fn(t.handler) + return t +} + +func TestCollectionHandler_Index(t *testing.T) { + cases := []struct { + sort string + error bool + }{ + {"trending", false}, + {"foo", true}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("%+v", c), func(t *testing.T) { + + tester := NewCollectionTester(t).WithHandleFunc(func(h *CollectionHandler) gin.HandlerFunc { + return h.Index + }) + + if !c.error { + tester.mocks.collection.EXPECT().GetCollections(tester.ctx, &types.CollectionFilter{ + Search: "foo", + Sort: c.sort, + }, 10, 1).Return([]types.Collection{ + {Name: "cc"}, + }, 100, nil) + } + + tester.AddPagination(1, 10).WithQuery("search", "foo"). + WithQuery("sort", c.sort).Execute() + + if c.error { + require.Equal(t, 400, tester.response.Code) + } else { + tester.ResponseEqSimple(t, 200, gin.H{ + "data": []types.Collection{{Name: "cc"}}, + "total": 100, + }) + } + }) + } +} + +func TestCollectionHandler_Create(t *testing.T) { + tester := NewCollectionTester(t).WithHandleFunc(func(h *CollectionHandler) gin.HandlerFunc { + return h.Create + }) + tester.RequireUser(t) + + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, &types.CreateCollectionReq{}).Return(true, nil) + tester.mocks.collection.EXPECT().CreateCollection(tester.ctx, types.CreateCollectionReq{ + Username: "u", + }).Return(&database.Collection{ID: 1}, nil) + tester.WithBody(t, &types.CreateCollectionReq{}).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &database.Collection{ID: 1}) + +} + +func TestCollectionHandler_GetCollection(t *testing.T) { + tester := NewCollectionTester(t).WithHandleFunc(func(h *CollectionHandler) gin.HandlerFunc { + return h.GetCollection + }) + + tester.mocks.collection.EXPECT().GetCollection( + tester.ctx, "u", int64(1), + ).Return(&types.Collection{ID: 1}, nil) + tester.WithUser().WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Collection{ID: 1}) + +} + +func TestCollectionHandler_UpdateCollection(t *testing.T) { + tester := NewCollectionTester(t).WithHandleFunc(func(h *CollectionHandler) gin.HandlerFunc { + return h.UpdateCollection + }) + tester.RequireUser(t) + + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, &types.CreateCollectionReq{}).Return(true, nil) + tester.mocks.collection.EXPECT().UpdateCollection(tester.ctx, types.CreateCollectionReq{ + ID: 1, + }).Return(&database.Collection{ID: 1}, nil) + tester.WithParam("id", "1").WithBody(t, &types.CreateCollectionReq{}).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &database.Collection{ID: 1}) + +} + +func TestCollectionHandler_DeleteCollection(t *testing.T) { + tester := NewCollectionTester(t).WithHandleFunc(func(h *CollectionHandler) gin.HandlerFunc { + return h.DeleteCollection + }) + tester.RequireUser(t) + + tester.mocks.collection.EXPECT().DeleteCollection( + tester.ctx, int64(1), "u", + ).Return(nil) + tester.WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) + +} + +func TestCollectionHandler_AddRepoToCollection(t *testing.T) { + tester := NewCollectionTester(t).WithHandleFunc(func(h *CollectionHandler) gin.HandlerFunc { + return h.AddRepoToCollection + }) + tester.RequireUser(t) + + tester.mocks.collection.EXPECT().AddReposToCollection(tester.ctx, types.UpdateCollectionReposReq{ + Username: "u", + ID: 1, + }).Return(nil) + tester.WithParam("id", "1").WithBody(t, &types.UpdateCollectionReposReq{}).Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) + +} + +func TestCollectionHandler_RemoveRepoFromCollection(t *testing.T) { + tester := NewCollectionTester(t).WithHandleFunc(func(h *CollectionHandler) gin.HandlerFunc { + return h.RemoveRepoFromCollection + }) + tester.RequireUser(t) + + tester.mocks.collection.EXPECT().RemoveReposFromCollection(tester.ctx, types.UpdateCollectionReposReq{ + Username: "u", + ID: 1, + }).Return(nil) + tester.WithParam("id", "1").WithBody(t, &types.UpdateCollectionReposReq{}).Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) + +} diff --git a/api/handler/dataset.go b/api/handler/dataset.go index ca9bc94d..043fe0c2 100644 --- a/api/handler/dataset.go +++ b/api/handler/dataset.go @@ -33,16 +33,16 @@ func NewDatasetHandler(config *config.Config) (*DatasetHandler, error) { } return &DatasetHandler{ - c: tc, - sc: sc, - repo: repo, + dataset: tc, + sensitive: sc, + repo: repo, }, nil } type DatasetHandler struct { - c component.DatasetComponent - sc component.SensitiveComponent - repo component.RepoComponent + dataset component.DatasetComponent + sensitive component.SensitiveComponent + repo component.RepoComponent } // CreateDataset godoc @@ -70,7 +70,7 @@ func (h *DatasetHandler) Create(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - _, err := h.sc.CheckRequestV2(ctx, req) + _, err := h.sensitive.CheckRequestV2(ctx, req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -78,7 +78,7 @@ func (h *DatasetHandler) Create(ctx *gin.Context) { } req.Username = currentUser - dataset, err := h.c.Create(ctx, req) + dataset, err := h.dataset.Create(ctx, req) if err != nil { slog.Error("Failed to create dataset", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -137,7 +137,7 @@ func (h *DatasetHandler) Index(ctx *gin.Context) { return } - datasets, total, err := h.c.Index(ctx, filter, per, page) + datasets, total, err := h.dataset.Index(ctx, filter, per, page) if err != nil { slog.Error("Failed to get datasets", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -179,7 +179,7 @@ func (h *DatasetHandler) Update(ctx *gin.Context) { return } - _, err := h.sc.CheckRequestV2(ctx, req) + _, err := h.sensitive.CheckRequestV2(ctx, req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -196,7 +196,7 @@ func (h *DatasetHandler) Update(ctx *gin.Context) { req.Namespace = namespace req.Name = name - dataset, err := h.c.Update(ctx, req) + dataset, err := h.dataset.Update(ctx, req) if err != nil { slog.Error("Failed to update dataset", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -233,7 +233,7 @@ func (h *DatasetHandler) Delete(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - err = h.c.Delete(ctx, namespace, name, currentUser) + err = h.dataset.Delete(ctx, namespace, name, currentUser) if err != nil { slog.Error("Failed to delete dataset", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -265,7 +265,7 @@ func (h *DatasetHandler) Show(ctx *gin.Context) { return } currentUser := httpbase.GetCurrentUser(ctx) - detail, err := h.c.Show(ctx, namespace, name, currentUser) + detail, err := h.dataset.Show(ctx, namespace, name, currentUser) if err != nil { if errors.Is(err, component.ErrUnauthorized) { httpbase.UnauthorizedError(ctx, err) @@ -301,7 +301,7 @@ func (h *DatasetHandler) Relations(ctx *gin.Context) { return } currentUser := httpbase.GetCurrentUser(ctx) - detail, err := h.c.Relations(ctx, namespace, name, currentUser) + detail, err := h.dataset.Relations(ctx, namespace, name, currentUser) if err != nil { if errors.Is(err, component.ErrUnauthorized) { httpbase.UnauthorizedError(ctx, err) diff --git a/api/handler/dataset_test.go b/api/handler/dataset_test.go new file mode 100644 index 00000000..de181d1a --- /dev/null +++ b/api/handler/dataset_test.go @@ -0,0 +1,189 @@ +package handler + +import ( + "fmt" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/common/types" +) + +type DatasetTester struct { + *GinTester + handler *DatasetHandler + mocks struct { + dataset *mockcomponent.MockDatasetComponent + sensitive *mockcomponent.MockSensitiveComponent + repo *mockcomponent.MockRepoComponent + } +} + +func NewDatasetTester(t *testing.T) *DatasetTester { + tester := &DatasetTester{GinTester: NewGinTester()} + tester.mocks.dataset = mockcomponent.NewMockDatasetComponent(t) + tester.mocks.sensitive = mockcomponent.NewMockSensitiveComponent(t) + tester.mocks.repo = mockcomponent.NewMockRepoComponent(t) + + tester.handler = &DatasetHandler{ + dataset: tester.mocks.dataset, + sensitive: tester.mocks.sensitive, + repo: tester.mocks.repo, + } + tester.WithParam("namespace", "u") + tester.WithParam("name", "r") + return tester +} + +func (t *DatasetTester) WithHandleFunc(fn func(h *DatasetHandler) gin.HandlerFunc) *DatasetTester { + t.ginHandler = fn(t.handler) + return t +} + +func TestDatasetHandler_Create(t *testing.T) { + + t.Run("public", func(t *testing.T) { + + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.Create + }) + tester.RequireUser(t) + + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, &types.CreateDatasetReq{ + CreateRepoReq: types.CreateRepoReq{Private: true}, + }).Return(true, nil) + tester.mocks.dataset.EXPECT().Create(tester.ctx, &types.CreateDatasetReq{ + CreateRepoReq: types.CreateRepoReq{Private: true, Username: "u"}, + }).Return(&types.Dataset{Name: "d"}, nil) + tester.WithBody(t, &types.CreateDatasetReq{ + CreateRepoReq: types.CreateRepoReq{Private: true}, + }).Execute() + + tester.ResponseEqSimple(t, 200, gin.H{ + "data": &types.Dataset{Name: "d"}, + }) + }) + +} + +func TestDatasetHandler_Index(t *testing.T) { + cases := []struct { + sort string + source string + error bool + }{ + {"most_download", "local", false}, + {"foo", "local", true}, + {"most_download", "bar", true}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("%+v", c), func(t *testing.T) { + + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.Index + }) + + if !c.error { + tester.mocks.dataset.EXPECT().Index(tester.ctx, &types.RepoFilter{ + Search: "foo", + Sort: c.sort, + Source: c.source, + }, 10, 1).Return([]types.Dataset{ + {Name: "cc"}, + }, 100, nil) + } + + tester.AddPagination(1, 10).WithQuery("search", "foo"). + WithQuery("sort", c.sort). + WithQuery("source", c.source).Execute() + + if c.error { + require.Equal(t, 400, tester.response.Code) + } else { + tester.ResponseEqSimple(t, 200, gin.H{ + "data": []types.Dataset{{Name: "cc"}}, + "total": 100, + }) + } + }) + } +} + +func TestDatasetHandler_Update(t *testing.T) { + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.Update + }) + tester.RequireUser(t) + + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, &types.UpdateDatasetReq{}).Return(true, nil) + tester.mocks.dataset.EXPECT().Update(tester.ctx, &types.UpdateDatasetReq{ + UpdateRepoReq: types.UpdateRepoReq{ + Username: "u", + Namespace: "u", + Name: "r", + }, + }).Return(&types.Dataset{Name: "foo"}, nil) + tester.WithBody(t, &types.UpdateDatasetReq{ + UpdateRepoReq: types.UpdateRepoReq{Name: "r"}, + }).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Dataset{Name: "foo"}) +} + +func TestDatasetHandler_Delete(t *testing.T) { + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.Delete + }) + tester.RequireUser(t) + + tester.mocks.dataset.EXPECT().Delete(tester.ctx, "u", "r", "u").Return(nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestDatasetHandler_Show(t *testing.T) { + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.Show + }) + + tester.mocks.dataset.EXPECT().Show(tester.ctx, "u", "r", "u").Return(&types.Dataset{ + Name: "d", + }, nil) + tester.WithUser().Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Dataset{Name: "d"}) +} + +func TestDatasetHandler_Relations(t *testing.T) { + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.Relations + }) + + tester.mocks.dataset.EXPECT().Relations(tester.ctx, "u", "r", "u").Return(&types.Relations{ + Models: []*types.Model{{Name: "m"}}, + }, nil) + tester.WithUser().Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Relations{ + Models: []*types.Model{{Name: "m"}}, + }) +} + +func TestDatasetHandler_AllFiles(t *testing.T) { + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.AllFiles + }) + + tester.mocks.repo.EXPECT().AllFiles(tester.ctx, types.GetAllFilesReq{ + Namespace: "u", + Name: "r", + RepoType: types.DatasetRepo, + CurrentUser: "u", + }).Return([]*types.File{{Name: "f"}}, nil) + tester.WithUser().Execute() + + tester.ResponseEq(t, 200, tester.OKText, []*types.File{{Name: "f"}}) +} diff --git a/api/handler/discussion.go b/api/handler/discussion.go index 109c5553..c6cb0e4f 100644 --- a/api/handler/discussion.go +++ b/api/handler/discussion.go @@ -15,8 +15,8 @@ import ( ) type DiscussionHandler struct { - c component.DiscussionComponent - sc component.SensitiveComponent + discussion component.DiscussionComponent + sensitive component.SensitiveComponent } func NewDiscussionHandler(cfg *config.Config) (*DiscussionHandler, error) { @@ -26,8 +26,8 @@ func NewDiscussionHandler(cfg *config.Config) (*DiscussionHandler, error) { return nil, fmt.Errorf("failed to create sensitive component: %w", err) } return &DiscussionHandler{ - c: c, - sc: sc, + discussion: c, + sensitive: sc, }, nil } @@ -42,8 +42,8 @@ func NewDiscussionHandler(cfg *config.Config) (*DiscussionHandler, error) { // @Param repo_type path string true "repository type" Enums(models,datasets,codes,spaces) // @Param namespace path string true "namespace" // @Param name path string true "name" -// @Param body body component.CreateRepoDiscussionRequest true "body" -// @Success 200 {object} types.Response{data=component.CreateDiscussionResponse} "OK" +// @Param body body types.CreateRepoDiscussionRequest true "body" +// @Success 200 {object} types.Response{data=types.CreateDiscussionResponse} "OK" // @Failure 400 {object} types.APIBadRequest "Bad request" // @Failure 500 {object} types.APIInternalServerError "Internal server error" // @Router /{repo_type}/{namespace}/{name}/discussions [post] @@ -61,13 +61,13 @@ func (h *DiscussionHandler) CreateRepoDiscussion(ctx *gin.Context) { return } - var req component.CreateRepoDiscussionRequest + var req types.CreateRepoDiscussionRequest if err := ctx.ShouldBindJSON(&req); err != nil { httpbase.BadRequest(ctx, err.Error()) return } - _, err = h.sc.CheckRequestV2(ctx, &req) + _, err = h.sensitive.CheckRequestV2(ctx, &req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -78,7 +78,7 @@ func (h *DiscussionHandler) CreateRepoDiscussion(ctx *gin.Context) { req.RepoType = types.RepositoryType(repoType) req.Namespace = namespace req.Name = name - resp, err := h.c.CreateRepoDiscussion(ctx, req) + resp, err := h.discussion.CreateRepoDiscussion(ctx, req) if err != nil { slog.Error("Failed to create repo discussion", "error", err, "request", req) httpbase.ServerError(ctx, fmt.Errorf("failed to create repo discussion: %w", err)) @@ -96,7 +96,7 @@ func (h *DiscussionHandler) CreateRepoDiscussion(ctx *gin.Context) { // @Produce json // @Param id path string true "the discussion id" // @Param current_user query string true "current user, the owner" -// @Param body body component.UpdateDiscussionRequest true "body" +// @Param body body types.UpdateDiscussionRequest true "body" // @Success 200 {object} types.Response "OK" // @Failure 400 {object} types.APIBadRequest "Bad request" // @Failure 500 {object} types.APIInternalServerError "Internal server error" @@ -113,12 +113,12 @@ func (h *DiscussionHandler) UpdateDiscussion(ctx *gin.Context) { httpbase.BadRequest(ctx, "invalid discussion id:"+id) return } - var req component.UpdateDiscussionRequest + var req types.UpdateDiscussionRequest if err := ctx.ShouldBindJSON(&req); err != nil { httpbase.BadRequest(ctx, err.Error()) return } - _, err = h.sc.CheckRequestV2(ctx, &req) + _, err = h.sensitive.CheckRequestV2(ctx, &req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -127,7 +127,7 @@ func (h *DiscussionHandler) UpdateDiscussion(ctx *gin.Context) { req.ID = idInt req.CurrentUser = currentUser - err = h.c.UpdateDiscussion(ctx, req) + err = h.discussion.UpdateDiscussion(ctx, req) if err != nil { slog.Error("Failed to update discussion", "error", err, "request", req) httpbase.ServerError(ctx, fmt.Errorf("failed to update discussion: %w", err)) @@ -163,7 +163,7 @@ func (h *DiscussionHandler) DeleteDiscussion(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - err = h.c.DeleteDiscussion(ctx, currentUser, idInt) + err = h.discussion.DeleteDiscussion(ctx, currentUser, idInt) if err != nil { slog.Error("Failed to delete discussion", "error", err, "id", id) httpbase.ServerError(ctx, fmt.Errorf("failed to delete discussion: %w", err)) @@ -180,7 +180,7 @@ func (h *DiscussionHandler) DeleteDiscussion(ctx *gin.Context) { // @Accept json // @Produce json // @Param id path string true "the discussion id" -// @Success 200 {object} types.Response{data=component.ShowDiscussionResponse} "OK" +// @Success 200 {object} types.Response{data=types.ShowDiscussionResponse} "OK" // @Failure 400 {object} types.APIBadRequest "Bad request" // @Failure 500 {object} types.APIInternalServerError "Internal server error" // @Router /discussions/{id} [get] @@ -192,7 +192,7 @@ func (h *DiscussionHandler) ShowDiscussion(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - d, err := h.c.GetDiscussion(ctx, idInt) + d, err := h.discussion.GetDiscussion(ctx, idInt) if err != nil { slog.Error("Failed to get discussion", "error", err, "id", id) httpbase.ServerError(ctx, fmt.Errorf("failed to get discussion: %w", err)) @@ -212,7 +212,7 @@ func (h *DiscussionHandler) ShowDiscussion(ctx *gin.Context) { // @Param repo_type path string true "repository type" Enums(models,datasets,codes,spaces) // @Param namespace path string true "namespace" // @Param name query string true "name" -// @Success 200 {object} types.Response{data=component.ListRepoDiscussionResponse} "OK" +// @Success 200 {object} types.Response{data=types.ListRepoDiscussionResponse} "OK" // @Failure 400 {object} types.APIBadRequest "Bad request" // @Failure 500 {object} types.APIInternalServerError "Internal server error" // @Router /{repo_type}/{namespace}/{name}/discussions [get] @@ -225,12 +225,12 @@ func (h *DiscussionHandler) ListRepoDiscussions(ctx *gin.Context) { return } - var req component.ListRepoDiscussionRequest + var req types.ListRepoDiscussionRequest req.CurrentUser = currentUser req.RepoType = types.RepositoryType(repoType) req.Namespace = namespace req.Name = name - resp, err := h.c.ListRepoDiscussions(ctx, req) + resp, err := h.discussion.ListRepoDiscussions(ctx, req) if err != nil { slog.Error("Failed to list repo discussions", "error", err, "request", req) httpbase.ServerError(ctx, fmt.Errorf("failed to list repo discussions: %w", err)) @@ -247,8 +247,8 @@ func (h *DiscussionHandler) ListRepoDiscussions(ctx *gin.Context) { // @Accept json // @Produce json // @Param id path string true "the discussion id" -// @Param body body component.CreateCommentRequest true "body" -// @Success 200 {object} types.Response{data=component.CreateCommentResponse} "OK" +// @Param body body types.CreateCommentRequest true "body" +// @Success 200 {object} types.Response{data=types.CreateCommentResponse} "OK" // @Failure 400 {object} types.APIBadRequest "Bad request" // @Failure 500 {object} types.APIInternalServerError "Internal server error" // @Router /discussions/{id}/comments [post] @@ -265,12 +265,12 @@ func (h *DiscussionHandler) CreateDiscussionComment(ctx *gin.Context) { httpbase.BadRequest(ctx, fmt.Errorf("invalid discussion id: %w", err).Error()) return } - var req component.CreateCommentRequest + var req types.CreateCommentRequest if err := ctx.ShouldBindJSON(&req); err != nil { httpbase.BadRequest(ctx, err.Error()) return } - _, err = h.sc.CheckRequestV2(ctx, &req) + _, err = h.sensitive.CheckRequestV2(ctx, &req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -280,7 +280,7 @@ func (h *DiscussionHandler) CreateDiscussionComment(ctx *gin.Context) { req.CommentableID = idInt req.CurrentUser = currentUser - resp, err := h.c.CreateDiscussionComment(ctx, req) + resp, err := h.discussion.CreateDiscussionComment(ctx, req) if err != nil { slog.Error("Failed to create discussion comment", "error", err, "request", req) httpbase.ServerError(ctx, fmt.Errorf("failed to create discussion comment: %w", err)) @@ -298,7 +298,7 @@ func (h *DiscussionHandler) CreateDiscussionComment(ctx *gin.Context) { // @Produce json // @Param id path string true "the comment id" // @Param current_user query string true "current user, the owner of the comment" -// @Param body body component.UpdateCommentRequest true "body" +// @Param body body types.UpdateCommentRequest true "body" // @Success 200 {object} types.Response "OK" // @Failure 400 {object} types.APIBadRequest "Bad request" // @Failure 500 {object} types.APIInternalServerError "Internal server error" @@ -316,19 +316,19 @@ func (h *DiscussionHandler) UpdateComment(ctx *gin.Context) { return } - var req component.UpdateCommentRequest + var req types.UpdateCommentRequest if err := ctx.ShouldBindJSON(&req); err != nil { httpbase.BadRequest(ctx, err.Error()) return } - _, err = h.sc.CheckRequestV2(ctx, &req) + _, err = h.sensitive.CheckRequestV2(ctx, &req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) return } - err = h.c.UpdateComment(ctx, currentUser, idInt, req.Content) + err = h.discussion.UpdateComment(ctx, currentUser, idInt, req.Content) if err != nil { slog.Error("Failed to update comment", "error", err, "request", req) httpbase.ServerError(ctx, fmt.Errorf("failed to update comment: %w", err)) @@ -362,7 +362,7 @@ func (h *DiscussionHandler) DeleteComment(ctx *gin.Context) { httpbase.BadRequest(ctx, fmt.Errorf("invalid comment id: %w", err).Error()) return } - err = h.c.DeleteComment(ctx, currentUser, idInt) + err = h.discussion.DeleteComment(ctx, currentUser, idInt) if err != nil { slog.Error("Failed to delete comment", "error", err, "id", id) httpbase.ServerError(ctx, fmt.Errorf("failed to delete comment: %w", err)) @@ -379,7 +379,7 @@ func (h *DiscussionHandler) DeleteComment(ctx *gin.Context) { // @Accept json // @Produce json // @Param id path string true "the discussion id" -// @Success 200 {object} types.Response{data=[]component.DiscussionResponse_Comment} "OK" +// @Success 200 {object} types.Response{data=[]types.DiscussionResponse_Comment} "OK" // @Failure 400 {object} types.APIBadRequest "Bad request" // @Failure 500 {object} types.APIInternalServerError "Internal server error" // @Router /discussions/{id}/comments [get] @@ -389,7 +389,7 @@ func (h *DiscussionHandler) ListDiscussionComments(ctx *gin.Context) { if err != nil { httpbase.BadRequest(ctx, fmt.Errorf("invalid discussion id: %w", err).Error()) } - comments, err := h.c.ListDiscussionComments(ctx, idInt) + comments, err := h.discussion.ListDiscussionComments(ctx, idInt) if err != nil { slog.Error("Failed to list discussion comments", "error", err, "id", id) httpbase.ServerError(ctx, fmt.Errorf("failed to list discussion comments: %w", err)) diff --git a/api/handler/discussion_test.go b/api/handler/discussion_test.go new file mode 100644 index 00000000..0021b74d --- /dev/null +++ b/api/handler/discussion_test.go @@ -0,0 +1,212 @@ +package handler + +import ( + "testing" + + "github.com/gin-gonic/gin" + mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/common/types" +) + +type DiscussionTester struct { + *GinTester + handler *DiscussionHandler + mocks struct { + discussion *mockcomponent.MockDiscussionComponent + sensitive *mockcomponent.MockSensitiveComponent + } +} + +func NewDiscussionTester(t *testing.T) *DiscussionTester { + tester := &DiscussionTester{GinTester: NewGinTester()} + tester.mocks.discussion = mockcomponent.NewMockDiscussionComponent(t) + tester.mocks.sensitive = mockcomponent.NewMockSensitiveComponent(t) + + tester.handler = &DiscussionHandler{ + discussion: tester.mocks.discussion, + sensitive: tester.mocks.sensitive, + } + tester.WithParam("namespace", "u") + tester.WithParam("name", "r") + return tester +} + +func (t *DiscussionTester) WithHandleFunc(fn func(h *DiscussionHandler) gin.HandlerFunc) *DiscussionTester { + t.ginHandler = fn(t.handler) + return t +} + +func TestDiscussionHandler_CreateRepoDiscussion(t *testing.T) { + tester := NewDiscussionTester(t).WithHandleFunc(func(h *DiscussionHandler) gin.HandlerFunc { + return h.CreateRepoDiscussion + }) + tester.RequireUser(t) + + tester.mocks.sensitive.EXPECT().CheckRequestV2( + tester.ctx, &types.CreateRepoDiscussionRequest{Title: "foo"}, + ).Return(true, nil) + tester.mocks.discussion.EXPECT().CreateRepoDiscussion( + tester.ctx, types.CreateRepoDiscussionRequest{ + CurrentUser: "u", + Namespace: "u", + Name: "r", + RepoType: types.ModelRepo, + Title: "foo", + }, + ).Return(&types.CreateDiscussionResponse{ID: 123}, nil) + tester.WithParam("repo_type", "models").WithBody(t, &types.CreateRepoDiscussionRequest{ + Title: "foo", + }).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.CreateDiscussionResponse{ID: 123}) + +} + +func TestDiscussionHandler_UpdateDiscussion(t *testing.T) { + tester := NewDiscussionTester(t).WithHandleFunc(func(h *DiscussionHandler) gin.HandlerFunc { + return h.UpdateDiscussion + }) + tester.RequireUser(t) + + tester.mocks.sensitive.EXPECT().CheckRequestV2( + tester.ctx, &types.UpdateDiscussionRequest{Title: "foo"}, + ).Return(true, nil) + tester.mocks.discussion.EXPECT().UpdateDiscussion( + tester.ctx, types.UpdateDiscussionRequest{ + CurrentUser: "u", + ID: 1, + Title: "foo", + }, + ).Return(nil) + tester.WithParam("id", "1").WithBody(t, &types.UpdateDiscussionRequest{ + Title: "foo", + }).Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) + +} + +func TestDiscussionHandler_DeleteDiscussion(t *testing.T) { + tester := NewDiscussionTester(t).WithHandleFunc(func(h *DiscussionHandler) gin.HandlerFunc { + return h.DeleteDiscussion + }) + tester.RequireUser(t) + + tester.mocks.discussion.EXPECT().DeleteDiscussion( + tester.ctx, "u", int64(1), + ).Return(nil) + tester.WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) + +} + +func TestDiscussionHandler_ShowDiscussion(t *testing.T) { + tester := NewDiscussionTester(t).WithHandleFunc(func(h *DiscussionHandler) gin.HandlerFunc { + return h.ShowDiscussion + }) + + tester.mocks.discussion.EXPECT().GetDiscussion( + tester.ctx, int64(1), + ).Return(&types.ShowDiscussionResponse{Title: "foo"}, nil) + tester.WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.ShowDiscussionResponse{Title: "foo"}) + +} + +func TestDiscussionHandler_ListRepoDiscussions(t *testing.T) { + tester := NewDiscussionTester(t).WithHandleFunc(func(h *DiscussionHandler) gin.HandlerFunc { + return h.ListRepoDiscussions + }) + + tester.mocks.discussion.EXPECT().ListRepoDiscussions( + tester.ctx, types.ListRepoDiscussionRequest{ + CurrentUser: "u", + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + }, + ).Return(&types.ListRepoDiscussionResponse{Discussions: []*types.CreateDiscussionResponse{ + {ID: 1}, + }}, nil) + tester.WithUser().WithParam("repo_type", "models").Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.ListRepoDiscussionResponse{ + Discussions: []*types.CreateDiscussionResponse{{ID: 1}}, + }) + +} + +func TestDiscussionHandler_CreateDiscussionComment(t *testing.T) { + tester := NewDiscussionTester(t).WithHandleFunc(func(h *DiscussionHandler) gin.HandlerFunc { + return h.CreateDiscussionComment + }) + tester.RequireUser(t) + + tester.mocks.sensitive.EXPECT().CheckRequestV2( + tester.ctx, &types.CreateCommentRequest{Content: "foo"}, + ).Return(true, nil) + tester.mocks.discussion.EXPECT().CreateDiscussionComment( + tester.ctx, types.CreateCommentRequest{ + CurrentUser: "u", + Content: "foo", + CommentableID: 1, + }, + ).Return(&types.CreateCommentResponse{ID: 1}, nil) + tester.WithParam("id", "1").WithParam("repo_type", "models").WithBody( + t, &types.CreateCommentRequest{Content: "foo"}, + ).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.CreateCommentResponse{ID: 1}) + +} + +func TestDiscussionHandler_UpdateComment(t *testing.T) { + tester := NewDiscussionTester(t).WithHandleFunc(func(h *DiscussionHandler) gin.HandlerFunc { + return h.UpdateComment + }) + tester.RequireUser(t) + + tester.mocks.sensitive.EXPECT().CheckRequestV2( + tester.ctx, &types.UpdateCommentRequest{Content: "foo"}, + ).Return(true, nil) + tester.mocks.discussion.EXPECT().UpdateComment( + tester.ctx, "u", int64(1), "foo", + ).Return(nil) + tester.WithParam("id", "1").WithBody( + t, &types.UpdateCommentRequest{Content: "foo"}, + ).Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) + +} + +func TestDiscussionHandler_DeleteComment(t *testing.T) { + tester := NewDiscussionTester(t).WithHandleFunc(func(h *DiscussionHandler) gin.HandlerFunc { + return h.DeleteComment + }) + tester.RequireUser(t) + + tester.mocks.discussion.EXPECT().DeleteComment( + tester.ctx, "u", int64(1), + ).Return(nil) + tester.WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) + +} + +func TestDiscussionHandler_ListDiscussionComments(t *testing.T) { + tester := NewDiscussionTester(t).WithHandleFunc(func(h *DiscussionHandler) gin.HandlerFunc { + return h.ListDiscussionComments + }) + + tester.mocks.discussion.EXPECT().ListDiscussionComments( + tester.ctx, int64(1), + ).Return([]*types.DiscussionResponse_Comment{{Content: "foo"}}, nil) + tester.WithUser().WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, []*types.DiscussionResponse_Comment{{Content: "foo"}}) + +} diff --git a/api/handler/evaluation.go b/api/handler/evaluation.go index 74522220..808cb20c 100644 --- a/api/handler/evaluation.go +++ b/api/handler/evaluation.go @@ -23,14 +23,14 @@ func NewEvaluationHandler(config *config.Config) (*EvaluationHandler, error) { return nil, fmt.Errorf("error creating sensitive component:%w", err) } return &EvaluationHandler{ - c: wkf, - sc: sc, + evaluation: wkf, + sensitive: sc, }, nil } type EvaluationHandler struct { - c component.EvaluationComponent - sc component.SensitiveComponent + evaluation component.EvaluationComponent + sensitive component.SensitiveComponent } // create evaluation godoc @@ -59,14 +59,14 @@ func (h *EvaluationHandler) RunEvaluation(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - _, err := h.sc.CheckRequestV2(ctx, &req) + _, err := h.sensitive.CheckRequestV2(ctx, &req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) return } req.Username = currentUser - evaluation, err := h.c.CreateEvaluation(ctx, req) + evaluation, err := h.evaluation.CreateEvaluation(ctx, req) if err != nil { slog.Error("Failed to create evaluation job", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -101,7 +101,7 @@ func (h *EvaluationHandler) GetEvaluation(ctx *gin.Context) { var req = &types.EvaluationGetReq{} req.ID = id req.Username = currentUser - evaluation, err := h.c.GetEvaluation(ctx, *req) + evaluation, err := h.evaluation.GetEvaluation(ctx, *req) if err != nil { slog.Error("Failed to get evaluation job", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -137,7 +137,7 @@ func (h *EvaluationHandler) DeleteEvaluation(ctx *gin.Context) { var req = &types.EvaluationDelReq{} req.ID = id req.Username = currentUser - err = h.c.DeleteEvaluation(ctx, *req) + err = h.evaluation.DeleteEvaluation(ctx, *req) if err != nil { slog.Error("Failed to delete evaluation job", slog.Any("error", err)) httpbase.ServerError(ctx, err) diff --git a/api/handler/evaluation_test.go b/api/handler/evaluation_test.go new file mode 100644 index 00000000..dbee6632 --- /dev/null +++ b/api/handler/evaluation_test.go @@ -0,0 +1,85 @@ +package handler + +import ( + "testing" + + "github.com/gin-gonic/gin" + mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/common/types" +) + +type EvaluationTester struct { + *GinTester + handler *EvaluationHandler + mocks struct { + evaluation *mockcomponent.MockEvaluationComponent + sensitive *mockcomponent.MockSensitiveComponent + } +} + +func NewEvaluationTester(t *testing.T) *EvaluationTester { + tester := &EvaluationTester{GinTester: NewGinTester()} + tester.mocks.evaluation = mockcomponent.NewMockEvaluationComponent(t) + tester.mocks.sensitive = mockcomponent.NewMockSensitiveComponent(t) + + tester.handler = &EvaluationHandler{ + evaluation: tester.mocks.evaluation, + sensitive: tester.mocks.sensitive, + } + tester.WithParam("namespace", "u") + tester.WithParam("name", "r") + return tester +} + +func (t *EvaluationTester) WithHandleFunc(fn func(h *EvaluationHandler) gin.HandlerFunc) *EvaluationTester { + t.ginHandler = fn(t.handler) + return t +} + +func TestEvaluationHandler_Run(t *testing.T) { + tester := NewEvaluationTester(t).WithHandleFunc(func(h *EvaluationHandler) gin.HandlerFunc { + return h.RunEvaluation + }) + tester.RequireUser(t) + + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, &types.EvaluationReq{}).Return(true, nil) + tester.mocks.evaluation.EXPECT().CreateEvaluation(tester.ctx, types.EvaluationReq{ + Username: "u", + }).Return(&types.ArgoWorkFlowRes{ID: 1}, nil) + tester.WithBody(t, &types.EvaluationReq{}).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.ArgoWorkFlowRes{ID: 1}) + +} + +func TestEvaluationHandler_Get(t *testing.T) { + tester := NewEvaluationTester(t).WithHandleFunc(func(h *EvaluationHandler) gin.HandlerFunc { + return h.GetEvaluation + }) + tester.RequireUser(t) + + tester.mocks.evaluation.EXPECT().GetEvaluation(tester.ctx, types.EvaluationGetReq{ + Username: "u", + ID: 1, + }).Return(&types.EvaluationRes{ID: 1}, nil) + tester.WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.EvaluationRes{ID: 1}) + +} + +func TestEvaluationHandler_Delete(t *testing.T) { + tester := NewEvaluationTester(t).WithHandleFunc(func(h *EvaluationHandler) gin.HandlerFunc { + return h.DeleteEvaluation + }) + tester.RequireUser(t) + + tester.mocks.evaluation.EXPECT().DeleteEvaluation(tester.ctx, types.EvaluationGetReq{ + Username: "u", + ID: 1, + }).Return(nil) + tester.WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) + +} diff --git a/api/handler/internal.go b/api/handler/internal.go index 020206f7..12ada5cc 100644 --- a/api/handler/internal.go +++ b/api/handler/internal.go @@ -20,19 +20,21 @@ func NewInternalHandler(config *config.Config) (*InternalHandler, error) { return nil, err } return &InternalHandler{ - c: uc, - config: config, + internal: uc, + config: config, + workflowClient: workflow.GetWorkflowClient(), }, nil } type InternalHandler struct { - c component.InternalComponent - config *config.Config + internal component.InternalComponent + config *config.Config + workflowClient client.Client } // TODO: add prmission check func (h *InternalHandler) Allowed(ctx *gin.Context) { - allowed, err := h.c.Allowed(ctx) + allowed, err := h.internal.Allowed(ctx) if err != nil { httpbase.ServerError(ctx, err) return @@ -67,7 +69,7 @@ func (h *InternalHandler) SSHAllowed(ctx *gin.Context) { req.Protocol = rawReq.Protocol req.CheckIP = rawReq.CheckIP - resp, err := h.c.SSHAllowed(ctx, req) + resp, err := h.internal.SSHAllowed(ctx, req) if err != nil { httpbase.ServerError(ctx, err) return @@ -90,7 +92,7 @@ func (h *InternalHandler) LfsAuthenticate(ctx *gin.Context) { return } req.RepoType, req.Namespace, req.Name = getRepoInfoFronClonePath(req.Repo) - resp, err := h.c.LfsAuthenticate(ctx, req) + resp, err := h.internal.LfsAuthenticate(ctx, req) if err != nil { httpbase.ServerError(ctx, err) return @@ -128,7 +130,7 @@ func (h *InternalHandler) PostReceive(ctx *gin.Context) { Ref: ref, RepoType: types.RepositoryType(strings.TrimSuffix(paths[0], "s")), } - callback, err := h.c.GetCommitDiff(ctx, diffReq) + callback, err := h.internal.GetCommitDiff(ctx, diffReq) if err != nil { slog.Error("post receive: failed to get commit diff", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -136,7 +138,7 @@ func (h *InternalHandler) PostReceive(ctx *gin.Context) { } callback.Ref = originalRef //start workflow to handle push request - workflowClient := workflow.GetWorkflowClient() + workflowClient := h.workflowClient workflowOptions := client.StartWorkflowOptions{ TaskQueue: workflow.HandlePushQueueName, } @@ -165,7 +167,7 @@ func (h *InternalHandler) PostReceive(ctx *gin.Context) { func (h *InternalHandler) GetAuthorizedKeys(ctx *gin.Context) { key := ctx.Query("key") - sshKey, err := h.c.GetAuthorizedKeys(ctx, key) + sshKey, err := h.internal.GetAuthorizedKeys(ctx, key) if err != nil { slog.Error("failed to get authorize keys", slog.Any("error", err)) httpbase.ServerError(ctx, err) diff --git a/api/handler/internal_test.go b/api/handler/internal_test.go new file mode 100644 index 00000000..82446dbe --- /dev/null +++ b/api/handler/internal_test.go @@ -0,0 +1,191 @@ +package handler + +import ( + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/mock" + "go.temporal.io/sdk/client" + temporal_mock "go.temporal.io/sdk/mocks" + mock_temporal "opencsg.com/csghub-server/_mocks/go.temporal.io/sdk/client" + mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/api/workflow" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" +) + +type InternalTester struct { + *GinTester + handler *InternalHandler + mocks struct { + internal *mockcomponent.MockInternalComponent + workflowClient *mock_temporal.MockClient + } +} + +func NewInternalTester(t *testing.T) *InternalTester { + tester := &InternalTester{GinTester: NewGinTester()} + tester.mocks.internal = mockcomponent.NewMockInternalComponent(t) + tester.mocks.workflowClient = mock_temporal.NewMockClient(t) + + tester.handler = &InternalHandler{ + internal: tester.mocks.internal, + workflowClient: tester.mocks.workflowClient, + config: &config.Config{}, + } + tester.WithParam("internalId", "testInternalId") + tester.WithParam("userId", "testUserId") + return tester +} + +func (t *InternalTester) WithHandleFunc(fn func(h *InternalHandler) gin.HandlerFunc) *InternalTester { + t.ginHandler = fn(t.handler) + return t +} + +func TestInternalHandler_Allowed(t *testing.T) { + tester := NewInternalTester(t).WithHandleFunc(func(h *InternalHandler) gin.HandlerFunc { + return h.Allowed + }) + + tester.mocks.internal.EXPECT().Allowed(tester.ctx).Return(true, nil) + tester.Execute() + + tester.ResponseEqSimple(t, 200, gin.H{ + "status": true, + "message": "allowed", + }) +} + +func TestInternalHandler_SSHAllowed(t *testing.T) { + t.Run("https", func(t *testing.T) { + tester := NewInternalTester(t).WithHandleFunc(func(h *InternalHandler) gin.HandlerFunc { + return h.SSHAllowed + }) + + tester.WithBody(t, &types.GitalyAllowedReq{ + Protocol: "https", + }).Execute() + + tester.ResponseEqSimple(t, 200, gin.H{ + "status": true, + "message": "allowed", + }) + }) + + t.Run("ssh", func(t *testing.T) { + tester := NewInternalTester(t).WithHandleFunc(func(h *InternalHandler) gin.HandlerFunc { + return h.SSHAllowed + }) + + tester.mocks.internal.EXPECT().SSHAllowed(tester.ctx, types.SSHAllowedReq{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + Action: "act", + Changes: "c", + KeyID: "k", + Protocol: "ssh", + CheckIP: "ci", + }).Return(&types.SSHAllowedResp{Message: "msg"}, nil) + tester.WithHeader("Content-Type", "application/json").WithBody(t, &types.GitalyAllowedReq{ + Protocol: "ssh", + GlRepository: "models/u/r", + Action: "act", + KeyID: "k", + Changes: "c", + CheckIP: "ci", + }).Execute() + + tester.ResponseEqSimple(t, 200, &types.SSHAllowedResp{Message: "msg"}) + }) + +} + +func TestInternalHandler_LfsAuthenticate(t *testing.T) { + tester := NewInternalTester(t).WithHandleFunc(func(h *InternalHandler) gin.HandlerFunc { + return h.LfsAuthenticate + }) + + tester.mocks.internal.EXPECT().LfsAuthenticate(tester.ctx, types.LfsAuthenticateReq{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", + Repo: "models/u/r", + }).Return(&types.LfsAuthenticateResp{LfsToken: "t"}, nil) + tester.WithHeader("Content-Type", "application/json").WithBody(t, &types.LfsAuthenticateReq{ + Repo: "models/u/r", + }).Execute() + + tester.ResponseEqSimple(t, 200, &types.LfsAuthenticateResp{LfsToken: "t"}) +} + +func TestInternalHandler_PreReceive(t *testing.T) { + tester := NewInternalTester(t).WithHandleFunc(func(h *InternalHandler) gin.HandlerFunc { + return h.PreReceive + }) + + tester.Execute() + + tester.ResponseEqSimple(t, 200, gin.H{ + "reference_counter_increased": true, + }) +} + +func TestInternalHandler_PostReceive(t *testing.T) { + tester := NewInternalTester(t).WithHandleFunc(func(h *InternalHandler) gin.HandlerFunc { + return h.PostReceive + }) + + tester.mocks.internal.EXPECT().GetCommitDiff(tester.ctx, types.GetDiffBetweenTwoCommitsReq{ + LeftCommitId: "foo", + RightCommitId: "bar", + Namespace: "u", + Name: "r", + Ref: "main", + RepoType: types.ModelRepo, + }).Return(&types.GiteaCallbackPushReq{Ref: "foo"}, nil) + + runMock := &temporal_mock.WorkflowRun{} + runMock.On("GetID").Return("id") + tester.mocks.workflowClient.EXPECT().ExecuteWorkflow( + tester.ctx, client.StartWorkflowOptions{ + TaskQueue: workflow.HandlePushQueueName, + }, mock.Anything, + &types.GiteaCallbackPushReq{Ref: "ref/heads/main"}, &config.Config{}, + ).Return( + runMock, nil, + ) + tester.WithHeader("Content-Type", "application/json").WithBody(t, &types.PostReceiveReq{ + Changes: "foo bar ref/heads/main\n", + GlRepository: "models/u/r", + }).Execute() + + tester.ResponseEqSimple(t, 200, gin.H{ + "reference_counter_decreased": true, + "messages": []Messages{ + { + Message: "Welcome to OpenCSG!", + Type: "alert", + }, + }, + }) +} + +func TestInternalHandler_GetAuthorizedKeys(t *testing.T) { + tester := NewInternalTester(t).WithHandleFunc(func(h *InternalHandler) gin.HandlerFunc { + return h.GetAuthorizedKeys + }) + + tester.mocks.internal.EXPECT().GetAuthorizedKeys(tester.ctx, "k").Return(&database.SSHKey{ + ID: 1, + Content: "kk", + }, nil) + tester.WithQuery("key", "k").Execute() + + tester.ResponseEqSimple(t, 200, gin.H{ + "id": int64(1), + "key": "kk", + }) +} diff --git a/api/handler/mirror.go b/api/handler/mirror.go index bc74824b..a310874b 100644 --- a/api/handler/mirror.go +++ b/api/handler/mirror.go @@ -18,12 +18,12 @@ func NewMirrorHandler(config *config.Config) (*MirrorHandler, error) { return nil, err } return &MirrorHandler{ - mc: mc, + mirror: mc, }, nil } type MirrorHandler struct { - mc component.MirrorComponent + mirror component.MirrorComponent } // CreateMirrorRepo godoc @@ -38,19 +38,21 @@ type MirrorHandler struct { // @Failure 500 {object} types.APIInternalServerError "Internal server error" // @Router /mirror/repo [post] func (h *MirrorHandler) CreateMirrorRepo(ctx *gin.Context) { + currentUser := httpbase.GetCurrentUser(ctx) + if currentUser == "" { + httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) + return + } + var req types.CreateMirrorRepoReq err := ctx.ShouldBindJSON(&req) if err != nil { httpbase.BadRequest(ctx, err.Error()) return } - currentUser := httpbase.GetCurrentUser(ctx) - if currentUser == "" { - httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) - return - } + req.CurrentUser = currentUser - m, err := h.mc.CreateMirrorRepo(ctx, req) + m, err := h.mirror.CreateMirrorRepo(ctx, req) if err != nil { slog.Error("failed to create mirror repo", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -74,18 +76,20 @@ func (h *MirrorHandler) CreateMirrorRepo(ctx *gin.Context) { // @Failure 500 {object} types.APIInternalServerError "Internal server error" // @Router /mirror/repos [get] func (h *MirrorHandler) Repos(ctx *gin.Context) { + currentUser := httpbase.GetCurrentUser(ctx) + if currentUser == "" { + httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) + return + } + per, page, err := common.GetPerAndPageFromContext(ctx) if err != nil { slog.Error("Bad request format", "error", err) httpbase.BadRequest(ctx, err.Error()) return } - currentUser := httpbase.GetCurrentUser(ctx) - if currentUser == "" { - httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) - return - } - repos, total, err := h.mc.Repos(ctx, currentUser, per, page) + + repos, total, err := h.mirror.Repos(ctx, currentUser, per, page) if err != nil { slog.Error("failed to get mirror repos", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -112,19 +116,21 @@ func (h *MirrorHandler) Repos(ctx *gin.Context) { // @Failure 500 {object} types.APIInternalServerError "Internal server error" // @Router /mirrors [get] func (h *MirrorHandler) Index(ctx *gin.Context) { + currentUser := httpbase.GetCurrentUser(ctx) + if currentUser == "" { + httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) + return + } + per, page, err := common.GetPerAndPageFromContext(ctx) if err != nil { slog.Error("Bad request format", "error", err) httpbase.BadRequest(ctx, err.Error()) return } - currentUser := httpbase.GetCurrentUser(ctx) - if currentUser == "" { - httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) - return - } + search := ctx.Query("search") - repos, total, err := h.mc.Index(ctx, currentUser, per, page, search) + repos, total, err := h.mirror.Index(ctx, currentUser, per, page, search) if err != nil { slog.Error("failed to get mirror repos", slog.Any("error", err)) httpbase.ServerError(ctx, err) diff --git a/api/handler/mirror_source.go b/api/handler/mirror_source.go index b55406b8..89d54629 100644 --- a/api/handler/mirror_source.go +++ b/api/handler/mirror_source.go @@ -18,12 +18,12 @@ func NewMirrorSourceHandler(config *config.Config) (*MirrorSourceHandler, error) return nil, err } return &MirrorSourceHandler{ - c: c, + mirrorSource: c, }, nil } type MirrorSourceHandler struct { - c component.MirrorSourceComponent + mirrorSource component.MirrorSourceComponent } // CreateMirrorSource godoc @@ -38,19 +38,21 @@ type MirrorSourceHandler struct { // @Failure 500 {object} types.APIInternalServerError "Internal server error" // @Router /mirror/sources [post] func (h *MirrorSourceHandler) Create(ctx *gin.Context) { + currentUser := httpbase.GetCurrentUser(ctx) + if currentUser == "" { + httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) + return + } + var msReq types.CreateMirrorSourceReq if err := ctx.ShouldBindJSON(&msReq); err != nil { slog.Error("Bad request format", "error", err) httpbase.BadRequest(ctx, err.Error()) return } - currentUser := httpbase.GetCurrentUser(ctx) - if currentUser == "" { - httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) - return - } + msReq.CurrentUser = currentUser - ms, err := h.c.Create(ctx, msReq) + ms, err := h.mirrorSource.Create(ctx, msReq) if err != nil { slog.Error("Failed to create mirror source", "error", err) httpbase.ServerError(ctx, err) @@ -75,7 +77,7 @@ func (h *MirrorSourceHandler) Index(ctx *gin.Context) { httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) return } - ms, err := h.c.Index(ctx, currentUser) + ms, err := h.mirrorSource.Index(ctx, currentUser) if err != nil { slog.Error("Failed to get mirror sources", "error", err) httpbase.ServerError(ctx, err) @@ -97,6 +99,12 @@ func (h *MirrorSourceHandler) Index(ctx *gin.Context) { // @Failure 500 {object} types.APIInternalServerError "Internal server error" // @Router /mirror/sources/{id} [put] func (h *MirrorSourceHandler) Update(ctx *gin.Context) { + currentUser := httpbase.GetCurrentUser(ctx) + if currentUser == "" { + httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) + return + } + var msReq types.UpdateMirrorSourceReq var msId int64 id := ctx.Param("id") @@ -118,13 +126,8 @@ func (h *MirrorSourceHandler) Update(ctx *gin.Context) { return } msReq.ID = msId - currentUser := httpbase.GetCurrentUser(ctx) - if currentUser == "" { - httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) - return - } msReq.CurrentUser = currentUser - ms, err := h.c.Update(ctx, msReq) + ms, err := h.mirrorSource.Update(ctx, msReq) if err != nil { slog.Error("Failed to get mirror sources", "error", err) httpbase.ServerError(ctx, err) @@ -145,6 +148,12 @@ func (h *MirrorSourceHandler) Update(ctx *gin.Context) { // @Failure 500 {object} types.APIInternalServerError "Internal server error" // @Router /mirror/sources/{id} [get] func (h *MirrorSourceHandler) Get(ctx *gin.Context) { + currentUser := httpbase.GetCurrentUser(ctx) + if currentUser == "" { + httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) + return + } + var msId int64 id := ctx.Param("id") if id == "" { @@ -159,12 +168,8 @@ func (h *MirrorSourceHandler) Get(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - currentUser := httpbase.GetCurrentUser(ctx) - if currentUser == "" { - httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) - return - } - ms, err := h.c.Get(ctx, msId, currentUser) + + ms, err := h.mirrorSource.Get(ctx, msId, currentUser) if err != nil { slog.Error("Failed to get mirror source", "error", err) httpbase.ServerError(ctx, err) @@ -185,6 +190,12 @@ func (h *MirrorSourceHandler) Get(ctx *gin.Context) { // @Failure 500 {object} types.APIInternalServerError "Internal server error" // @Router /mirror/sources/{id} [delete] func (h *MirrorSourceHandler) Delete(ctx *gin.Context) { + currentUser := httpbase.GetCurrentUser(ctx) + if currentUser == "" { + httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) + return + } + var msId int64 id := ctx.Param("id") if id == "" { @@ -199,12 +210,8 @@ func (h *MirrorSourceHandler) Delete(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - currentUser := httpbase.GetCurrentUser(ctx) - if currentUser == "" { - httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) - return - } - err = h.c.Delete(ctx, msId, currentUser) + + err = h.mirrorSource.Delete(ctx, msId, currentUser) if err != nil { slog.Error("Failed to delete mirror source", "error", err) httpbase.ServerError(ctx, err) diff --git a/api/handler/mirror_source_test.go b/api/handler/mirror_source_test.go new file mode 100644 index 00000000..e427e125 --- /dev/null +++ b/api/handler/mirror_source_test.go @@ -0,0 +1,104 @@ +package handler + +import ( + "testing" + + "github.com/gin-gonic/gin" + mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +type MirrorSourceTester struct { + *GinTester + handler *MirrorSourceHandler + mocks struct { + mirrorSource *mockcomponent.MockMirrorSourceComponent + } +} + +func NewMirrorSourceTester(t *testing.T) *MirrorSourceTester { + tester := &MirrorSourceTester{GinTester: NewGinTester()} + tester.mocks.mirrorSource = mockcomponent.NewMockMirrorSourceComponent(t) + + tester.handler = &MirrorSourceHandler{ + mirrorSource: tester.mocks.mirrorSource, + } + tester.WithParam("namespace", "u") + tester.WithParam("name", "r") + return tester +} + +func (t *MirrorSourceTester) WithHandleFunc(fn func(h *MirrorSourceHandler) gin.HandlerFunc) *MirrorSourceTester { + t.ginHandler = fn(t.handler) + return t +} + +func TestMirrorSourceHandler_Create(t *testing.T) { + tester := NewMirrorSourceTester(t).WithHandleFunc(func(h *MirrorSourceHandler) gin.HandlerFunc { + return h.Create + }) + tester.RequireUser(t) + + tester.mocks.mirrorSource.EXPECT().Create(tester.ctx, types.CreateMirrorSourceReq{ + SourceName: "sn", + CurrentUser: "u", + }).Return(&database.MirrorSource{ID: 1}, nil) + tester.WithBody(t, &types.CreateMirrorSourceReq{SourceName: "sn"}).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &database.MirrorSource{ID: 1}) +} + +func TestMirrorSourceHandler_Index(t *testing.T) { + tester := NewMirrorSourceTester(t).WithHandleFunc(func(h *MirrorSourceHandler) gin.HandlerFunc { + return h.Index + }) + tester.RequireUser(t) + + tester.mocks.mirrorSource.EXPECT().Index(tester.ctx, "u").Return([]database.MirrorSource{{ID: 1}}, nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, []database.MirrorSource{{ID: 1}}) +} + +func TestMirrorSourceHandler_Update(t *testing.T) { + tester := NewMirrorSourceTester(t).WithHandleFunc(func(h *MirrorSourceHandler) gin.HandlerFunc { + return h.Update + }) + tester.RequireUser(t) + + tester.mocks.mirrorSource.EXPECT().Update(tester.ctx, types.UpdateMirrorSourceReq{ + ID: 1, + CurrentUser: "u", + SourceName: "sn", + }).Return(&database.MirrorSource{ID: 1}, nil) + tester.WithBody(t, &types.UpdateMirrorSourceReq{ + SourceName: "sn", + }).WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, &database.MirrorSource{ID: 1}) +} + +func TestMirrorSourceHandler_Get(t *testing.T) { + tester := NewMirrorSourceTester(t).WithHandleFunc(func(h *MirrorSourceHandler) gin.HandlerFunc { + return h.Get + }) + tester.RequireUser(t) + + tester.mocks.mirrorSource.EXPECT().Get(tester.ctx, int64(1), "u").Return(&database.MirrorSource{ID: 1}, nil) + tester.WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, &database.MirrorSource{ID: 1}) +} + +func TestMirrorSourceHandler_Delete(t *testing.T) { + tester := NewMirrorSourceTester(t).WithHandleFunc(func(h *MirrorSourceHandler) gin.HandlerFunc { + return h.Delete + }) + tester.RequireUser(t) + + tester.mocks.mirrorSource.EXPECT().Delete(tester.ctx, int64(1), "u").Return(nil) + tester.WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} diff --git a/api/handler/mirror_test.go b/api/handler/mirror_test.go new file mode 100644 index 00000000..b18260aa --- /dev/null +++ b/api/handler/mirror_test.go @@ -0,0 +1,95 @@ +package handler + +import ( + "testing" + + "github.com/gin-gonic/gin" + mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +type MirrorTester struct { + *GinTester + handler *MirrorHandler + mocks struct { + mirror *mockcomponent.MockMirrorComponent + } +} + +func NewMirrorTester(t *testing.T) *MirrorTester { + tester := &MirrorTester{GinTester: NewGinTester()} + tester.mocks.mirror = mockcomponent.NewMockMirrorComponent(t) + + tester.handler = &MirrorHandler{ + mirror: tester.mocks.mirror, + } + tester.WithParam("mirrorId", "testMirrorId") + tester.WithParam("userId", "testUserId") + return tester +} + +func (t *MirrorTester) WithHandleFunc(fn func(h *MirrorHandler) gin.HandlerFunc) *MirrorTester { + t.ginHandler = fn(t.handler) + return t +} + +func TestMirrorHandler_CreateMirrorRepo(t *testing.T) { + tester := NewMirrorTester(t).WithHandleFunc(func(h *MirrorHandler) gin.HandlerFunc { + return h.CreateMirrorRepo + }) + tester.RequireUser(t) + + tester.mocks.mirror.EXPECT().CreateMirrorRepo(tester.ctx, types.CreateMirrorRepoReq{ + SourceNamespace: "ns", + SourceName: "sn", + CurrentUser: "u", + MirrorSourceID: 1, + RepoType: types.ModelRepo, + SourceGitCloneUrl: "url", + }).Return(&database.Mirror{}, nil) + tester.WithBody(t, &types.CreateMirrorRepoReq{ + SourceNamespace: "ns", + SourceName: "sn", + MirrorSourceID: 1, + RepoType: types.ModelRepo, + SourceGitCloneUrl: "url", + }).Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) + +} + +func TestMirrorHandler_Repos(t *testing.T) { + tester := NewMirrorTester(t).WithHandleFunc(func(h *MirrorHandler) gin.HandlerFunc { + return h.Repos + }) + tester.RequireUser(t) + + tester.mocks.mirror.EXPECT().Repos(tester.ctx, "u", 10, 1).Return( + []types.MirrorRepo{{Path: "p"}}, 100, nil, + ) + tester.AddPagination(1, 10).Execute() + + tester.ResponseEq(t, 200, tester.OKText, gin.H{ + "data": []types.MirrorRepo{{Path: "p"}}, + "total": 100, + }) +} + +func TestMirrorHandler_Index(t *testing.T) { + tester := NewMirrorTester(t).WithHandleFunc(func(h *MirrorHandler) gin.HandlerFunc { + return h.Index + }) + tester.RequireUser(t) + + tester.mocks.mirror.EXPECT().Index(tester.ctx, "u", 10, 1, "foo").Return( + []types.Mirror{{SourceUrl: "p"}}, 100, nil, + ) + tester.AddPagination(1, 10).WithQuery("search", "foo").Execute() + + tester.ResponseEq(t, 200, tester.OKText, gin.H{ + "data": []types.Mirror{{SourceUrl: "p"}}, + "total": 100, + }) +} diff --git a/api/handler/organization.go b/api/handler/organization.go index 69caf46e..2df1fef2 100644 --- a/api/handler/organization.go +++ b/api/handler/organization.go @@ -38,22 +38,22 @@ func NewOrganizationHandler(config *config.Config) (*OrganizationHandler, error) return nil, err } return &OrganizationHandler{ - sc: sc, - cc: cc, - mc: mc, - dsc: dsc, - colc: colc, - pc: pc, + space: sc, + code: cc, + model: mc, + dataset: dsc, + collection: colc, + prompt: pc, }, nil } type OrganizationHandler struct { - sc component.SpaceComponent - cc component.CodeComponent - mc component.ModelComponent - dsc component.DatasetComponent - colc component.CollectionComponent - pc component.PromptComponent + space component.SpaceComponent + code component.CodeComponent + model component.ModelComponent + dataset component.DatasetComponent + collection component.CollectionComponent + prompt component.PromptComponent } // GetOrganizationModels godoc @@ -84,7 +84,7 @@ func (h *OrganizationHandler) Models(ctx *gin.Context) { } req.Page = page req.PageSize = per - models, total, err := h.mc.OrgModels(ctx, &req) + models, total, err := h.model.OrgModels(ctx, &req) if err != nil { slog.Error("Failed to get org models", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -129,7 +129,7 @@ func (h *OrganizationHandler) Datasets(ctx *gin.Context) { } req.Page = page req.PageSize = per - datasets, total, err := h.dsc.OrgDatasets(ctx, &req) + datasets, total, err := h.dataset.OrgDatasets(ctx, &req) if err != nil { slog.Error("Failed to get org datasets", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -173,7 +173,7 @@ func (h *OrganizationHandler) Codes(ctx *gin.Context) { } req.Page = page req.PageSize = per - datasets, total, err := h.cc.OrgCodes(ctx, &req) + datasets, total, err := h.code.OrgCodes(ctx, &req) if err != nil { slog.Error("Failed to get org codes", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -217,7 +217,7 @@ func (h *OrganizationHandler) Spaces(ctx *gin.Context) { } req.Page = page req.PageSize = per - datasets, total, err := h.sc.OrgSpaces(ctx, &req) + datasets, total, err := h.space.OrgSpaces(ctx, &req) if err != nil { slog.Error("Failed to get org spaces", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -261,7 +261,7 @@ func (h *OrganizationHandler) Collections(ctx *gin.Context) { } req.Page = page req.PageSize = per - datasets, total, err := h.colc.OrgCollections(ctx, &req) + datasets, total, err := h.collection.OrgCollections(ctx, &req) if err != nil { slog.Error("Failed to get org collections", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -304,7 +304,7 @@ func (h *OrganizationHandler) Prompts(ctx *gin.Context) { } req.Page = page req.PageSize = per - prompts, total, err := h.pc.OrgPrompts(ctx, &req) + prompts, total, err := h.prompt.OrgPrompts(ctx, &req) if err != nil { slog.Error("Failed to get org prompts", slog.Any("error", err)) httpbase.ServerError(ctx, err) diff --git a/api/handler/organization_test.go b/api/handler/organization_test.go new file mode 100644 index 00000000..fb0b6ee6 --- /dev/null +++ b/api/handler/organization_test.go @@ -0,0 +1,175 @@ +package handler + +import ( + "testing" + + "github.com/gin-gonic/gin" + mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/common/types" +) + +type OrganizationTester struct { + *GinTester + handler *OrganizationHandler + mocks struct { + space *mockcomponent.MockSpaceComponent + code *mockcomponent.MockCodeComponent + model *mockcomponent.MockModelComponent + dataset *mockcomponent.MockDatasetComponent + collection *mockcomponent.MockCollectionComponent + prompt *mockcomponent.MockPromptComponent + } +} + +func NewOrganizationTester(t *testing.T) *OrganizationTester { + tester := &OrganizationTester{GinTester: NewGinTester()} + tester.mocks.space = mockcomponent.NewMockSpaceComponent(t) + tester.mocks.code = mockcomponent.NewMockCodeComponent(t) + tester.mocks.model = mockcomponent.NewMockModelComponent(t) + tester.mocks.dataset = mockcomponent.NewMockDatasetComponent(t) + tester.mocks.collection = mockcomponent.NewMockCollectionComponent(t) + tester.mocks.prompt = mockcomponent.NewMockPromptComponent(t) + + tester.handler = &OrganizationHandler{ + space: tester.mocks.space, + code: tester.mocks.code, + model: tester.mocks.model, + dataset: tester.mocks.dataset, + collection: tester.mocks.collection, + prompt: tester.mocks.prompt, + } + tester.WithParam("namespace", "u") + tester.WithParam("name", "n") + return tester +} + +func (t *OrganizationTester) WithHandleFunc(fn func(h *OrganizationHandler) gin.HandlerFunc) *OrganizationTester { + t.ginHandler = fn(t.handler) + return t +} + +func TestOrganizationHandler_Models(t *testing.T) { + tester := NewOrganizationTester(t).WithHandleFunc(func(h *OrganizationHandler) gin.HandlerFunc { + return h.Models + }) + + tester.mocks.model.EXPECT().OrgModels(tester.ctx, &types.OrgModelsReq{ + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + Namespace: "u", + CurrentUser: "u", + }).Return([]types.Model{{Name: "m"}}, 100, nil) + tester.WithUser().AddPagination(1, 10).Execute() + tester.ResponseEqSimple(t, 200, gin.H{ + "message": "OK", + "data": []types.Model{{Name: "m"}}, + "total": 100, + }) +} + +func TestOrganizationHandler_Datasets(t *testing.T) { + tester := NewOrganizationTester(t).WithHandleFunc(func(h *OrganizationHandler) gin.HandlerFunc { + return h.Datasets + }) + + tester.mocks.dataset.EXPECT().OrgDatasets(tester.ctx, &types.OrgDatasetsReq{ + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + Namespace: "u", + CurrentUser: "u", + }).Return([]types.Dataset{{Name: "m"}}, 100, nil) + tester.WithUser().AddPagination(1, 10).Execute() + tester.ResponseEqSimple(t, 200, gin.H{ + "message": "OK", + "data": []types.Dataset{{Name: "m"}}, + "total": 100, + }) +} + +func TestOrganizationHandler_Codes(t *testing.T) { + tester := NewOrganizationTester(t).WithHandleFunc(func(h *OrganizationHandler) gin.HandlerFunc { + return h.Codes + }) + + tester.mocks.code.EXPECT().OrgCodes(tester.ctx, &types.OrgCodesReq{ + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + Namespace: "u", + CurrentUser: "u", + }).Return([]types.Code{{Name: "m"}}, 100, nil) + tester.WithUser().AddPagination(1, 10).Execute() + tester.ResponseEqSimple(t, 200, gin.H{ + "message": "OK", + "data": []types.Code{{Name: "m"}}, + "total": 100, + }) +} + +func TestOrganizationHandler_Spaces(t *testing.T) { + tester := NewOrganizationTester(t).WithHandleFunc(func(h *OrganizationHandler) gin.HandlerFunc { + return h.Spaces + }) + + tester.mocks.space.EXPECT().OrgSpaces(tester.ctx, &types.OrgSpacesReq{ + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + Namespace: "u", + CurrentUser: "u", + }).Return([]types.Space{{Name: "m"}}, 100, nil) + tester.WithUser().AddPagination(1, 10).Execute() + tester.ResponseEqSimple(t, 200, gin.H{ + "message": "OK", + "data": []types.Space{{Name: "m"}}, + "total": 100, + }) +} + +func TestOrganizationHandler_Collections(t *testing.T) { + tester := NewOrganizationTester(t).WithHandleFunc(func(h *OrganizationHandler) gin.HandlerFunc { + return h.Collections + }) + + tester.mocks.collection.EXPECT().OrgCollections(tester.ctx, &types.OrgCollectionsReq{ + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + Namespace: "u", + CurrentUser: "u", + }).Return([]types.Collection{{Name: "m"}}, 100, nil) + tester.WithUser().AddPagination(1, 10).Execute() + tester.ResponseEqSimple(t, 200, gin.H{ + "message": "OK", + "data": []types.Collection{{Name: "m"}}, + "total": 100, + }) +} + +func TestOrganizationHandler_Prompts(t *testing.T) { + tester := NewOrganizationTester(t).WithHandleFunc(func(h *OrganizationHandler) gin.HandlerFunc { + return h.Prompts + }) + + tester.mocks.prompt.EXPECT().OrgPrompts(tester.ctx, &types.OrgPromptsReq{ + PageOpts: types.PageOpts{ + Page: 1, + PageSize: 10, + }, + Namespace: "u", + CurrentUser: "u", + }).Return([]types.PromptRes{{Name: "m"}}, 100, nil) + tester.WithUser().AddPagination(1, 10).Execute() + tester.ResponseEqSimple(t, 200, gin.H{ + "message": "OK", + "data": []types.PromptRes{{Name: "m"}}, + "total": 100, + }) +} diff --git a/api/handler/runtime_architecture.go b/api/handler/runtime_architecture.go index 43062a3e..f2ad9c5b 100644 --- a/api/handler/runtime_architecture.go +++ b/api/handler/runtime_architecture.go @@ -23,14 +23,14 @@ func NewRuntimeArchitectureHandler(config *config.Config) (*RuntimeArchitectureH } return &RuntimeArchitectureHandler{ - rc: nrc, - rac: nrac, + repo: nrc, + runtimeArch: nrac, }, nil } type RuntimeArchitectureHandler struct { - rc component.RepoComponent - rac component.RuntimeArchitectureComponent + repo component.RepoComponent + runtimeArch component.RuntimeArchitectureComponent } // GetArchitectures godoc @@ -53,7 +53,7 @@ func (r *RuntimeArchitectureHandler) ListByRuntimeFrameworkID(ctx *gin.Context) httpbase.BadRequest(ctx, "invalid runtime framework ID format") return } - resp, err := r.rac.ListByRuntimeFrameworkID(ctx, id) + resp, err := r.runtimeArch.ListByRuntimeFrameworkID(ctx, id) if err != nil { slog.Error("fail to list runtime architectures", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -90,7 +90,7 @@ func (r *RuntimeArchitectureHandler) UpdateArchitecture(ctx *gin.Context) { return } - res, err := r.rac.SetArchitectures(ctx, id, req.Architectures) + res, err := r.runtimeArch.SetArchitectures(ctx, id, req.Architectures) if err != nil { slog.Error("Failed to set architectures", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -128,7 +128,7 @@ func (r *RuntimeArchitectureHandler) DeleteArchitecture(ctx *gin.Context) { return } - list, err := r.rac.DeleteArchitectures(ctx, id, req.Architectures) + list, err := r.runtimeArch.DeleteArchitectures(ctx, id, req.Architectures) if err != nil { slog.Error("Failed to delete architectures", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -181,7 +181,7 @@ func (r *RuntimeArchitectureHandler) ScanArchitecture(ctx *gin.Context) { return } - err = r.rac.ScanArchitecture(ctx, id, scanType, req.Models) + err = r.runtimeArch.ScanArchitecture(ctx, id, scanType, req.Models) if err != nil { slog.Error("Failed to scan architecture", slog.Any("error", err)) httpbase.ServerError(ctx, err) diff --git a/api/handler/runtime_architecture_test.go b/api/handler/runtime_architecture_test.go new file mode 100644 index 00000000..269211d5 --- /dev/null +++ b/api/handler/runtime_architecture_test.go @@ -0,0 +1,94 @@ +package handler + +import ( + "testing" + + "github.com/gin-gonic/gin" + mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +type RuntimeArchitectureTester struct { + *GinTester + handler *RuntimeArchitectureHandler + mocks struct { + repo *mockcomponent.MockRepoComponent + runtimeArch *mockcomponent.MockRuntimeArchitectureComponent + } +} + +func NewRuntimeArchitectureTester(t *testing.T) *RuntimeArchitectureTester { + tester := &RuntimeArchitectureTester{GinTester: NewGinTester()} + tester.mocks.repo = mockcomponent.NewMockRepoComponent(t) + tester.mocks.runtimeArch = mockcomponent.NewMockRuntimeArchitectureComponent(t) + + tester.handler = &RuntimeArchitectureHandler{ + repo: tester.mocks.repo, + runtimeArch: tester.mocks.runtimeArch, + } + tester.WithParam("namespace", "u") + tester.WithParam("name", "r") + return tester +} + +func (t *RuntimeArchitectureTester) WithHandleFunc(fn func(h *RuntimeArchitectureHandler) gin.HandlerFunc) *RuntimeArchitectureTester { + t.ginHandler = fn(t.handler) + return t +} + +func TestRuntimeArchHandler_ListByRuntimeFrameworkID(t *testing.T) { + tester := NewRuntimeArchitectureTester(t).WithHandleFunc(func(h *RuntimeArchitectureHandler) gin.HandlerFunc { + return h.ListByRuntimeFrameworkID + }) + + tester.mocks.runtimeArch.EXPECT().ListByRuntimeFrameworkID(tester.ctx, int64(1)).Return([]database.RuntimeArchitecture{{ID: 1}}, nil) + tester.WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, []database.RuntimeArchitecture{{ID: 1}}) +} + +func TestRuntimeArchHandler_UpdateArchitecture(t *testing.T) { + tester := NewRuntimeArchitectureTester(t).WithHandleFunc(func(h *RuntimeArchitectureHandler) gin.HandlerFunc { + return h.UpdateArchitecture + }) + + tester.mocks.runtimeArch.EXPECT().SetArchitectures( + tester.ctx, int64(1), []string{"foo"}, + ).Return([]string{"bar"}, nil) + tester.WithParam("id", "1").WithBody(t, &types.RuntimeArchitecture{ + Architectures: []string{"foo"}, + }).Execute() + + tester.ResponseEq(t, 200, tester.OKText, []string{"bar"}) +} + +func TestRuntimeArchHandler_DeleteArchitecture(t *testing.T) { + tester := NewRuntimeArchitectureTester(t).WithHandleFunc(func(h *RuntimeArchitectureHandler) gin.HandlerFunc { + return h.DeleteArchitecture + }) + + tester.mocks.runtimeArch.EXPECT().DeleteArchitectures( + tester.ctx, int64(1), []string{"foo"}, + ).Return([]string{"bar"}, nil) + tester.WithParam("id", "1").WithBody(t, &types.RuntimeArchitecture{ + Architectures: []string{"foo"}, + }).Execute() + + tester.ResponseEq(t, 200, tester.OKText, []string{"bar"}) +} + +func TestRuntimeArchHandler_ScanArchitecture(t *testing.T) { + tester := NewRuntimeArchitectureTester(t).WithHandleFunc(func(h *RuntimeArchitectureHandler) gin.HandlerFunc { + return h.ScanArchitecture + }) + + tester.mocks.runtimeArch.EXPECT().ScanArchitecture( + tester.ctx, int64(1), 2, []string{"foo"}, + ).Return(nil) + tester.WithParam("id", "1").WithQuery("scan_type", "2").WithBody(t, &types.RuntimeFrameworkModels{ + Models: []string{"foo"}, + }).Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} diff --git a/api/handler/space.go b/api/handler/space.go index da4ddd05..16070a72 100644 --- a/api/handler/space.go +++ b/api/handler/space.go @@ -29,18 +29,19 @@ func NewSpaceHandler(config *config.Config) (*SpaceHandler, error) { if err != nil { return nil, fmt.Errorf("error creating repo component:%w", err) } - return &SpaceHandler{ - c: c, - ssc: ssc, - repo: repo, + space: c, + sensitive: ssc, + repo: repo, + spaceStatusCheckInterval: 5 * time.Second, }, nil } type SpaceHandler struct { - c component.SpaceComponent - ssc component.SensitiveComponent - repo component.RepoComponent + space component.SpaceComponent + sensitive component.SensitiveComponent + repo component.RepoComponent + spaceStatusCheckInterval time.Duration } // GetAllSpaces godoc @@ -89,7 +90,7 @@ func (h *SpaceHandler) Index(ctx *gin.Context) { return } - spaces, total, err := h.c.Index(ctx, filter, per, page) + spaces, total, err := h.space.Index(ctx, filter, per, page) if err != nil { slog.Error("Failed to get spaces", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -123,7 +124,7 @@ func (h *SpaceHandler) Show(ctx *gin.Context) { return } currentUser := httpbase.GetCurrentUser(ctx) - detail, err := h.c.Show(ctx, namespace, name, currentUser) + detail, err := h.space.Show(ctx, namespace, name, currentUser) if err != nil { if errors.Is(err, component.ErrUnauthorized) { httpbase.UnauthorizedError(ctx, err) @@ -162,7 +163,7 @@ func (h *SpaceHandler) Create(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - _, err := h.ssc.CheckRequestV2(ctx, &req) + _, err := h.sensitive.CheckRequestV2(ctx, &req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -170,7 +171,7 @@ func (h *SpaceHandler) Create(ctx *gin.Context) { } req.Username = currentUser - space, err := h.c.Create(ctx, req) + space, err := h.space.Create(ctx, req) if err != nil { slog.Error("Failed to create space", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -207,7 +208,7 @@ func (h *SpaceHandler) Update(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - _, err := h.ssc.CheckRequestV2(ctx, req) + _, err := h.sensitive.CheckRequestV2(ctx, req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -224,7 +225,7 @@ func (h *SpaceHandler) Update(ctx *gin.Context) { req.Namespace = namespace req.Name = name - space, err := h.c.Update(ctx, req) + space, err := h.space.Update(ctx, req) if err != nil { slog.Error("Failed to update space", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -261,7 +262,7 @@ func (h *SpaceHandler) Delete(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - err = h.c.Delete(ctx, namespace, name, currentUser) + err = h.space.Delete(ctx, namespace, name, currentUser) if err != nil { slog.Error("Failed to delete space", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -308,7 +309,7 @@ func (h *SpaceHandler) Run(ctx *gin.Context) { httpbase.UnauthorizedError(ctx, errors.New("user not allowed to run sapce")) return } - deployID, err := h.c.Deploy(ctx, namespace, name, currentUser) + deployID, err := h.space.Deploy(ctx, namespace, name, currentUser) if err != nil { slog.Error("failed to deploy space", slog.String("namespace", namespace), slog.String("name", name), slog.Any("error", err)) @@ -339,7 +340,7 @@ func (h *SpaceHandler) Wakeup(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - err = h.c.Wakeup(ctx, namespace, name) + err = h.space.Wakeup(ctx, namespace, name) if err != nil { slog.Error("failed to wakeup space", slog.String("namespace", namespace), slog.String("name", name), slog.Any("error", err)) @@ -387,7 +388,7 @@ func (h *SpaceHandler) Stop(ctx *gin.Context) { return } - err = h.c.Stop(ctx, namespace, name, false) + err = h.space.Stop(ctx, namespace, name, false) if err != nil { slog.Error("failed to stop space", slog.String("namespace", namespace), slog.String("name", name), slog.Any("error", err)) @@ -452,9 +453,9 @@ func (h *SpaceHandler) Status(ctx *gin.Context) { slog.Info("space handler status request context done", slog.Any("error", ctx.Request.Context().Err())) return default: - time.Sleep(time.Second * 5) + time.Sleep(h.spaceStatusCheckInterval) //user http request context instead of gin context, so that server knows the life cycle of the request - _, status, err := h.c.Status(ctx.Request.Context(), namespace, name) + _, status, err := h.space.Status(ctx.Request.Context(), namespace, name) if err != nil { slog.Error("failed to get space status", slog.Any("error", err), slog.String("namespace", namespace), slog.String("name", name)) @@ -543,7 +544,7 @@ func (h *SpaceHandler) Logs(ctx *gin.Context) { ctx.Writer.Header().Set("Transfer-Encoding", "chunked") //user http request context instead of gin context, so that server knows the life cycle of the request - logReader, err := h.c.Logs(ctx.Request.Context(), namespace, name) + logReader, err := h.space.Logs(ctx.Request.Context(), namespace, name) if err != nil { httpbase.ServerError(ctx, err) return diff --git a/api/handler/space_resource.go b/api/handler/space_resource.go index 91f7563e..f13a08d9 100644 --- a/api/handler/space_resource.go +++ b/api/handler/space_resource.go @@ -17,12 +17,12 @@ func NewSpaceResourceHandler(config *config.Config) (*SpaceResourceHandler, erro return nil, err } return &SpaceResourceHandler{ - c: src, + spaceResource: src, }, nil } type SpaceResourceHandler struct { - c component.SpaceResourceComponent + spaceResource component.SpaceResourceComponent } // GetSpaceResources godoc @@ -51,7 +51,7 @@ func (h *SpaceResourceHandler) Index(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - spaceResources, err := h.c.Index(ctx, clusterId, deployType, "") + spaceResources, err := h.spaceResource.Index(ctx, clusterId, deployType, "") if err != nil { slog.Error("Failed to get space resources", slog.String("cluster_id", clusterId), slog.String("deploy_type", deployTypeStr), slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -80,7 +80,7 @@ func (h *SpaceResourceHandler) Create(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - spaceResource, err := h.c.Create(ctx, &req) + spaceResource, err := h.spaceResource.Create(ctx, &req) if err != nil { slog.Error("Failed to create space resources", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -122,7 +122,7 @@ func (h *SpaceResourceHandler) Update(ctx *gin.Context) { } req.ID = id - spaceResource, err := h.c.Update(ctx, req) + spaceResource, err := h.spaceResource.Update(ctx, req) if err != nil { slog.Error("Failed to update space resource", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -156,7 +156,7 @@ func (h *SpaceResourceHandler) Delete(ctx *gin.Context) { return } - err = h.c.Delete(ctx, id) + err = h.spaceResource.Delete(ctx, id) if err != nil { slog.Error("Failed to delete space resource", slog.Any("error", err)) httpbase.ServerError(ctx, err) diff --git a/api/handler/space_resource_test.go b/api/handler/space_resource_test.go new file mode 100644 index 00000000..51a7ec83 --- /dev/null +++ b/api/handler/space_resource_test.go @@ -0,0 +1,98 @@ +package handler + +import ( + "testing" + + "github.com/gin-gonic/gin" + mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/common/types" +) + +type SpaceResourceTester struct { + *GinTester + handler *SpaceResourceHandler + mocks struct { + spaceResource *mockcomponent.MockSpaceResourceComponent + } +} + +func NewSpaceResourceTester(t *testing.T) *SpaceResourceTester { + tester := &SpaceResourceTester{GinTester: NewGinTester()} + tester.mocks.spaceResource = mockcomponent.NewMockSpaceResourceComponent(t) + + tester.handler = &SpaceResourceHandler{ + spaceResource: tester.mocks.spaceResource, + } + tester.WithParam("namespace", "u") + tester.WithParam("name", "r") + return tester +} + +func (t *SpaceResourceTester) WithHandleFunc(fn func(h *SpaceResourceHandler) gin.HandlerFunc) *SpaceResourceTester { + t.ginHandler = fn(t.handler) + return t +} + +func TestSpaceResourceHandler_Index(t *testing.T) { + tester := NewSpaceResourceTester(t).WithHandleFunc(func(h *SpaceResourceHandler) gin.HandlerFunc { + return h.Index + }) + + tester.mocks.spaceResource.EXPECT().Index(tester.ctx, "c1", types.InferenceType, "").Return( + []types.SpaceResource{{Name: "sp"}}, nil, + ) + tester.WithQuery("cluster_id", "c1").WithQuery("deploy_type", "").WithUser().Execute() + + tester.ResponseEq(t, 200, tester.OKText, []types.SpaceResource{{Name: "sp"}}) +} + +func TestSpaceResourceHandler_Create(t *testing.T) { + tester := NewSpaceResourceTester(t).WithHandleFunc(func(h *SpaceResourceHandler) gin.HandlerFunc { + return h.Create + }) + + tester.mocks.spaceResource.EXPECT().Create(tester.ctx, &types.CreateSpaceResourceReq{ + ClusterID: "c", + Name: "n", + Resources: "r", + }).Return( + &types.SpaceResource{Name: "sp"}, nil, + ) + tester.WithBody(t, &types.CreateSpaceResourceReq{ + ClusterID: "c", Name: "n", Resources: "r", + }).WithUser().Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.SpaceResource{Name: "sp"}) +} + +func TestSpaceResourceHandler_Update(t *testing.T) { + tester := NewSpaceResourceTester(t).WithHandleFunc(func(h *SpaceResourceHandler) gin.HandlerFunc { + return h.Update + }) + + tester.mocks.spaceResource.EXPECT().Update(tester.ctx, &types.UpdateSpaceResourceReq{ + Name: "n", + Resources: "r", + ID: 1, + }).Return( + &types.SpaceResource{Name: "sp"}, nil, + ) + tester.WithBody(t, &types.UpdateSpaceResourceReq{ + Name: "n", Resources: "r", + }).WithUser().WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.SpaceResource{Name: "sp"}) +} + +func TestSpaceResourceHandler_Delete(t *testing.T) { + tester := NewSpaceResourceTester(t).WithHandleFunc(func(h *SpaceResourceHandler) gin.HandlerFunc { + return h.Delete + }) + + tester.mocks.spaceResource.EXPECT().Delete(tester.ctx, int64(1)).Return( + nil, + ) + tester.WithUser().WithParam("id", "1").Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} diff --git a/api/handler/space_test.go b/api/handler/space_test.go new file mode 100644 index 00000000..9f1f88e3 --- /dev/null +++ b/api/handler/space_test.go @@ -0,0 +1,259 @@ +package handler + +import ( + "context" + "fmt" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/builder/deploy" + "opencsg.com/csghub-server/common/types" +) + +type SpaceTester struct { + *GinTester + handler *SpaceHandler + mocks struct { + space *mockcomponent.MockSpaceComponent + sensitive *mockcomponent.MockSensitiveComponent + repo *mockcomponent.MockRepoComponent + } +} + +func NewSpaceTester(t *testing.T) *SpaceTester { + tester := &SpaceTester{GinTester: NewGinTester()} + tester.mocks.space = mockcomponent.NewMockSpaceComponent(t) + tester.mocks.sensitive = mockcomponent.NewMockSensitiveComponent(t) + tester.mocks.repo = mockcomponent.NewMockRepoComponent(t) + + tester.handler = &SpaceHandler{ + space: tester.mocks.space, + sensitive: tester.mocks.sensitive, + repo: tester.mocks.repo, + } + tester.WithParam("namespace", "u") + tester.WithParam("name", "r") + return tester +} + +func (t *SpaceTester) WithHandleFunc(fn func(h *SpaceHandler) gin.HandlerFunc) *SpaceTester { + t.ginHandler = fn(t.handler) + return t +} + +func TestSpaceHandler_Index(t *testing.T) { + cases := []struct { + sort string + source string + error bool + }{ + {"most_download", "local", false}, + {"foo", "local", true}, + {"most_download", "bar", true}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("%+v", c), func(t *testing.T) { + + tester := NewSpaceTester(t).WithHandleFunc(func(h *SpaceHandler) gin.HandlerFunc { + return h.Index + }) + + if !c.error { + tester.mocks.space.EXPECT().Index(tester.ctx, &types.RepoFilter{ + Search: "foo", + Sort: c.sort, + Source: c.source, + }, 10, 1).Return([]types.Space{ + {Name: "cc"}, + }, 100, nil) + } + + tester.AddPagination(1, 10).WithQuery("search", "foo"). + WithQuery("sort", c.sort). + WithQuery("source", c.source).Execute() + + if c.error { + require.Equal(t, 400, tester.response.Code) + } else { + tester.ResponseEqSimple(t, 200, gin.H{ + "data": []types.Space{{Name: "cc"}}, + "total": 100, + }) + } + }) + } + +} + +func TestSpaceHandler_Show(t *testing.T) { + tester := NewSpaceTester(t).WithHandleFunc(func(h *SpaceHandler) gin.HandlerFunc { + return h.Show + }) + + tester.WithUser() + tester.mocks.space.EXPECT().Show(tester.ctx, "u", "r", "u").Return(&types.Space{ + Name: "m", + }, nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Space{Name: "m"}) + +} + +func TestSpaceHandler_Create(t *testing.T) { + tester := NewSpaceTester(t).WithHandleFunc(func(h *SpaceHandler) gin.HandlerFunc { + return h.Create + }) + tester.RequireUser(t) + + req := &types.CreateSpaceReq{CreateRepoReq: types.CreateRepoReq{}} + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, req).Return(true, nil) + reqn := *req + reqn.Username = "u" + tester.mocks.space.EXPECT().Create(tester.ctx, reqn).Return(&types.Space{Name: "m"}, nil) + tester.WithBody(t, req).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Space{Name: "m"}) +} + +func TestSpaceHandler_Update(t *testing.T) { + tester := NewSpaceTester(t).WithHandleFunc(func(h *SpaceHandler) gin.HandlerFunc { + return h.Update + }) + tester.RequireUser(t) + + req := &types.UpdateSpaceReq{UpdateRepoReq: types.UpdateRepoReq{}} + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, req).Return(true, nil) + tester.mocks.space.EXPECT().Update(tester.ctx, &types.UpdateSpaceReq{ + UpdateRepoReq: types.UpdateRepoReq{ + Namespace: "u", + Name: "r", + Username: "u", + }, + }).Return(&types.Space{Name: "m"}, nil) + tester.WithBody(t, req).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Space{Name: "m"}) +} + +func TestSpaceHandler_Delete(t *testing.T) { + tester := NewSpaceTester(t).WithHandleFunc(func(h *SpaceHandler) gin.HandlerFunc { + return h.Delete + }) + tester.RequireUser(t) + + tester.mocks.space.EXPECT().Delete(tester.ctx, "u", "r", "u").Return(nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestSpaceHandler_Run(t *testing.T) { + tester := NewSpaceTester(t).WithHandleFunc(func(h *SpaceHandler) gin.HandlerFunc { + return h.Run + }) + tester.RequireUser(t) + + tester.mocks.repo.EXPECT().AllowAdminAccess(tester.ctx, types.SpaceRepo, "u", "r", "u").Return(true, nil) + tester.mocks.space.EXPECT().Deploy(tester.ctx, "u", "r", "u").Return(123, nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestSpaceHandler_Wakeup(t *testing.T) { + tester := NewSpaceTester(t).WithHandleFunc(func(h *SpaceHandler) gin.HandlerFunc { + return h.Wakeup + }) + + tester.mocks.space.EXPECT().Wakeup(tester.ctx, "u", "r").Return(nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestSpaceHandler_Stop(t *testing.T) { + tester := NewSpaceTester(t).WithHandleFunc(func(h *SpaceHandler) gin.HandlerFunc { + return h.Stop + }) + tester.RequireUser(t) + + tester.mocks.repo.EXPECT().AllowAdminAccess(tester.ctx, types.SpaceRepo, "u", "r", "u").Return(true, nil) + tester.mocks.space.EXPECT().Stop(tester.ctx, "u", "r", false).Return(nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestSpaceHandler_Status(t *testing.T) { + tester := NewSpaceTester(t).WithHandleFunc(func(h *SpaceHandler) gin.HandlerFunc { + return h.Status + }) + tester.handler.spaceStatusCheckInterval = 0 + + tester.mocks.repo.EXPECT().AllowReadAccess( + tester.ctx, types.SpaceRepo, "u", "r", "u", + ).Return(true, nil) + cc, cancel := context.WithCancel(tester.ctx.Request.Context()) + tester.ctx.Request = tester.ctx.Request.WithContext(cc) + tester.mocks.space.EXPECT().Status( + mock.Anything, "u", "r", + ).Return("", "s1", nil).Once() + tester.mocks.space.EXPECT().Status( + mock.Anything, "u", "r", + ).RunAndReturn(func(ctx context.Context, s1, s2 string) (string, string, error) { + cancel() + return "", "s3", nil + }).Once() + + tester.WithUser().Execute() + require.Equal(t, 200, tester.response.Code) + headers := tester.response.Header() + require.Equal(t, "text/event-stream", headers.Get("Content-Type")) + require.Equal(t, "no-cache", headers.Get("Cache-Control")) + require.Equal(t, "keep-alive", headers.Get("Connection")) + require.Equal(t, "chunked", headers.Get("Transfer-Encoding")) + require.Equal( + t, "event:status\ndata:s1\n\nevent:status\ndata:s3\n\n", + tester.response.Body.String(), + ) + +} + +func TestSpaceHandler_Logs(t *testing.T) { + tester := NewSpaceTester(t).WithHandleFunc(func(h *SpaceHandler) gin.HandlerFunc { + return h.Logs + }) + + tester.mocks.repo.EXPECT().AllowReadAccess( + tester.ctx, types.SpaceRepo, "u", "r", "u", + ).Return(true, nil) + runlogChan := make(chan string) + tester.mocks.space.EXPECT().Logs( + mock.Anything, "u", "r", + ).Return(deploy.NewMultiLogReader(nil, runlogChan), nil) + cc, cancel := context.WithCancel(tester.ctx.Request.Context()) + tester.ctx.Request = tester.ctx.Request.WithContext(cc) + go func() { + runlogChan <- "foo" + runlogChan <- "bar" + close(runlogChan) + cancel() + }() + + tester.WithUser().Execute() + require.Equal(t, 200, tester.response.Code) + headers := tester.response.Header() + require.Equal(t, "text/event-stream", headers.Get("Content-Type")) + require.Equal(t, "no-cache", headers.Get("Cache-Control")) + require.Equal(t, "keep-alive", headers.Get("Connection")) + require.Equal(t, "chunked", headers.Get("Transfer-Encoding")) + require.Equal( + t, "event:Container\ndata:foo\n\nevent:Container\ndata:bar\n\n", + tester.response.Body.String(), + ) +} diff --git a/builder/deploy/deployer.go b/builder/deploy/deployer.go index 0838f217..3c981a45 100644 --- a/builder/deploy/deployer.go +++ b/builder/deploy/deployer.go @@ -728,7 +728,11 @@ func CheckResource(clusterResources *types.ClusterRes, hardware *types.HardWare) } for _, node := range clusterResources.Resources { if float32(mem) <= node.AvailableMem { - return checkNodeResource(node, hardware) + isAvailable := checkNodeResource(node, hardware) + if isAvailable { + // if true return, otherwise continue check next node + return true + } } } return false diff --git a/builder/deploy/deployer_test.go b/builder/deploy/deployer_test.go index 07ac662c..23e2654d 100644 --- a/builder/deploy/deployer_test.go +++ b/builder/deploy/deployer_test.go @@ -691,6 +691,10 @@ func TestDeployer_CheckResource(t *testing.T) { Gpu: types.GPU{Num: "1", Type: "t1"}, Cpu: types.CPU{Num: "20"}, }, false}, + {&types.HardWare{ + Gpu: types.GPU{Num: "1", Type: "t1"}, + Cpu: types.CPU{Num: "12"}, + }, true}, } for _, c := range cases { @@ -698,6 +702,7 @@ func TestDeployer_CheckResource(t *testing.T) { v := CheckResource(&types.ClusterRes{ Resources: []types.NodeResourceInfo{ {AvailableXPU: 10, XPUModel: "t1", AvailableCPU: 10, AvailableMem: 10000}, + {AvailableXPU: 12, XPUModel: "t1", AvailableCPU: 12, AvailableMem: 10000}, }, }, c.hardware) require.Equal(t, c.available, v, c.hardware) diff --git a/cmd/csghub-server/cmd/start/server.go b/cmd/csghub-server/cmd/start/server.go index ff24eb0d..c9e66bb6 100644 --- a/cmd/csghub-server/cmd/start/server.go +++ b/cmd/csghub-server/cmd/start/server.go @@ -83,14 +83,14 @@ var serverCmd = &cobra.Command{ if err != nil { return fmt.Errorf("failed to init deploy: %w", err) } - r, err := router.NewRouter(cfg, enableSwagger) - if err != nil { - return fmt.Errorf("failed to init router: %w", err) - } err = workflow.StartWorker(cfg) if err != nil { return fmt.Errorf("failed to start worker: %w", err) } + r, err := router.NewRouter(cfg, enableSwagger) + if err != nil { + return fmt.Errorf("failed to init router: %w", err) + } err = workflow.RegisterCronJobs(cfg) if err != nil { diff --git a/common/types/discussion.go b/common/types/discussion.go new file mode 100644 index 00000000..c966f000 --- /dev/null +++ b/common/types/discussion.go @@ -0,0 +1,144 @@ +package types + +import ( + "time" + + "opencsg.com/csghub-server/builder/sensitive" +) + +type CreateRepoDiscussionRequest struct { + Title string `json:"title" binding:"required"` + RepoType RepositoryType `json:"-"` + Namespace string `json:"-"` + Name string `json:"-"` + CurrentUser string `json:"-"` +} + +// CreateRepoDiscussionRequest implements SensitiveRequestV2 +var _ SensitiveRequestV2 = (*CreateRepoDiscussionRequest)(nil) + +func (req *CreateRepoDiscussionRequest) GetSensitiveFields() []SensitiveField { + return []SensitiveField{ + { + Name: "title", + Value: func() string { + return req.Title + }, + Scenario: string(sensitive.ScenarioCommentDetection), + }, + } +} + +type CreateDiscussionResponse struct { + ID int64 `json:"id"` + User *DiscussionResponse_User `json:"user"` + Title string `json:"title"` + // DiscussionableID int64 `json:"discussionable_id"` + // DiscussionableType string `json:"discussionable_type"` + CommentCount int64 `json:"comment_count"` + CreatedAt time.Time `json:"created_at"` + // UpdatedAt time.Time `json:"updated_at"` +} + +type DiscussionResponse_User struct { + ID int64 `json:"id"` + Username string `json:"name"` + Avatar string `json:"avatar"` +} + +type UpdateDiscussionRequest struct { + ID int64 `json:"-"` + Title string `json:"title" binding:"required"` + CurrentUser string `json:"-"` +} + +// UpdateDiscussionRequest implements SensitiveRequestV2 +var _ SensitiveRequestV2 = (*UpdateDiscussionRequest)(nil) + +func (req *UpdateDiscussionRequest) GetSensitiveFields() []SensitiveField { + return []SensitiveField{ + { + Name: "title", + Value: func() string { + return req.Title + }, + Scenario: string(sensitive.ScenarioCommentDetection), + }, + } +} + +type ShowDiscussionResponse struct { + ID int64 `json:"id"` + Title string `json:"title"` + User *DiscussionResponse_User `json:"user"` + CommentCount int64 `json:"comment_count"` + Comments []*DiscussionResponse_Comment `json:"comments,omitempty"` +} + +type DiscussionResponse_Comment struct { + ID int64 `json:"id"` + Content string `json:"content"` + User *DiscussionResponse_User `json:"user"` + CreatedAt time.Time `json:"created_at"` +} + +type ListRepoDiscussionRequest struct { + RepoType RepositoryType `json:"-"` + Namespace string `json:"-"` + Name string `json:"-"` + CurrentUser string `json:"-"` +} + +type ListRepoDiscussionResponse struct { + Discussions []*CreateDiscussionResponse `json:"discussions"` +} + +type CreateCommentRequest struct { + Content string `json:"content" binding:"required"` + CommentableID int64 `json:"commentable_id"` + CommentableType string `json:"commentable_type"` + CurrentUser string `json:"-"` +} + +// CreateCommentRequest implements SensitiveRequestV2 +var _ SensitiveRequestV2 = (*CreateCommentRequest)(nil) + +func (req *CreateCommentRequest) GetSensitiveFields() []SensitiveField { + return []SensitiveField{ + { + Name: "content", + Value: func() string { + return req.Content + }, + Scenario: string(sensitive.ScenarioCommentDetection), + }, + } +} + +type CreateCommentResponse struct { + ID int64 `json:"id"` + CommentableID int64 `json:"commentable_id"` + CommentableType string `json:"commentable_type"` + CreatedAt time.Time `json:"created_at"` + User *DiscussionResponse_User `json:"user"` +} + +type UpdateCommentRequest struct { + ID int64 `json:"-"` + Content string `json:"content" binding:"required"` +} + +// UpdateCommentRequest implements SensitiveRequestV2 +var _ SensitiveRequestV2 = (*UpdateCommentRequest)(nil) + +func (req *UpdateCommentRequest) GetSensitiveFields() []SensitiveField { + return []SensitiveField{ + { + Name: "content", + Value: func() string { + return req.Content + }, + Scenario: string(sensitive.ScenarioCommentDetection), + }, + } +} diff --git a/component/discussion.go b/component/discussion.go index 06f30e9d..e00e6a8e 100644 --- a/component/discussion.go +++ b/component/discussion.go @@ -4,9 +4,7 @@ import ( "context" "database/sql" "fmt" - "time" - "opencsg.com/csghub-server/builder/sensitive" "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/types" ) @@ -18,15 +16,15 @@ type discussionComponentImpl struct { } type DiscussionComponent interface { - CreateRepoDiscussion(ctx context.Context, req CreateRepoDiscussionRequest) (*CreateDiscussionResponse, error) - GetDiscussion(ctx context.Context, id int64) (*ShowDiscussionResponse, error) - UpdateDiscussion(ctx context.Context, req UpdateDiscussionRequest) error + CreateRepoDiscussion(ctx context.Context, req types.CreateRepoDiscussionRequest) (*types.CreateDiscussionResponse, error) + GetDiscussion(ctx context.Context, id int64) (*types.ShowDiscussionResponse, error) + UpdateDiscussion(ctx context.Context, req types.UpdateDiscussionRequest) error DeleteDiscussion(ctx context.Context, currentUser string, id int64) error - ListRepoDiscussions(ctx context.Context, req ListRepoDiscussionRequest) (*ListRepoDiscussionResponse, error) - CreateDiscussionComment(ctx context.Context, req CreateCommentRequest) (*CreateCommentResponse, error) + ListRepoDiscussions(ctx context.Context, req types.ListRepoDiscussionRequest) (*types.ListRepoDiscussionResponse, error) + CreateDiscussionComment(ctx context.Context, req types.CreateCommentRequest) (*types.CreateCommentResponse, error) UpdateComment(ctx context.Context, currentUser string, id int64, content string) error DeleteComment(ctx context.Context, currentUser string, id int64) error - ListDiscussionComments(ctx context.Context, discussionID int64) ([]*DiscussionResponse_Comment, error) + ListDiscussionComments(ctx context.Context, discussionID int64) ([]*types.DiscussionResponse_Comment, error) } func NewDiscussionComponent() DiscussionComponent { @@ -36,7 +34,7 @@ func NewDiscussionComponent() DiscussionComponent { return &discussionComponentImpl{discussionStore: ds, repoStore: rs, userStore: us} } -func (c *discussionComponentImpl) CreateRepoDiscussion(ctx context.Context, req CreateRepoDiscussionRequest) (*CreateDiscussionResponse, error) { +func (c *discussionComponentImpl) CreateRepoDiscussion(ctx context.Context, req types.CreateRepoDiscussionRequest) (*types.CreateDiscussionResponse, error) { //TODO:check if the user can access the repo //get repo by namespace and name @@ -57,9 +55,9 @@ func (c *discussionComponentImpl) CreateRepoDiscussion(ctx context.Context, req if err != nil { return nil, fmt.Errorf("failed to create discussion: %w", err) } - resp := &CreateDiscussionResponse{ + resp := &types.CreateDiscussionResponse{ ID: discussion.ID, - User: &DiscussionResponse_User{ + User: &types.DiscussionResponse_User{ ID: user.ID, Username: user.Username, Avatar: user.Avatar, @@ -71,7 +69,7 @@ func (c *discussionComponentImpl) CreateRepoDiscussion(ctx context.Context, req return resp, nil } -func (c *discussionComponentImpl) GetDiscussion(ctx context.Context, id int64) (*ShowDiscussionResponse, error) { +func (c *discussionComponentImpl) GetDiscussion(ctx context.Context, id int64) (*types.ShowDiscussionResponse, error) { discussion, err := c.discussionStore.FindByID(ctx, id) if err != nil { return nil, fmt.Errorf("failed to find discussion by id '%d': %w", id, err) @@ -80,20 +78,20 @@ func (c *discussionComponentImpl) GetDiscussion(ctx context.Context, id int64) ( if err != nil && err != sql.ErrNoRows { return nil, fmt.Errorf("failed to find discussion comments by discussion id '%d': %w", discussion.ID, err) } - resp := &ShowDiscussionResponse{ + resp := &types.ShowDiscussionResponse{ ID: discussion.ID, Title: discussion.Title, - User: &DiscussionResponse_User{ + User: &types.DiscussionResponse_User{ ID: discussion.User.ID, Username: discussion.User.Username, Avatar: discussion.User.Avatar, }, } for _, comment := range comments { - resp.Comments = append(resp.Comments, &DiscussionResponse_Comment{ + resp.Comments = append(resp.Comments, &types.DiscussionResponse_Comment{ ID: comment.ID, Content: comment.Content, - User: &DiscussionResponse_User{ + User: &types.DiscussionResponse_User{ ID: comment.User.ID, Username: comment.User.Username, Avatar: comment.User.Avatar, @@ -103,7 +101,7 @@ func (c *discussionComponentImpl) GetDiscussion(ctx context.Context, id int64) ( return resp, nil } -func (c *discussionComponentImpl) UpdateDiscussion(ctx context.Context, req UpdateDiscussionRequest) error { +func (c *discussionComponentImpl) UpdateDiscussion(ctx context.Context, req types.UpdateDiscussionRequest) error { //check if the user is the owner of the discussion user, err := c.userStore.FindByUsername(ctx, req.CurrentUser) if err != nil { @@ -138,7 +136,7 @@ func (c *discussionComponentImpl) DeleteDiscussion(ctx context.Context, currentU return nil } -func (c *discussionComponentImpl) ListRepoDiscussions(ctx context.Context, req ListRepoDiscussionRequest) (*ListRepoDiscussionResponse, error) { +func (c *discussionComponentImpl) ListRepoDiscussions(ctx context.Context, req types.ListRepoDiscussionRequest) (*types.ListRepoDiscussionResponse, error) { //TODO:check if the user can access the repo repo, err := c.repoStore.FindByPath(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { @@ -148,14 +146,14 @@ func (c *discussionComponentImpl) ListRepoDiscussions(ctx context.Context, req L if err != nil { return nil, fmt.Errorf("failed to list repo discussions by repo type '%s', namespace '%s', name '%s': %w", req.RepoType, req.Namespace, req.Name, err) } - resp := &ListRepoDiscussionResponse{} + resp := &types.ListRepoDiscussionResponse{} for _, discussion := range discussions { - resp.Discussions = append(resp.Discussions, &CreateDiscussionResponse{ + resp.Discussions = append(resp.Discussions, &types.CreateDiscussionResponse{ ID: discussion.ID, Title: discussion.Title, CommentCount: discussion.CommentCount, CreatedAt: discussion.CreatedAt, - User: &DiscussionResponse_User{ + User: &types.DiscussionResponse_User{ ID: discussion.User.ID, Username: discussion.User.Username, Avatar: discussion.User.Avatar, @@ -165,7 +163,7 @@ func (c *discussionComponentImpl) ListRepoDiscussions(ctx context.Context, req L return resp, nil } -func (c *discussionComponentImpl) CreateDiscussionComment(ctx context.Context, req CreateCommentRequest) (*CreateCommentResponse, error) { +func (c *discussionComponentImpl) CreateDiscussionComment(ctx context.Context, req types.CreateCommentRequest) (*types.CreateCommentResponse, error) { req.CommentableType = database.CommentableTypeDiscussion // get discussion by id _, err := c.discussionStore.FindByID(ctx, req.CommentableID) @@ -188,12 +186,12 @@ func (c *discussionComponentImpl) CreateDiscussionComment(ctx context.Context, r if err != nil { return nil, fmt.Errorf("failed to create discussion comment: %w", err) } - return &CreateCommentResponse{ + return &types.CreateCommentResponse{ ID: comment.ID, CommentableID: comment.CommentableID, CommentableType: comment.CommentableType, CreatedAt: comment.CreatedAt, - User: &DiscussionResponse_User{ + User: &types.DiscussionResponse_User{ ID: user.ID, Username: user.Username, Avatar: user.Avatar, @@ -243,17 +241,17 @@ func (c *discussionComponentImpl) DeleteComment(ctx context.Context, currentUser return nil } -func (c *discussionComponentImpl) ListDiscussionComments(ctx context.Context, discussionID int64) ([]*DiscussionResponse_Comment, error) { +func (c *discussionComponentImpl) ListDiscussionComments(ctx context.Context, discussionID int64) ([]*types.DiscussionResponse_Comment, error) { comments, err := c.discussionStore.FindDiscussionComments(ctx, discussionID) if err != nil { return nil, fmt.Errorf("failed to find discussion comments by discussion id '%d': %w", discussionID, err) } - resp := make([]*DiscussionResponse_Comment, 0, len(comments)) + resp := make([]*types.DiscussionResponse_Comment, 0, len(comments)) for _, comment := range comments { - resp = append(resp, &DiscussionResponse_Comment{ + resp = append(resp, &types.DiscussionResponse_Comment{ ID: comment.ID, Content: comment.Content, - User: &DiscussionResponse_User{ + User: &types.DiscussionResponse_User{ ID: comment.User.ID, Username: comment.User.Username, Avatar: comment.User.Avatar, @@ -263,142 +261,3 @@ func (c *discussionComponentImpl) ListDiscussionComments(ctx context.Context, di } return resp, nil } - -//--- request and response ---// - -type CreateRepoDiscussionRequest struct { - Title string `json:"title" binding:"required"` - RepoType types.RepositoryType `json:"-"` - Namespace string `json:"-"` - Name string `json:"-"` - CurrentUser string `json:"-"` -} - -// CreateRepoDiscussionRequest implements types.SensitiveRequestV2 -var _ types.SensitiveRequestV2 = (*CreateRepoDiscussionRequest)(nil) - -func (req *CreateRepoDiscussionRequest) GetSensitiveFields() []types.SensitiveField { - return []types.SensitiveField{ - { - Name: "title", - Value: func() string { - return req.Title - }, - Scenario: string(sensitive.ScenarioCommentDetection), - }, - } -} - -type CreateDiscussionResponse struct { - ID int64 `json:"id"` - User *DiscussionResponse_User `json:"user"` - Title string `json:"title"` - // DiscussionableID int64 `json:"discussionable_id"` - // DiscussionableType string `json:"discussionable_type"` - CommentCount int64 `json:"comment_count"` - CreatedAt time.Time `json:"created_at"` - // UpdatedAt time.Time `json:"updated_at"` -} - -type DiscussionResponse_User struct { - ID int64 `json:"id"` - Username string `json:"name"` - Avatar string `json:"avatar"` -} - -type UpdateDiscussionRequest struct { - ID int64 `json:"-"` - Title string `json:"title" binding:"required"` - CurrentUser string `json:"-"` -} - -// UpdateDiscussionRequest implements types.SensitiveRequestV2 -var _ types.SensitiveRequestV2 = (*UpdateDiscussionRequest)(nil) - -func (req *UpdateDiscussionRequest) GetSensitiveFields() []types.SensitiveField { - return []types.SensitiveField{ - { - Name: "title", - Value: func() string { - return req.Title - }, - Scenario: string(sensitive.ScenarioCommentDetection), - }, - } -} - -type ShowDiscussionResponse struct { - ID int64 `json:"id"` - Title string `json:"title"` - User *DiscussionResponse_User `json:"user"` - CommentCount int64 `json:"comment_count"` - Comments []*DiscussionResponse_Comment `json:"comments,omitempty"` -} - -type DiscussionResponse_Comment struct { - ID int64 `json:"id"` - Content string `json:"content"` - User *DiscussionResponse_User `json:"user"` - CreatedAt time.Time `json:"created_at"` -} - -type ListRepoDiscussionRequest struct { - RepoType types.RepositoryType `json:"-"` - Namespace string `json:"-"` - Name string `json:"-"` - CurrentUser string `json:"-"` -} - -type ListRepoDiscussionResponse struct { - Discussions []*CreateDiscussionResponse `json:"discussions"` -} - -type CreateCommentRequest struct { - Content string `json:"content" binding:"required"` - CommentableID int64 `json:"commentable_id"` - CommentableType string `json:"commentable_type"` - CurrentUser string `json:"-"` -} - -// CreateCommentRequest implements types.SensitiveRequestV2 -var _ types.SensitiveRequestV2 = (*CreateCommentRequest)(nil) - -func (req *CreateCommentRequest) GetSensitiveFields() []types.SensitiveField { - return []types.SensitiveField{ - { - Name: "content", - Value: func() string { - return req.Content - }, - Scenario: string(sensitive.ScenarioCommentDetection), - }, - } -} - -type CreateCommentResponse struct { - ID int64 `json:"id"` - CommentableID int64 `json:"commentable_id"` - CommentableType string `json:"commentable_type"` - CreatedAt time.Time `json:"created_at"` - User *DiscussionResponse_User `json:"user"` -} - -type UpdateCommentRequest struct { - ID int64 `json:"-"` - Content string `json:"content" binding:"required"` -} - -// UpdateCommentRequest implements types.SensitiveRequestV2 -var _ types.SensitiveRequestV2 = (*UpdateCommentRequest)(nil) - -func (req *UpdateCommentRequest) GetSensitiveFields() []types.SensitiveField { - return []types.SensitiveField{ - { - Name: "content", - Value: func() string { - return req.Content - }, - Scenario: string(sensitive.ScenarioCommentDetection), - }, - } -} diff --git a/component/discussion_test.go b/component/discussion_test.go index ddceff7a..afb7bd27 100644 --- a/component/discussion_test.go +++ b/component/discussion_test.go @@ -46,7 +46,7 @@ func TestDiscussionComponent_CreateDisucssion(t *testing.T) { dbdisc.CreatedAt = time.Now() mockDiscussionStore.EXPECT().Create(mock.Anything, disc).Return(&dbdisc, nil).Once() - req := CreateRepoDiscussionRequest{ + req := types.CreateRepoDiscussionRequest{ Title: "test discussion", RepoType: "model", Namespace: "namespace", @@ -56,12 +56,12 @@ func TestDiscussionComponent_CreateDisucssion(t *testing.T) { actualDisc, err := comp.CreateRepoDiscussion(context.TODO(), req) require.Nil(t, err) - expectedDisc := &CreateDiscussionResponse{ + expectedDisc := &types.CreateDiscussionResponse{ ID: 1, Title: "test discussion", CommentCount: 0, CreatedAt: dbdisc.CreatedAt, - User: &DiscussionResponse_User{ + User: &types.DiscussionResponse_User{ ID: 1, Username: "user", Avatar: "avatar", @@ -131,7 +131,7 @@ func TestDiscussionComponent_UpdateDisussion(t *testing.T) { discussionStore: mockDiscussionStore, } - req := UpdateDiscussionRequest{ + req := types.UpdateDiscussionRequest{ ID: 1, Title: "test discussion", CurrentUser: "user", @@ -222,7 +222,7 @@ func TestDiscussionComponent_ListRepoDiscussions(t *testing.T) { }) mockDiscussionStore.EXPECT().FindByDiscussionableID(mock.Anything, database.DiscussionableTypeRepo, repo.ID).Return(discussions, nil).Once() - resp, err := comp.ListRepoDiscussions(context.TODO(), ListRepoDiscussionRequest{ + resp, err := comp.ListRepoDiscussions(context.TODO(), types.ListRepoDiscussionRequest{ RepoType: types.ModelRepo, Namespace: "namespace", Name: "name", @@ -244,7 +244,7 @@ func TestDiscussionComponent_CreateDisussionComment(t *testing.T) { discussionStore: mockDiscussionStore, } - req := CreateCommentRequest{ + req := types.CreateCommentRequest{ Content: "test comment", CommentableID: 1, CommentableType: database.CommentableTypeDiscussion, @@ -298,7 +298,7 @@ func TestDiscussionComponent_UpdateComment(t *testing.T) { discussionStore: mockDiscussionStore, } - req := CreateCommentRequest{ + req := types.CreateCommentRequest{ Content: "test comment", CommentableID: 1, CommentableType: database.CommentableTypeDiscussion, @@ -337,7 +337,7 @@ func TestDiscussionComponent_DeleteComment(t *testing.T) { discussionStore: mockDiscussionStore, } - req := CreateCommentRequest{ + req := types.CreateCommentRequest{ Content: "test comment", CommentableID: 1, CommentableType: database.CommentableTypeDiscussion, @@ -399,7 +399,7 @@ func TestDiscussionComponent_ListDiscussionComments(t *testing.T) { func TestCreateRepoDiscussionRequest_GetSensitiveFields(t *testing.T) { - req := CreateRepoDiscussionRequest{ + req := types.CreateRepoDiscussionRequest{ Title: "title", } fields := req.GetSensitiveFields() @@ -410,7 +410,7 @@ func TestCreateRepoDiscussionRequest_GetSensitiveFields(t *testing.T) { func TestUpdateDiscussionRequest_GetSensitiveFields(t *testing.T) { - req := UpdateDiscussionRequest{ + req := types.UpdateDiscussionRequest{ Title: "title", } fields := req.GetSensitiveFields() @@ -420,7 +420,7 @@ func TestUpdateDiscussionRequest_GetSensitiveFields(t *testing.T) { } func TestCreateCommentRequest_GetSensitiveFields(t *testing.T) { - req := CreateCommentRequest{ + req := types.CreateCommentRequest{ Content: "content", } fields := req.GetSensitiveFields() @@ -431,7 +431,7 @@ func TestCreateCommentRequest_GetSensitiveFields(t *testing.T) { func TestUpdateCommentRequest_GetSensitiveFields(t *testing.T) { - req := UpdateCommentRequest{ + req := types.UpdateCommentRequest{ Content: "content", } fields := req.GetSensitiveFields() diff --git a/docs/docs.go b/docs/docs.go index 31468db6..e3c2369e 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -1173,7 +1173,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/component.UpdateCommentRequest" + "$ref": "#/definitions/types.UpdateCommentRequest" } } ], @@ -1911,7 +1911,7 @@ const docTemplate = `{ "type": "object", "properties": { "data": { - "$ref": "#/definitions/component.ShowDiscussionResponse" + "$ref": "#/definitions/types.ShowDiscussionResponse" } } } @@ -1970,7 +1970,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/component.UpdateDiscussionRequest" + "$ref": "#/definitions/types.UpdateDiscussionRequest" } } ], @@ -2091,7 +2091,7 @@ const docTemplate = `{ "data": { "type": "array", "items": { - "$ref": "#/definitions/component.DiscussionResponse_Comment" + "$ref": "#/definitions/types.DiscussionResponse_Comment" } } } @@ -2144,7 +2144,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/component.CreateCommentRequest" + "$ref": "#/definitions/types.CreateCommentRequest" } } ], @@ -2160,7 +2160,7 @@ const docTemplate = `{ "type": "object", "properties": { "data": { - "$ref": "#/definitions/component.CreateCommentResponse" + "$ref": "#/definitions/types.CreateCommentResponse" } } } @@ -4023,77 +4023,6 @@ const docTemplate = `{ } } }, - "/models/{namespace}/{name}/predict": { - "post": { - "security": [ - { - "ApiKey": [] - } - ], - "description": "invoke model prediction", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Model" - ], - "summary": "Invoke model prediction", - "parameters": [ - { - "type": "string", - "description": "namespace", - "name": "namespace", - "in": "path", - "required": true - }, - { - "type": "string", - "description": "name", - "name": "name", - "in": "path", - "required": true - }, - { - "type": "string", - "description": "current user", - "name": "current_user", - "in": "query" - }, - { - "description": "input for model prediction", - "name": "body", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/types.ModelPredictReq" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "type": "string" - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } - }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - } - }, "/models/{namespace}/{name}/relations": { "get": { "security": [ @@ -6489,478 +6418,70 @@ const docTemplate = `{ { "type": "string", "description": "filter by license tag", - "name": "license_tag", - "in": "query" - }, - { - "type": "string", - "description": "filter by language tag", - "name": "language_tag", - "in": "query" - }, - { - "type": "string", - "description": "sort by", - "name": "sort", - "in": "query" - }, - { - "enum": [ - "opencsg", - "huggingface", - "local" - ], - "type": "string", - "description": "source", - "name": "source", - "in": "query" - }, - { - "type": "integer", - "default": 20, - "description": "per", - "name": "per", - "in": "query" - }, - { - "type": "integer", - "default": 1, - "description": "per page", - "name": "page", - "in": "query" - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "allOf": [ - { - "$ref": "#/definitions/types.ResponseWithTotal" - }, - { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "$ref": "#/definitions/types.PromptRes" - } - }, - "total": { - "type": "integer" - } - } - } - ] - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } - }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - }, - "post": { - "security": [ - { - "ApiKey": [] - } - ], - "description": "create a new prompt repo", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Prompt" - ], - "summary": "Create a new prompt repo", - "parameters": [ - { - "type": "string", - "description": "current user, the owner", - "name": "current_user", - "in": "query" - }, - { - "description": "body", - "name": "body", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/types.CreatePromptRepoReq" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/types.Response" - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } - }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - } - }, - "/prompts/conversations": { - "get": { - "security": [ - { - "ApiKey": [] - } - ], - "description": "List conversations of user", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Prompt" - ], - "summary": "List conversations of user", - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/types.Response" - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } - }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - }, - "post": { - "security": [ - { - "ApiKey": [] - } - ], - "description": "Create new conversation", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Prompt" - ], - "summary": "Create new conversation", - "parameters": [ - { - "description": "body", - "name": "body", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/types.Conversation" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/types.Response" - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } - }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - } - }, - "/prompts/conversations/{id}": { - "get": { - "security": [ - { - "ApiKey": [] - } - ], - "description": "Get a conversation by uuid", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Prompt" - ], - "summary": "Get a conversation by uuid", - "parameters": [ - { - "type": "string", - "description": "conversation uuid", - "name": "id", - "in": "path", - "required": true - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/types.Response" - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } - }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - }, - "put": { - "security": [ - { - "ApiKey": [] - } - ], - "description": "Update a conversation title", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Prompt" - ], - "summary": "Update a conversation title", - "parameters": [ - { - "type": "string", - "description": "conversation uuid", - "name": "id", - "in": "path", - "required": true - }, - { - "description": "body", - "name": "body", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/types.ConversationTitle" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/types.Response" - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } - }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - }, - "post": { - "security": [ - { - "ApiKey": [] - } - ], - "description": "Submit a conversation message", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Prompt" - ], - "summary": "Submit a conversation message", - "parameters": [ - { - "type": "string", - "description": "conversation uuid", - "name": "id", - "in": "path", - "required": true - }, - { - "description": "body", - "name": "body", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/types.Conversation" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/types.Response" - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } - }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - }, - "delete": { - "security": [ - { - "ApiKey": [] - } - ], - "description": "Delete a conversation", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Prompt" - ], - "summary": "Delete a conversation", - "parameters": [ - { - "type": "string", - "description": "conversation uuid", - "name": "id", - "in": "path", - "required": true - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/types.Response" - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } + "name": "license_tag", + "in": "query" }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - } - }, - "/prompts/conversations/{id}/message/{msgid}/hate": { - "put": { - "security": [ { - "ApiKey": [] - } - ], - "description": "Hate a conversation message", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Prompt" - ], - "summary": "Hate a conversation message", - "parameters": [ + "type": "string", + "description": "filter by language tag", + "name": "language_tag", + "in": "query" + }, { "type": "string", - "description": "conversation uuid", - "name": "uuid", - "in": "path", - "required": true + "description": "sort by", + "name": "sort", + "in": "query" }, { + "enum": [ + "opencsg", + "huggingface", + "local" + ], "type": "string", - "description": "message id", - "name": "id", - "in": "path", - "required": true + "description": "source", + "name": "source", + "in": "query" + }, + { + "type": "integer", + "default": 20, + "description": "per", + "name": "per", + "in": "query" + }, + { + "type": "integer", + "default": 1, + "description": "per page", + "name": "page", + "in": "query" } ], "responses": { "200": { "description": "OK", "schema": { - "$ref": "#/definitions/types.Response" + "allOf": [ + { + "$ref": "#/definitions/types.ResponseWithTotal" + }, + { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/definitions/types.PromptRes" + } + }, + "total": { + "type": "integer" + } + } + } + ] } }, "400": { @@ -6976,16 +6497,14 @@ const docTemplate = `{ } } } - } - }, - "/prompts/conversations/{id}/message/{msgid}/like": { - "put": { + }, + "post": { "security": [ { "ApiKey": [] } ], - "description": "Like a conversation message", + "description": "create a new prompt repo", "consumes": [ "application/json" ], @@ -6995,21 +6514,22 @@ const docTemplate = `{ "tags": [ "Prompt" ], - "summary": "Like a conversation message", + "summary": "Create a new prompt repo", "parameters": [ { "type": "string", - "description": "conversation uuid", - "name": "uuid", - "in": "path", - "required": true + "description": "current user, the owner", + "name": "current_user", + "in": "query" }, { - "type": "string", - "description": "message id", - "name": "id", - "in": "path", - "required": true + "description": "body", + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/types.CreatePromptRepoReq" + } } ], "responses": { @@ -13189,7 +12709,7 @@ const docTemplate = `{ "type": "object", "properties": { "data": { - "$ref": "#/definitions/component.ListRepoDiscussionResponse" + "$ref": "#/definitions/types.ListRepoDiscussionResponse" } } } @@ -13268,7 +12788,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/component.CreateRepoDiscussionRequest" + "$ref": "#/definitions/types.CreateRepoDiscussionRequest" } } ], @@ -13284,7 +12804,7 @@ const docTemplate = `{ "type": "object", "properties": { "data": { - "$ref": "#/definitions/component.CreateDiscussionResponse" + "$ref": "#/definitions/types.CreateDiscussionResponse" } } } @@ -15435,206 +14955,50 @@ const docTemplate = `{ }, { "type": "string", - "description": "branch or tag", - "name": "ref", - "in": "query" - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "allOf": [ - { - "$ref": "#/definitions/types.ResponseWithTotal" - }, - { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "$ref": "#/definitions/types.File" - } - } - } - } - ] - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } - }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - } - } - }, - "definitions": { - "component.CreateCommentRequest": { - "type": "object", - "required": [ - "content" - ], - "properties": { - "commentable_id": { - "type": "integer" - }, - "commentable_type": { - "type": "string" - }, - "content": { - "type": "string" - } - } - }, - "component.CreateCommentResponse": { - "type": "object", - "properties": { - "commentable_id": { - "type": "integer" - }, - "commentable_type": { - "type": "string" - }, - "created_at": { - "type": "string" - }, - "id": { - "type": "integer" - }, - "user": { - "$ref": "#/definitions/component.DiscussionResponse_User" - } - } - }, - "component.CreateDiscussionResponse": { - "type": "object", - "properties": { - "comment_count": { - "description": "DiscussionableID int64 ` + "`" + `json:\"discussionable_id\"` + "`" + `\nDiscussionableType string ` + "`" + `json:\"discussionable_type\"` + "`" + `", - "type": "integer" - }, - "created_at": { - "type": "string" - }, - "id": { - "type": "integer" - }, - "title": { - "type": "string" - }, - "user": { - "$ref": "#/definitions/component.DiscussionResponse_User" - } - } - }, - "component.CreateRepoDiscussionRequest": { - "type": "object", - "required": [ - "title" - ], - "properties": { - "title": { - "type": "string" - } - } - }, - "component.DiscussionResponse_Comment": { - "type": "object", - "properties": { - "content": { - "type": "string" - }, - "created_at": { - "type": "string" - }, - "id": { - "type": "integer" - }, - "user": { - "$ref": "#/definitions/component.DiscussionResponse_User" - } - } - }, - "component.DiscussionResponse_User": { - "type": "object", - "properties": { - "avatar": { - "type": "string" - }, - "id": { - "type": "integer" - }, - "name": { - "type": "string" - } - } - }, - "component.ListRepoDiscussionResponse": { - "type": "object", - "properties": { - "discussions": { - "type": "array", - "items": { - "$ref": "#/definitions/component.CreateDiscussionResponse" - } - } - } - }, - "component.ShowDiscussionResponse": { - "type": "object", - "properties": { - "comment_count": { - "type": "integer" - }, - "comments": { - "type": "array", - "items": { - "$ref": "#/definitions/component.DiscussionResponse_Comment" + "description": "branch or tag", + "name": "ref", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "allOf": [ + { + "$ref": "#/definitions/types.ResponseWithTotal" + }, + { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/definitions/types.File" + } + } + } + } + ] + } + }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/types.APIInternalServerError" + } } - }, - "id": { - "type": "integer" - }, - "title": { - "type": "string" - }, - "user": { - "$ref": "#/definitions/component.DiscussionResponse_User" - } - } - }, - "component.UpdateCommentRequest": { - "type": "object", - "required": [ - "content" - ], - "properties": { - "content": { - "type": "string" - } - } - }, - "component.UpdateDiscussionRequest": { - "type": "object", - "required": [ - "title" - ], - "properties": { - "title": { - "type": "string" } } - }, + } + }, + "definitions": { "database.AccessToken": { "type": "object", "properties": { @@ -16574,6 +15938,9 @@ const docTemplate = `{ "repository_id": { "type": "integer" }, + "sensitive_check_status": { + "type": "string" + }, "source": { "$ref": "#/definitions/types.RepositorySource" }, @@ -16810,39 +16177,6 @@ const docTemplate = `{ } } }, - "types.Conversation": { - "type": "object", - "required": [ - "message", - "uuid" - ], - "properties": { - "message": { - "type": "string" - }, - "temperature": { - "type": "number" - }, - "uuid": { - "type": "string" - } - } - }, - "types.ConversationTitle": { - "type": "object", - "required": [ - "title", - "uuid" - ], - "properties": { - "title": { - "type": "string" - }, - "uuid": { - "type": "string" - } - } - }, "types.CreateCategory": { "type": "object", "required": [ @@ -16923,6 +16257,43 @@ const docTemplate = `{ } } }, + "types.CreateCommentRequest": { + "type": "object", + "required": [ + "content" + ], + "properties": { + "commentable_id": { + "type": "integer" + }, + "commentable_type": { + "type": "string" + }, + "content": { + "type": "string" + } + } + }, + "types.CreateCommentResponse": { + "type": "object", + "properties": { + "commentable_id": { + "type": "integer" + }, + "commentable_type": { + "type": "string" + }, + "created_at": { + "type": "string" + }, + "id": { + "type": "integer" + }, + "user": { + "$ref": "#/definitions/types.DiscussionResponse_User" + } + } + }, "types.CreateDatasetReq": { "type": "object", "properties": { @@ -16964,6 +16335,27 @@ const docTemplate = `{ } } }, + "types.CreateDiscussionResponse": { + "type": "object", + "properties": { + "comment_count": { + "description": "DiscussionableID int64 ` + "`" + `json:\"discussionable_id\"` + "`" + `\nDiscussionableType string ` + "`" + `json:\"discussionable_type\"` + "`" + `", + "type": "integer" + }, + "created_at": { + "type": "string" + }, + "id": { + "type": "integer" + }, + "title": { + "type": "string" + }, + "user": { + "$ref": "#/definitions/types.DiscussionResponse_User" + } + } + }, "types.CreateFileReq": { "type": "object", "properties": { @@ -17216,6 +16608,17 @@ const docTemplate = `{ } } }, + "types.CreateRepoDiscussionRequest": { + "type": "object", + "required": [ + "title" + ], + "properties": { + "title": { + "type": "string" + } + } + }, "types.CreateSSHKeyRequest": { "type": "object", "properties": { @@ -17453,6 +16856,9 @@ const docTemplate = `{ "repository_id": { "type": "integer" }, + "sensitive_check_status": { + "type": "string" + }, "source": { "$ref": "#/definitions/types.RepositorySource" }, @@ -17568,6 +16974,37 @@ const docTemplate = `{ } } }, + "types.DiscussionResponse_Comment": { + "type": "object", + "properties": { + "content": { + "type": "string" + }, + "created_at": { + "type": "string" + }, + "id": { + "type": "integer" + }, + "user": { + "$ref": "#/definitions/types.DiscussionResponse_User" + } + } + }, + "types.DiscussionResponse_User": { + "type": "object", + "properties": { + "avatar": { + "type": "string" + }, + "id": { + "type": "integer" + }, + "name": { + "type": "string" + } + } + }, "types.EditOrgReq": { "type": "object", "properties": { @@ -17803,6 +17240,17 @@ const docTemplate = `{ } } }, + "types.ListRepoDiscussionResponse": { + "type": "object", + "properties": { + "discussions": { + "type": "array", + "items": { + "$ref": "#/definitions/types.CreateDiscussionResponse" + } + } + } + }, "types.Member": { "type": "object", "properties": { @@ -18044,20 +17492,6 @@ const docTemplate = `{ } } }, - "types.ModelPredictReq": { - "type": "object", - "properties": { - "current_user": { - "type": "string" - }, - "input": { - "type": "string" - }, - "version": { - "type": "string" - } - } - }, "types.ModelResp": { "type": "object", "properties": { @@ -18578,6 +18012,29 @@ const docTemplate = `{ "SensitiveCheckException" ] }, + "types.ShowDiscussionResponse": { + "type": "object", + "properties": { + "comment_count": { + "type": "integer" + }, + "comments": { + "type": "array", + "items": { + "$ref": "#/definitions/types.DiscussionResponse_Comment" + } + }, + "id": { + "type": "integer" + }, + "title": { + "type": "string" + }, + "user": { + "$ref": "#/definitions/types.DiscussionResponse_User" + } + } + }, "types.Space": { "type": "object", "properties": { @@ -18834,6 +18291,17 @@ const docTemplate = `{ } } }, + "types.UpdateCommentRequest": { + "type": "object", + "required": [ + "content" + ], + "properties": { + "content": { + "type": "string" + } + } + }, "types.UpdateDatasetReq": { "type": "object", "properties": { @@ -18850,6 +18318,17 @@ const docTemplate = `{ } } }, + "types.UpdateDiscussionRequest": { + "type": "object", + "required": [ + "title" + ], + "properties": { + "title": { + "type": "string" + } + } + }, "types.UpdateFileReq": { "type": "object", "properties": { diff --git a/docs/swagger.json b/docs/swagger.json index 711b3efd..e868e757 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -1162,7 +1162,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/component.UpdateCommentRequest" + "$ref": "#/definitions/types.UpdateCommentRequest" } } ], @@ -1900,7 +1900,7 @@ "type": "object", "properties": { "data": { - "$ref": "#/definitions/component.ShowDiscussionResponse" + "$ref": "#/definitions/types.ShowDiscussionResponse" } } } @@ -1959,7 +1959,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/component.UpdateDiscussionRequest" + "$ref": "#/definitions/types.UpdateDiscussionRequest" } } ], @@ -2080,7 +2080,7 @@ "data": { "type": "array", "items": { - "$ref": "#/definitions/component.DiscussionResponse_Comment" + "$ref": "#/definitions/types.DiscussionResponse_Comment" } } } @@ -2133,7 +2133,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/component.CreateCommentRequest" + "$ref": "#/definitions/types.CreateCommentRequest" } } ], @@ -2149,7 +2149,7 @@ "type": "object", "properties": { "data": { - "$ref": "#/definitions/component.CreateCommentResponse" + "$ref": "#/definitions/types.CreateCommentResponse" } } } @@ -4012,77 +4012,6 @@ } } }, - "/models/{namespace}/{name}/predict": { - "post": { - "security": [ - { - "ApiKey": [] - } - ], - "description": "invoke model prediction", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Model" - ], - "summary": "Invoke model prediction", - "parameters": [ - { - "type": "string", - "description": "namespace", - "name": "namespace", - "in": "path", - "required": true - }, - { - "type": "string", - "description": "name", - "name": "name", - "in": "path", - "required": true - }, - { - "type": "string", - "description": "current user", - "name": "current_user", - "in": "query" - }, - { - "description": "input for model prediction", - "name": "body", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/types.ModelPredictReq" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "type": "string" - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } - }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - } - }, "/models/{namespace}/{name}/relations": { "get": { "security": [ @@ -6478,478 +6407,70 @@ { "type": "string", "description": "filter by license tag", - "name": "license_tag", - "in": "query" - }, - { - "type": "string", - "description": "filter by language tag", - "name": "language_tag", - "in": "query" - }, - { - "type": "string", - "description": "sort by", - "name": "sort", - "in": "query" - }, - { - "enum": [ - "opencsg", - "huggingface", - "local" - ], - "type": "string", - "description": "source", - "name": "source", - "in": "query" - }, - { - "type": "integer", - "default": 20, - "description": "per", - "name": "per", - "in": "query" - }, - { - "type": "integer", - "default": 1, - "description": "per page", - "name": "page", - "in": "query" - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "allOf": [ - { - "$ref": "#/definitions/types.ResponseWithTotal" - }, - { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "$ref": "#/definitions/types.PromptRes" - } - }, - "total": { - "type": "integer" - } - } - } - ] - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } - }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - }, - "post": { - "security": [ - { - "ApiKey": [] - } - ], - "description": "create a new prompt repo", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Prompt" - ], - "summary": "Create a new prompt repo", - "parameters": [ - { - "type": "string", - "description": "current user, the owner", - "name": "current_user", - "in": "query" - }, - { - "description": "body", - "name": "body", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/types.CreatePromptRepoReq" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/types.Response" - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } - }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - } - }, - "/prompts/conversations": { - "get": { - "security": [ - { - "ApiKey": [] - } - ], - "description": "List conversations of user", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Prompt" - ], - "summary": "List conversations of user", - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/types.Response" - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } - }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - }, - "post": { - "security": [ - { - "ApiKey": [] - } - ], - "description": "Create new conversation", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Prompt" - ], - "summary": "Create new conversation", - "parameters": [ - { - "description": "body", - "name": "body", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/types.Conversation" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/types.Response" - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } - }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - } - }, - "/prompts/conversations/{id}": { - "get": { - "security": [ - { - "ApiKey": [] - } - ], - "description": "Get a conversation by uuid", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Prompt" - ], - "summary": "Get a conversation by uuid", - "parameters": [ - { - "type": "string", - "description": "conversation uuid", - "name": "id", - "in": "path", - "required": true - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/types.Response" - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } - }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - }, - "put": { - "security": [ - { - "ApiKey": [] - } - ], - "description": "Update a conversation title", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Prompt" - ], - "summary": "Update a conversation title", - "parameters": [ - { - "type": "string", - "description": "conversation uuid", - "name": "id", - "in": "path", - "required": true - }, - { - "description": "body", - "name": "body", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/types.ConversationTitle" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/types.Response" - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } - }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - }, - "post": { - "security": [ - { - "ApiKey": [] - } - ], - "description": "Submit a conversation message", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Prompt" - ], - "summary": "Submit a conversation message", - "parameters": [ - { - "type": "string", - "description": "conversation uuid", - "name": "id", - "in": "path", - "required": true - }, - { - "description": "body", - "name": "body", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/types.Conversation" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/types.Response" - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } - }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - }, - "delete": { - "security": [ - { - "ApiKey": [] - } - ], - "description": "Delete a conversation", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Prompt" - ], - "summary": "Delete a conversation", - "parameters": [ - { - "type": "string", - "description": "conversation uuid", - "name": "id", - "in": "path", - "required": true - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/types.Response" - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } + "name": "license_tag", + "in": "query" }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - } - }, - "/prompts/conversations/{id}/message/{msgid}/hate": { - "put": { - "security": [ { - "ApiKey": [] - } - ], - "description": "Hate a conversation message", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Prompt" - ], - "summary": "Hate a conversation message", - "parameters": [ + "type": "string", + "description": "filter by language tag", + "name": "language_tag", + "in": "query" + }, { "type": "string", - "description": "conversation uuid", - "name": "uuid", - "in": "path", - "required": true + "description": "sort by", + "name": "sort", + "in": "query" }, { + "enum": [ + "opencsg", + "huggingface", + "local" + ], "type": "string", - "description": "message id", - "name": "id", - "in": "path", - "required": true + "description": "source", + "name": "source", + "in": "query" + }, + { + "type": "integer", + "default": 20, + "description": "per", + "name": "per", + "in": "query" + }, + { + "type": "integer", + "default": 1, + "description": "per page", + "name": "page", + "in": "query" } ], "responses": { "200": { "description": "OK", "schema": { - "$ref": "#/definitions/types.Response" + "allOf": [ + { + "$ref": "#/definitions/types.ResponseWithTotal" + }, + { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/definitions/types.PromptRes" + } + }, + "total": { + "type": "integer" + } + } + } + ] } }, "400": { @@ -6965,16 +6486,14 @@ } } } - } - }, - "/prompts/conversations/{id}/message/{msgid}/like": { - "put": { + }, + "post": { "security": [ { "ApiKey": [] } ], - "description": "Like a conversation message", + "description": "create a new prompt repo", "consumes": [ "application/json" ], @@ -6984,21 +6503,22 @@ "tags": [ "Prompt" ], - "summary": "Like a conversation message", + "summary": "Create a new prompt repo", "parameters": [ { "type": "string", - "description": "conversation uuid", - "name": "uuid", - "in": "path", - "required": true + "description": "current user, the owner", + "name": "current_user", + "in": "query" }, { - "type": "string", - "description": "message id", - "name": "id", - "in": "path", - "required": true + "description": "body", + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/types.CreatePromptRepoReq" + } } ], "responses": { @@ -13178,7 +12698,7 @@ "type": "object", "properties": { "data": { - "$ref": "#/definitions/component.ListRepoDiscussionResponse" + "$ref": "#/definitions/types.ListRepoDiscussionResponse" } } } @@ -13257,7 +12777,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/component.CreateRepoDiscussionRequest" + "$ref": "#/definitions/types.CreateRepoDiscussionRequest" } } ], @@ -13273,7 +12793,7 @@ "type": "object", "properties": { "data": { - "$ref": "#/definitions/component.CreateDiscussionResponse" + "$ref": "#/definitions/types.CreateDiscussionResponse" } } } @@ -15424,206 +14944,50 @@ }, { "type": "string", - "description": "branch or tag", - "name": "ref", - "in": "query" - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "allOf": [ - { - "$ref": "#/definitions/types.ResponseWithTotal" - }, - { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "$ref": "#/definitions/types.File" - } - } - } - } - ] - } - }, - "400": { - "description": "Bad request", - "schema": { - "$ref": "#/definitions/types.APIBadRequest" - } - }, - "500": { - "description": "Internal server error", - "schema": { - "$ref": "#/definitions/types.APIInternalServerError" - } - } - } - } - } - }, - "definitions": { - "component.CreateCommentRequest": { - "type": "object", - "required": [ - "content" - ], - "properties": { - "commentable_id": { - "type": "integer" - }, - "commentable_type": { - "type": "string" - }, - "content": { - "type": "string" - } - } - }, - "component.CreateCommentResponse": { - "type": "object", - "properties": { - "commentable_id": { - "type": "integer" - }, - "commentable_type": { - "type": "string" - }, - "created_at": { - "type": "string" - }, - "id": { - "type": "integer" - }, - "user": { - "$ref": "#/definitions/component.DiscussionResponse_User" - } - } - }, - "component.CreateDiscussionResponse": { - "type": "object", - "properties": { - "comment_count": { - "description": "DiscussionableID int64 `json:\"discussionable_id\"`\nDiscussionableType string `json:\"discussionable_type\"`", - "type": "integer" - }, - "created_at": { - "type": "string" - }, - "id": { - "type": "integer" - }, - "title": { - "type": "string" - }, - "user": { - "$ref": "#/definitions/component.DiscussionResponse_User" - } - } - }, - "component.CreateRepoDiscussionRequest": { - "type": "object", - "required": [ - "title" - ], - "properties": { - "title": { - "type": "string" - } - } - }, - "component.DiscussionResponse_Comment": { - "type": "object", - "properties": { - "content": { - "type": "string" - }, - "created_at": { - "type": "string" - }, - "id": { - "type": "integer" - }, - "user": { - "$ref": "#/definitions/component.DiscussionResponse_User" - } - } - }, - "component.DiscussionResponse_User": { - "type": "object", - "properties": { - "avatar": { - "type": "string" - }, - "id": { - "type": "integer" - }, - "name": { - "type": "string" - } - } - }, - "component.ListRepoDiscussionResponse": { - "type": "object", - "properties": { - "discussions": { - "type": "array", - "items": { - "$ref": "#/definitions/component.CreateDiscussionResponse" - } - } - } - }, - "component.ShowDiscussionResponse": { - "type": "object", - "properties": { - "comment_count": { - "type": "integer" - }, - "comments": { - "type": "array", - "items": { - "$ref": "#/definitions/component.DiscussionResponse_Comment" + "description": "branch or tag", + "name": "ref", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "allOf": [ + { + "$ref": "#/definitions/types.ResponseWithTotal" + }, + { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/definitions/types.File" + } + } + } + } + ] + } + }, + "400": { + "description": "Bad request", + "schema": { + "$ref": "#/definitions/types.APIBadRequest" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/types.APIInternalServerError" + } } - }, - "id": { - "type": "integer" - }, - "title": { - "type": "string" - }, - "user": { - "$ref": "#/definitions/component.DiscussionResponse_User" - } - } - }, - "component.UpdateCommentRequest": { - "type": "object", - "required": [ - "content" - ], - "properties": { - "content": { - "type": "string" - } - } - }, - "component.UpdateDiscussionRequest": { - "type": "object", - "required": [ - "title" - ], - "properties": { - "title": { - "type": "string" } } - }, + } + }, + "definitions": { "database.AccessToken": { "type": "object", "properties": { @@ -16563,6 +15927,9 @@ "repository_id": { "type": "integer" }, + "sensitive_check_status": { + "type": "string" + }, "source": { "$ref": "#/definitions/types.RepositorySource" }, @@ -16799,39 +16166,6 @@ } } }, - "types.Conversation": { - "type": "object", - "required": [ - "message", - "uuid" - ], - "properties": { - "message": { - "type": "string" - }, - "temperature": { - "type": "number" - }, - "uuid": { - "type": "string" - } - } - }, - "types.ConversationTitle": { - "type": "object", - "required": [ - "title", - "uuid" - ], - "properties": { - "title": { - "type": "string" - }, - "uuid": { - "type": "string" - } - } - }, "types.CreateCategory": { "type": "object", "required": [ @@ -16912,6 +16246,43 @@ } } }, + "types.CreateCommentRequest": { + "type": "object", + "required": [ + "content" + ], + "properties": { + "commentable_id": { + "type": "integer" + }, + "commentable_type": { + "type": "string" + }, + "content": { + "type": "string" + } + } + }, + "types.CreateCommentResponse": { + "type": "object", + "properties": { + "commentable_id": { + "type": "integer" + }, + "commentable_type": { + "type": "string" + }, + "created_at": { + "type": "string" + }, + "id": { + "type": "integer" + }, + "user": { + "$ref": "#/definitions/types.DiscussionResponse_User" + } + } + }, "types.CreateDatasetReq": { "type": "object", "properties": { @@ -16953,6 +16324,27 @@ } } }, + "types.CreateDiscussionResponse": { + "type": "object", + "properties": { + "comment_count": { + "description": "DiscussionableID int64 `json:\"discussionable_id\"`\nDiscussionableType string `json:\"discussionable_type\"`", + "type": "integer" + }, + "created_at": { + "type": "string" + }, + "id": { + "type": "integer" + }, + "title": { + "type": "string" + }, + "user": { + "$ref": "#/definitions/types.DiscussionResponse_User" + } + } + }, "types.CreateFileReq": { "type": "object", "properties": { @@ -17205,6 +16597,17 @@ } } }, + "types.CreateRepoDiscussionRequest": { + "type": "object", + "required": [ + "title" + ], + "properties": { + "title": { + "type": "string" + } + } + }, "types.CreateSSHKeyRequest": { "type": "object", "properties": { @@ -17442,6 +16845,9 @@ "repository_id": { "type": "integer" }, + "sensitive_check_status": { + "type": "string" + }, "source": { "$ref": "#/definitions/types.RepositorySource" }, @@ -17557,6 +16963,37 @@ } } }, + "types.DiscussionResponse_Comment": { + "type": "object", + "properties": { + "content": { + "type": "string" + }, + "created_at": { + "type": "string" + }, + "id": { + "type": "integer" + }, + "user": { + "$ref": "#/definitions/types.DiscussionResponse_User" + } + } + }, + "types.DiscussionResponse_User": { + "type": "object", + "properties": { + "avatar": { + "type": "string" + }, + "id": { + "type": "integer" + }, + "name": { + "type": "string" + } + } + }, "types.EditOrgReq": { "type": "object", "properties": { @@ -17792,6 +17229,17 @@ } } }, + "types.ListRepoDiscussionResponse": { + "type": "object", + "properties": { + "discussions": { + "type": "array", + "items": { + "$ref": "#/definitions/types.CreateDiscussionResponse" + } + } + } + }, "types.Member": { "type": "object", "properties": { @@ -18033,20 +17481,6 @@ } } }, - "types.ModelPredictReq": { - "type": "object", - "properties": { - "current_user": { - "type": "string" - }, - "input": { - "type": "string" - }, - "version": { - "type": "string" - } - } - }, "types.ModelResp": { "type": "object", "properties": { @@ -18567,6 +18001,29 @@ "SensitiveCheckException" ] }, + "types.ShowDiscussionResponse": { + "type": "object", + "properties": { + "comment_count": { + "type": "integer" + }, + "comments": { + "type": "array", + "items": { + "$ref": "#/definitions/types.DiscussionResponse_Comment" + } + }, + "id": { + "type": "integer" + }, + "title": { + "type": "string" + }, + "user": { + "$ref": "#/definitions/types.DiscussionResponse_User" + } + } + }, "types.Space": { "type": "object", "properties": { @@ -18823,6 +18280,17 @@ } } }, + "types.UpdateCommentRequest": { + "type": "object", + "required": [ + "content" + ], + "properties": { + "content": { + "type": "string" + } + } + }, "types.UpdateDatasetReq": { "type": "object", "properties": { @@ -18839,6 +18307,17 @@ } } }, + "types.UpdateDiscussionRequest": { + "type": "object", + "required": [ + "title" + ], + "properties": { + "title": { + "type": "string" + } + } + }, "types.UpdateFileReq": { "type": "object", "properties": { diff --git a/docs/swagger.yaml b/docs/swagger.yaml index ccfa9f9e..625d0f9a 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -1,107 +1,4 @@ definitions: - component.CreateCommentRequest: - properties: - commentable_id: - type: integer - commentable_type: - type: string - content: - type: string - required: - - content - type: object - component.CreateCommentResponse: - properties: - commentable_id: - type: integer - commentable_type: - type: string - created_at: - type: string - id: - type: integer - user: - $ref: '#/definitions/component.DiscussionResponse_User' - type: object - component.CreateDiscussionResponse: - properties: - comment_count: - description: |- - DiscussionableID int64 `json:"discussionable_id"` - DiscussionableType string `json:"discussionable_type"` - type: integer - created_at: - type: string - id: - type: integer - title: - type: string - user: - $ref: '#/definitions/component.DiscussionResponse_User' - type: object - component.CreateRepoDiscussionRequest: - properties: - title: - type: string - required: - - title - type: object - component.DiscussionResponse_Comment: - properties: - content: - type: string - created_at: - type: string - id: - type: integer - user: - $ref: '#/definitions/component.DiscussionResponse_User' - type: object - component.DiscussionResponse_User: - properties: - avatar: - type: string - id: - type: integer - name: - type: string - type: object - component.ListRepoDiscussionResponse: - properties: - discussions: - items: - $ref: '#/definitions/component.CreateDiscussionResponse' - type: array - type: object - component.ShowDiscussionResponse: - properties: - comment_count: - type: integer - comments: - items: - $ref: '#/definitions/component.DiscussionResponse_Comment' - type: array - id: - type: integer - title: - type: string - user: - $ref: '#/definitions/component.DiscussionResponse_User' - type: object - component.UpdateCommentRequest: - properties: - content: - type: string - required: - - content - type: object - component.UpdateDiscussionRequest: - properties: - title: - type: string - required: - - title - type: object database.AccessToken: properties: application: @@ -732,6 +629,8 @@ definitions: $ref: '#/definitions/types.Repository' repository_id: type: integer + sensitive_check_status: + type: string source: $ref: '#/definitions/types.RepositorySource' sync_status: @@ -887,28 +786,6 @@ definitions: total: type: integer type: object - types.Conversation: - properties: - message: - type: string - temperature: - type: number - uuid: - type: string - required: - - message - - uuid - type: object - types.ConversationTitle: - properties: - title: - type: string - uuid: - type: string - required: - - title - - uuid - type: object types.CreateCategory: properties: name: @@ -965,6 +842,30 @@ definitions: example: '#fff000' type: string type: object + types.CreateCommentRequest: + properties: + commentable_id: + type: integer + commentable_type: + type: string + content: + type: string + required: + - content + type: object + types.CreateCommentResponse: + properties: + commentable_id: + type: integer + commentable_type: + type: string + created_at: + type: string + id: + type: integer + user: + $ref: '#/definitions/types.DiscussionResponse_User' + type: object types.CreateDatasetReq: properties: default_branch: @@ -994,6 +895,22 @@ definitions: type: type: integer type: object + types.CreateDiscussionResponse: + properties: + comment_count: + description: |- + DiscussionableID int64 `json:"discussionable_id"` + DiscussionableType string `json:"discussionable_type"` + type: integer + created_at: + type: string + id: + type: integer + title: + type: string + user: + $ref: '#/definitions/types.DiscussionResponse_User' + type: object types.CreateFileReq: properties: branch: @@ -1167,6 +1084,13 @@ definitions: readme: type: string type: object + types.CreateRepoDiscussionRequest: + properties: + title: + type: string + required: + - title + type: object types.CreateSSHKeyRequest: properties: content: @@ -1326,6 +1250,8 @@ definitions: $ref: '#/definitions/types.Repository' repository_id: type: integer + sensitive_check_status: + type: string source: $ref: '#/definitions/types.RepositorySource' sync_status: @@ -1401,6 +1327,26 @@ definitions: secure_level: type: integer type: object + types.DiscussionResponse_Comment: + properties: + content: + type: string + created_at: + type: string + id: + type: integer + user: + $ref: '#/definitions/types.DiscussionResponse_User' + type: object + types.DiscussionResponse_User: + properties: + avatar: + type: string + id: + type: integer + name: + type: string + type: object types.EditOrgReq: properties: description: @@ -1561,6 +1507,13 @@ definitions: type: string type: array type: object + types.ListRepoDiscussionResponse: + properties: + discussions: + items: + $ref: '#/definitions/types.CreateDiscussionResponse' + type: array + type: object types.Member: properties: avatar: @@ -1725,15 +1678,6 @@ definitions: description: 'widget UI style: generation,chat' example: generation type: object - types.ModelPredictReq: - properties: - current_user: - type: string - input: - type: string - version: - type: string - type: object types.ModelResp: properties: description: @@ -2088,6 +2032,21 @@ definitions: - SensitiveCheckPass - SensitiveCheckSkip - SensitiveCheckException + types.ShowDiscussionResponse: + properties: + comment_count: + type: integer + comments: + items: + $ref: '#/definitions/types.DiscussionResponse_Comment' + type: array + id: + type: integer + title: + type: string + user: + $ref: '#/definitions/types.DiscussionResponse_User' + type: object types.Space: properties: can_manage: @@ -2263,6 +2222,13 @@ definitions: type: integer type: array type: object + types.UpdateCommentRequest: + properties: + content: + type: string + required: + - content + type: object types.UpdateDatasetReq: properties: description: @@ -2274,6 +2240,13 @@ definitions: example: false type: boolean type: object + types.UpdateDiscussionRequest: + properties: + title: + type: string + required: + - title + type: object types.UpdateFileReq: properties: branch: @@ -2773,7 +2746,7 @@ paths: - $ref: '#/definitions/types.Response' - properties: data: - $ref: '#/definitions/component.ListRepoDiscussionResponse' + $ref: '#/definitions/types.ListRepoDiscussionResponse' type: object "400": description: Bad request @@ -2823,7 +2796,7 @@ paths: name: body required: true schema: - $ref: '#/definitions/component.CreateRepoDiscussionRequest' + $ref: '#/definitions/types.CreateRepoDiscussionRequest' produces: - application/json responses: @@ -2834,7 +2807,7 @@ paths: - $ref: '#/definitions/types.Response' - properties: data: - $ref: '#/definitions/component.CreateDiscussionResponse' + $ref: '#/definitions/types.CreateDiscussionResponse' type: object "400": description: Bad request @@ -5013,7 +4986,7 @@ paths: name: body required: true schema: - $ref: '#/definitions/component.UpdateCommentRequest' + $ref: '#/definitions/types.UpdateCommentRequest' produces: - application/json responses: @@ -5475,7 +5448,7 @@ paths: - $ref: '#/definitions/types.Response' - properties: data: - $ref: '#/definitions/component.ShowDiscussionResponse' + $ref: '#/definitions/types.ShowDiscussionResponse' type: object "400": description: Bad request @@ -5510,7 +5483,7 @@ paths: name: body required: true schema: - $ref: '#/definitions/component.UpdateDiscussionRequest' + $ref: '#/definitions/types.UpdateDiscussionRequest' produces: - application/json responses: @@ -5553,7 +5526,7 @@ paths: - properties: data: items: - $ref: '#/definitions/component.DiscussionResponse_Comment' + $ref: '#/definitions/types.DiscussionResponse_Comment' type: array type: object "400": @@ -5584,7 +5557,7 @@ paths: name: body required: true schema: - $ref: '#/definitions/component.CreateCommentRequest' + $ref: '#/definitions/types.CreateCommentRequest' produces: - application/json responses: @@ -5595,7 +5568,7 @@ paths: - $ref: '#/definitions/types.Response' - properties: data: - $ref: '#/definitions/component.CreateCommentResponse' + $ref: '#/definitions/types.CreateCommentResponse' type: object "400": description: Bad request @@ -6696,52 +6669,6 @@ paths: summary: Stop a finetune instance tags: - Model - /models/{namespace}/{name}/predict: - post: - consumes: - - application/json - description: invoke model prediction - parameters: - - description: namespace - in: path - name: namespace - required: true - type: string - - description: name - in: path - name: name - required: true - type: string - - description: current user - in: query - name: current_user - type: string - - description: input for model prediction - in: body - name: body - required: true - schema: - $ref: '#/definitions/types.ModelPredictReq' - produces: - - application/json - responses: - "200": - description: OK - schema: - type: string - "400": - description: Bad request - schema: - $ref: '#/definitions/types.APIBadRequest' - "500": - description: Internal server error - schema: - $ref: '#/definitions/types.APIInternalServerError' - security: - - ApiKey: [] - summary: Invoke model prediction - tags: - - Model /models/{namespace}/{name}/relations: get: consumes: @@ -9022,267 +8949,6 @@ paths: summary: update the tags of a certain category tags: - Prompt - /prompts/conversations: - get: - consumes: - - application/json - description: List conversations of user - produces: - - application/json - responses: - "200": - description: OK - schema: - $ref: '#/definitions/types.Response' - "400": - description: Bad request - schema: - $ref: '#/definitions/types.APIBadRequest' - "500": - description: Internal server error - schema: - $ref: '#/definitions/types.APIInternalServerError' - security: - - ApiKey: [] - summary: List conversations of user - tags: - - Prompt - post: - consumes: - - application/json - description: Create new conversation - parameters: - - description: body - in: body - name: body - required: true - schema: - $ref: '#/definitions/types.Conversation' - produces: - - application/json - responses: - "200": - description: OK - schema: - $ref: '#/definitions/types.Response' - "400": - description: Bad request - schema: - $ref: '#/definitions/types.APIBadRequest' - "500": - description: Internal server error - schema: - $ref: '#/definitions/types.APIInternalServerError' - security: - - ApiKey: [] - summary: Create new conversation - tags: - - Prompt - /prompts/conversations/{id}: - delete: - consumes: - - application/json - description: Delete a conversation - parameters: - - description: conversation uuid - in: path - name: id - required: true - type: string - produces: - - application/json - responses: - "200": - description: OK - schema: - $ref: '#/definitions/types.Response' - "400": - description: Bad request - schema: - $ref: '#/definitions/types.APIBadRequest' - "500": - description: Internal server error - schema: - $ref: '#/definitions/types.APIInternalServerError' - security: - - ApiKey: [] - summary: Delete a conversation - tags: - - Prompt - get: - consumes: - - application/json - description: Get a conversation by uuid - parameters: - - description: conversation uuid - in: path - name: id - required: true - type: string - produces: - - application/json - responses: - "200": - description: OK - schema: - $ref: '#/definitions/types.Response' - "400": - description: Bad request - schema: - $ref: '#/definitions/types.APIBadRequest' - "500": - description: Internal server error - schema: - $ref: '#/definitions/types.APIInternalServerError' - security: - - ApiKey: [] - summary: Get a conversation by uuid - tags: - - Prompt - post: - consumes: - - application/json - description: Submit a conversation message - parameters: - - description: conversation uuid - in: path - name: id - required: true - type: string - - description: body - in: body - name: body - required: true - schema: - $ref: '#/definitions/types.Conversation' - produces: - - application/json - responses: - "200": - description: OK - schema: - $ref: '#/definitions/types.Response' - "400": - description: Bad request - schema: - $ref: '#/definitions/types.APIBadRequest' - "500": - description: Internal server error - schema: - $ref: '#/definitions/types.APIInternalServerError' - security: - - ApiKey: [] - summary: Submit a conversation message - tags: - - Prompt - put: - consumes: - - application/json - description: Update a conversation title - parameters: - - description: conversation uuid - in: path - name: id - required: true - type: string - - description: body - in: body - name: body - required: true - schema: - $ref: '#/definitions/types.ConversationTitle' - produces: - - application/json - responses: - "200": - description: OK - schema: - $ref: '#/definitions/types.Response' - "400": - description: Bad request - schema: - $ref: '#/definitions/types.APIBadRequest' - "500": - description: Internal server error - schema: - $ref: '#/definitions/types.APIInternalServerError' - security: - - ApiKey: [] - summary: Update a conversation title - tags: - - Prompt - /prompts/conversations/{id}/message/{msgid}/hate: - put: - consumes: - - application/json - description: Hate a conversation message - parameters: - - description: conversation uuid - in: path - name: uuid - required: true - type: string - - description: message id - in: path - name: id - required: true - type: string - produces: - - application/json - responses: - "200": - description: OK - schema: - $ref: '#/definitions/types.Response' - "400": - description: Bad request - schema: - $ref: '#/definitions/types.APIBadRequest' - "500": - description: Internal server error - schema: - $ref: '#/definitions/types.APIInternalServerError' - security: - - ApiKey: [] - summary: Hate a conversation message - tags: - - Prompt - /prompts/conversations/{id}/message/{msgid}/like: - put: - consumes: - - application/json - description: Like a conversation message - parameters: - - description: conversation uuid - in: path - name: uuid - required: true - type: string - - description: message id - in: path - name: id - required: true - type: string - produces: - - application/json - responses: - "200": - description: OK - schema: - $ref: '#/definitions/types.Response' - "400": - description: Bad request - schema: - $ref: '#/definitions/types.APIBadRequest' - "500": - description: Internal server error - schema: - $ref: '#/definitions/types.APIInternalServerError' - security: - - ApiKey: [] - summary: Like a conversation message - tags: - - Prompt /recom/opweight: post: consumes: diff --git a/go.mod b/go.mod index a4f188ca..8e9d06a5 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module opencsg.com/csghub-server -go 1.22.0 - -toolchain go1.22.6 +go 1.23 require ( github.com/DATA-DOG/go-txdb v0.2.0 @@ -20,7 +18,6 @@ require ( github.com/golang-jwt/jwt/v5 v5.2.1 github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f github.com/google/wire v0.6.0 - github.com/jarcoal/httpmock v1.3.1 github.com/marcboeker/go-duckdb v1.5.6 github.com/minio/minio-go/v7 v7.0.66 github.com/minio/sha256-simd v1.0.1 diff --git a/go.sum b/go.sum index db84171c..41d5b88b 100644 --- a/go.sum +++ b/go.sum @@ -307,7 +307,6 @@ github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af h1:kmjWCqn2qkEml422C2 github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= -github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -360,8 +359,6 @@ github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= -github.com/jarcoal/httpmock v1.3.1 h1:iUx3whfZWVf3jT01hQTO/Eo5sAYtB2/rqaUuOtpInww= -github.com/jarcoal/httpmock v1.3.1/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= github.com/jellydator/ttlcache/v2 v2.11.1 h1:AZGME43Eh2Vv3giG6GeqeLeFXxwxn1/qHItqWZl6U64= github.com/jellydator/ttlcache/v2 v2.11.1/go.mod h1:RtE5Snf0/57e+2cLWFYWCCsLas2Hy3c5Z4n14XmSvTI= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= @@ -428,8 +425,6 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= -github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g= -github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM= github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4= github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY= github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= From ce23982ff496aa3cfb6cc6da1aadd8f8b9e66425 Mon Sep 17 00:00:00 2001 From: James <370036720@qq.com> Date: Fri, 3 Jan 2025 12:04:24 +0800 Subject: [PATCH 31/34] fix: model task detection issue and model download issue (#229) * mini fix bug * Fix model download issue --------- Co-authored-by: James --- .../Dockerfile.lm-evaluation-harness | 1 + docker/evaluation/Dockerfile.opencompass | 4 +-- .../evaluation/lm-evaluation-harness/start.sh | 9 ++++-- docker/evaluation/opencompass/start.sh | 30 +++++++++++++++++-- docker/finetune/Dockerfile.llamafactory | 11 ++++--- docker/inference/Dockerfile.tgi | 2 +- docker/inference/Dockerfile.vllm | 2 +- docker/inference/Dockerfile.vllm-cpu | 2 +- docker/inference/README.md | 22 +++++++------- docker/spaces/Dockerfile.nginx | 7 +++-- 10 files changed, 59 insertions(+), 31 deletions(-) diff --git a/docker/evaluation/Dockerfile.lm-evaluation-harness b/docker/evaluation/Dockerfile.lm-evaluation-harness index 7ef795d8..c9913664 100644 --- a/docker/evaluation/Dockerfile.lm-evaluation-harness +++ b/docker/evaluation/Dockerfile.lm-evaluation-harness @@ -8,6 +8,7 @@ WORKDIR /workspace/ RUN git clone --depth 1 https://github.com/EleutherAI/lm-evaluation-harness.git --branch v0.4.6 --single-branch && \ cd lm-evaluation-harness && pip install setuptools --upgrade --no-cache-dir -e \ ".[ifeval,math,multilingual,sentencepiece]" +RUN pip install --no-cache-dir huggingface-hub==0.27.0 COPY ./lm-evaluation-harness/ /etc/csghub/ RUN ln -s /usr/bin/python3 /usr/bin/python &&\ chmod +x /etc/csghub/*.sh diff --git a/docker/evaluation/Dockerfile.opencompass b/docker/evaluation/Dockerfile.opencompass index e4a59551..d0e5f6de 100644 --- a/docker/evaluation/Dockerfile.opencompass +++ b/docker/evaluation/Dockerfile.opencompass @@ -3,8 +3,8 @@ RUN apt-get update && apt-get -y install python3.10 python3-pip dumb-init \ && apt-get clean && rm -rf /var/lib/apt/lists/* RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple \ && pip install --no-cache-dir opencompass==0.3.5 \ - csghub-sdk==0.4.4 minio oss2 -RUN pip install --no-cache-dir vllm==0.6.3.post1 openpyxl + csghub-sdk==0.4.6 minio oss2 +RUN pip install --no-cache-dir vllm==0.6.3.post1 openpyxl modelscope==1.20.1 datasets==2.20.0 COPY ./opencompass/ /etc/csghub/ RUN ln -s /usr/bin/python3 /usr/bin/python &&\ chmod +x /etc/csghub/*.sh diff --git a/docker/evaluation/lm-evaluation-harness/start.sh b/docker/evaluation/lm-evaluation-harness/start.sh index 7d687d16..74b53215 100644 --- a/docker/evaluation/lm-evaluation-harness/start.sh +++ b/docker/evaluation/lm-evaluation-harness/start.sh @@ -25,6 +25,7 @@ search_path_with_most_term() { echo $max_count_path return 0 } +export HF_ENDPOINT="$HF_ENDPOINT/hf" #download datasets if [ ! -z "$DATASET_IDS" ]; then echo "Downloading datasets..." @@ -44,8 +45,6 @@ if [ $? -ne 0 ]; then exit 1 fi -export HF_ENDPOINT="$HF_ENDPOINT/hf" - tasks="" task_dir="/workspace/lm-evaluation-harness/lm_eval/tasks" IFS=',' read -r -a dataset_repos <<< "$DATASET_IDS" @@ -53,12 +52,16 @@ if [ -z "$NUM_FEW_SHOT" ]; then NUM_FEW_SHOT=0 fi script_dts_array=("allenai/winogrande" "facebook/anli" "aps/super_glue" "Rowan/hellaswag" "nyu-mll/blimp" "EdinburghNLP/orange_sum" "facebook/xnli" "nyu-mll/glue" "openai/gsm8k" "cimec/lambada" "allenai/math_qa" "openlifescienceai/medmcqa" "google-research-datasets/nq_open" "allenai/openbookqa" "google-research-datasets/paws-x" "ybisk/piqa" "community-datasets/qa4mre" "allenai/sciq" "allenai/social_i_qa" "LSDSem/story_cloze" "allenai/swag" "IWSLT/iwslt2017" "wmt/wmt14" "wmt/wmt16","mandarjoshi/trivia_qa" "truthfulqa/truthful_qa" "Stanford/web_questions" "ErnestSDavis/winograd_wsc" "cambridgeltl/xcopa" "google/xquad") +script_dts_multi_config_array=("allenai/winogrande") for repo in "${dataset_repos[@]}"; do repo_name="${repo#*/}" if [[ " ${script_dts_array[@]} " =~ " ${repo} " ]]; then #need replace with real path echo "replace script repo with namespace repo" - find . -type f -exec sed -i "s|dataset_path: $repo_name|dataset_path: $repo|g" {} + + find $task_dir -type f -exec sed -i "s|dataset_path: $repo_name|dataset_path: $repo|g" {} + + if [[ " ${script_dts_multi_config_array[@]} " =~ " ${repo} " ]]; then + grep -rl "dataset_path: $repo" "$task_dir" | xargs sed -i "s|dataset_name: .*|dataset_name: null|g" + fi fi # search full id to cover mirror repo id mapfile -t yaml_files < <(grep -Rl -E "(dataset_path: ${repo}($|\s))" $task_dir) diff --git a/docker/evaluation/opencompass/start.sh b/docker/evaluation/opencompass/start.sh index 8ee5ef04..e1c356da 100644 --- a/docker/evaluation/opencompass/start.sh +++ b/docker/evaluation/opencompass/start.sh @@ -20,6 +20,24 @@ if ! grep -q "chat_template" "$repo_tokenizer_config"; then cp "/workspace/$MODEL_ID/tokenizer_config.json" $filename awk -v ins="$insert_string" '/tokenizer_class/ {print; print ins; next}1' "$filename" > tmpfile && mv -f tmpfile $repo_tokenizer_config fi +#fix: use local dataset +export DATASET_SOURCE=ModelScope +export COMPASS_DATA_CACHE=/workspace/data/ +dataset_path="/usr/local/lib/python3.10/dist-packages/opencompass/datasets/" +find $dataset_path -type f -name "*.py" -exec sed -i 's/get_data_path(path)/"\/workspace\/data\/"+get_data_path(path)/g' {} + +declare -A dataset_alias +dataset_alias["ai2_arc"]="ARC-c" +dataset_alias["ceval-exam"]="ceval" +dataset_alias["OCNLI"]="ocnli" +dataset_alias["cmrc_dev"]="CMRC_dev" +dataset_alias["drcd_dev"]="DRCD_dev" +dataset_alias["humaneval"]="openai_humaneval" +dataset_alias["LCSTS"]="lcsts" +dataset_alias["natural_question"]="nq" +dataset_alias["strategy_qa"]="strategyqa" +dataset_alias["boolq"]="BoolQ" +dataset_alias["trivia_qa"]="triviaqa" +dataset_alias["xsum"]="Xsum" # download datasets IFS=',' read -r -a dataset_repos <<< "$DATASET_IDS" # Loop through the array and print each value @@ -29,9 +47,12 @@ for repo in "${dataset_repos[@]}"; do # check $dataset existing if [ ! -d "/workspace/data/$repo" ]; then echo "Start downloading dataset $repo..." - csghub-cli download $repo -t dataset -k $ACCESS_TOKEN -e $HF_ENDPOINT -cd /tmp/ - # mv "$dataset" to "/workspace/data/" - mv -f "/tmp/$repo" "/workspace/data/" + csghub-cli download $repo -t dataset -r master -k $ACCESS_TOKEN -e $HF_ENDPOINT -cd /workspace/data/ + if [ $? -ne 0 ]; then + echo "Download dataset $repo failed,retry with main branch" + #for some special case which use main branch + csghub-cli download $repo -t dataset -k $ACCESS_TOKEN -e $HF_ENDPOINT -cd /workspace/data/ + fi fi # get answer mode task_path=`python -W ignore /etc/csghub/get_answer_mode.py $repo` @@ -52,6 +73,9 @@ for repo in "${dataset_repos[@]}"; do task=`basename $task_conf_file | cut -d'.' -f1` dataset_tasks="$dataset_tasks $task" ori_name=`basename $repo` + if [ -n "${dataset_alias[$ori_name]}" ]; then + ori_name="${dataset_alias[$ori_name]}" + fi dataset_tasks_ori="$dataset_tasks_ori $ori_name" continue fi diff --git a/docker/finetune/Dockerfile.llamafactory b/docker/finetune/Dockerfile.llamafactory index 99270ca3..c28fdccb 100644 --- a/docker/finetune/Dockerfile.llamafactory +++ b/docker/finetune/Dockerfile.llamafactory @@ -26,19 +26,18 @@ RUN ln -sf /usr/bin/python3 /usr/bin/python && \ pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \ pip install --no-cache-dir jupyterlab numpy==1.26.4 \ torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 \ - jupyter-server-proxy==4.2.0 + jupyter-server-proxy==4.4.0 fastapi==0.112.2 # Create a working directory WORKDIR /etc/csghub -RUN git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git --branch v0.8.3 --single-branch && cd LLaMA-Factory && \ +RUN git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git --branch v0.9.1 --single-branch && cd LLaMA-Factory && \ pip install --no-cache-dir -e ".[metrics,deepspeed]" # Setup supervisord -COPY script/supervisord.conf /etc/supervisor/conf.d/supervisord.conf -COPY script/jupyter_notebook_config.py /root/.jupyter/jupyter_notebook_config.py -COPY script/ /etc/csghub/ -COPY script/handlers.py /usr/local/lib/python3.10/dist-packages/jupyter_server_proxy/handlers.py +COPY llama-factory/supervisord.conf /etc/supervisor/conf.d/supervisord.conf +COPY llama-factory/jupyter_notebook_config.py /root/.jupyter/jupyter_notebook_config.py +COPY llama-factory/ /etc/csghub/ RUN mkdir -p /var/log/supervisord && \ chmod +x /etc/csghub/*.sh && \ diff --git a/docker/inference/Dockerfile.tgi b/docker/inference/Dockerfile.tgi index 562a4713..dee6b07e 100644 --- a/docker/inference/Dockerfile.tgi +++ b/docker/inference/Dockerfile.tgi @@ -1,6 +1,6 @@ FROM ghcr.io/huggingface/text-generation-inference:2.4.0 RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple -RUN pip install --no-cache-dir csghub-sdk==0.4.5 +RUN pip install --no-cache-dir csghub-sdk==0.4.5 huggingface-hub==0.27.0 RUN apt-get update && apt-get install -y dumb-init && apt-get clean && rm -rf /var/lib/apt/lists/* COPY ./tgi/ /etc/csghub/ RUN chmod +x /etc/csghub/*.sh diff --git a/docker/inference/Dockerfile.vllm b/docker/inference/Dockerfile.vllm index aaf8dc40..0879d3aa 100644 --- a/docker/inference/Dockerfile.vllm +++ b/docker/inference/Dockerfile.vllm @@ -1,6 +1,6 @@ FROM vllm/vllm-openai:v0.6.3.post1 RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple -RUN pip install --no-cache-dir csghub-sdk==0.4.3 ray supervisor +RUN pip install --no-cache-dir csghub-sdk==0.4.3 ray supervisor huggingface-hub==0.27.0 RUN apt-get update && apt-get install -y supervisor RUN mkdir -p /var/log/supervisord COPY ./supervisord.conf /etc/supervisor/conf.d/supervisord.conf diff --git a/docker/inference/Dockerfile.vllm-cpu b/docker/inference/Dockerfile.vllm-cpu index c0b3ae36..17f94b55 100644 --- a/docker/inference/Dockerfile.vllm-cpu +++ b/docker/inference/Dockerfile.vllm-cpu @@ -1,6 +1,6 @@ FROM cledge/vllm-cpu:0.4.12-fix1 RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple -RUN pip install --no-cache-dir csghub-sdk==0.3.1 +RUN pip install --no-cache-dir csghub-sdk==0.4.6 WORKDIR /workspace/ diff --git a/docker/inference/README.md b/docker/inference/README.md index 35b86d5a..cf9fbd68 100644 --- a/docker/inference/README.md +++ b/docker/inference/README.md @@ -12,8 +12,8 @@ echo "$OPENCSG_ACR_PASSWORD" | docker login $OPENCSG_ACR -u $OPENCSG_ACR_USERNAM ```bash export BUILDX_NO_DEFAULT_ATTESTATIONS=1 -# For vllm: opencsg-registry.cn-beijing.cr.aliyuncs.com/public/vllm-local:2.7 -export IMAGE_TAG=2.8 +# For vllm: opencsg-registry.cn-beijing.cr.aliyuncs.com/public/vllm-local:3.2 +export IMAGE_TAG=3.2 docker buildx build --platform linux/amd64,linux/arm64 \ -t ${OPENCSG_ACR}/public/vllm-local:${IMAGE_TAG} \ -t ${OPENCSG_ACR}/public/vllm-local:latest \ @@ -28,8 +28,8 @@ docker buildx build --platform linux/amd64,linux/arm64 \ -f Dockerfile.vllm-cpu \ --push . -# For tgi: opencsg-registry.cn-beijing.cr.aliyuncs.com/public/tgi:2.2 -export IMAGE_TAG=2.2 +# For tgi: opencsg-registry.cn-beijing.cr.aliyuncs.com/public/tgi:3.2 +export IMAGE_TAG=3.2 docker buildx build --platform linux/amd64 \ -t ${OPENCSG_ACR}/public/tgi:${IMAGE_TAG} \ -t ${OPENCSG_ACR}/public/tgi:latest \ @@ -62,13 +62,13 @@ docker run -d \ *Note: HF_ENDPOINT should be use the real csghub address.* ## inference image name, version and cuda version -| Image Name | Version | CUDA Version | -| --- | --- | --- | -| vllm | 2.8 | 12.1 | -| vllm | 3.0 | 12.4 | -| vllm-cpu | 2.4 | -| -| tgi | 2.2 | 12.1 | -| tgi | 3.0 | 12.4 | +| Image Name | Version | CUDA Version | Fix +| --- | --- | --- |--- | +| vllm | 2.8 | 12.1 | - | +| vllm | 3.2 | 12.4 |fix hf hub timestamp| +| vllm-cpu | 2.4 | -|fix hf hub timestamp | +| tgi | 2.2 | 12.1 |- | +| tgi | 3.2 | 12.4 |fix hf hub timestamp| ## API to Call Inference diff --git a/docker/spaces/Dockerfile.nginx b/docker/spaces/Dockerfile.nginx index 9d0768a8..b3bfe3d7 100644 --- a/docker/spaces/Dockerfile.nginx +++ b/docker/spaces/Dockerfile.nginx @@ -1,6 +1,7 @@ FROM opencsg-registry.cn-beijing.cr.aliyuncs.com/opencsg_public/nginx:latest - +RUN apt-get update && apt-get install -y git WORKDIR /usr/share/nginx/html -COPY ./nginx/serve.sh serve.sh +COPY ./nginx/serve.sh /etc/serve.sh +RUN chmod +x /etc/serve.sh -CMD ["serve.sh"] \ No newline at end of file +CMD ["/etc/serve.sh"] \ No newline at end of file From 5bcc387688d7549e150230365998ff7923f3d86c Mon Sep 17 00:00:00 2001 From: James <370036720@qq.com> Date: Fri, 3 Jan 2025 12:05:35 +0800 Subject: [PATCH 32/34] runner service refactor (#230) * Refactor/runner service informer * fix runner bug * fix space status issue * runner testing --------- Co-authored-by: James --- .../store/database/mock_ClusterInfoStore.go | 28 +- .../database/mock_KnativeServiceStore.go | 298 ++++++++ builder/deploy/cluster/cluster_manager.go | 39 +- builder/deploy/deployer.go | 75 +- builder/deploy/deployer_ce.go | 1 - builder/deploy/deployer_test.go | 67 +- builder/store/database/cluster.go | 19 +- builder/store/database/cluster_test.go | 4 +- builder/store/database/knative_service.go | 92 +++ .../store/database/knative_service_test.go | 94 +++ ...1225124808_create_table_knative_service.go | 54 ++ common/types/service_runner.go | 52 +- go.mod | 6 +- go.sum | 12 +- runner/component/service.go | 643 ++++++++++++++++- runner/component/service_test.go | 670 ++++++++++++++++++ runner/component/workflow.go | 6 +- runner/component/workflow_test.go | 170 +++++ runner/handler/service.go | 572 ++------------- runner/handler/workflow.go | 2 + 20 files changed, 2209 insertions(+), 695 deletions(-) create mode 100644 _mocks/opencsg.com/csghub-server/builder/store/database/mock_KnativeServiceStore.go create mode 100644 builder/store/database/knative_service.go create mode 100644 builder/store/database/knative_service_test.go create mode 100644 builder/store/database/migrations/20241225124808_create_table_knative_service.go create mode 100644 runner/component/service_test.go create mode 100644 runner/component/workflow_test.go diff --git a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_ClusterInfoStore.go b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_ClusterInfoStore.go index 86fbeb25..82c5d5a2 100644 --- a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_ClusterInfoStore.go +++ b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_ClusterInfoStore.go @@ -23,21 +23,33 @@ func (_m *MockClusterInfoStore) EXPECT() *MockClusterInfoStore_Expecter { } // Add provides a mock function with given fields: ctx, clusterConfig, region -func (_m *MockClusterInfoStore) Add(ctx context.Context, clusterConfig string, region string) error { +func (_m *MockClusterInfoStore) Add(ctx context.Context, clusterConfig string, region string) (*database.ClusterInfo, error) { ret := _m.Called(ctx, clusterConfig, region) if len(ret) == 0 { panic("no return value specified for Add") } - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + var r0 *database.ClusterInfo + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (*database.ClusterInfo, error)); ok { + return rf(ctx, clusterConfig, region) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) *database.ClusterInfo); ok { r0 = rf(ctx, clusterConfig, region) } else { - r0 = ret.Error(0) + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.ClusterInfo) + } } - return r0 + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, clusterConfig, region) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // MockClusterInfoStore_Add_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Add' @@ -60,12 +72,12 @@ func (_c *MockClusterInfoStore_Add_Call) Run(run func(ctx context.Context, clust return _c } -func (_c *MockClusterInfoStore_Add_Call) Return(_a0 error) *MockClusterInfoStore_Add_Call { - _c.Call.Return(_a0) +func (_c *MockClusterInfoStore_Add_Call) Return(_a0 *database.ClusterInfo, _a1 error) *MockClusterInfoStore_Add_Call { + _c.Call.Return(_a0, _a1) return _c } -func (_c *MockClusterInfoStore_Add_Call) RunAndReturn(run func(context.Context, string, string) error) *MockClusterInfoStore_Add_Call { +func (_c *MockClusterInfoStore_Add_Call) RunAndReturn(run func(context.Context, string, string) (*database.ClusterInfo, error)) *MockClusterInfoStore_Add_Call { _c.Call.Return(run) return _c } diff --git a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_KnativeServiceStore.go b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_KnativeServiceStore.go new file mode 100644 index 00000000..fa330fc6 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_KnativeServiceStore.go @@ -0,0 +1,298 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package database + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + database "opencsg.com/csghub-server/builder/store/database" +) + +// MockKnativeServiceStore is an autogenerated mock type for the KnativeServiceStore type +type MockKnativeServiceStore struct { + mock.Mock +} + +type MockKnativeServiceStore_Expecter struct { + mock *mock.Mock +} + +func (_m *MockKnativeServiceStore) EXPECT() *MockKnativeServiceStore_Expecter { + return &MockKnativeServiceStore_Expecter{mock: &_m.Mock} +} + +// Add provides a mock function with given fields: ctx, service +func (_m *MockKnativeServiceStore) Add(ctx context.Context, service *database.KnativeService) error { + ret := _m.Called(ctx, service) + + if len(ret) == 0 { + panic("no return value specified for Add") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *database.KnativeService) error); ok { + r0 = rf(ctx, service) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockKnativeServiceStore_Add_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Add' +type MockKnativeServiceStore_Add_Call struct { + *mock.Call +} + +// Add is a helper method to define mock.On call +// - ctx context.Context +// - service *database.KnativeService +func (_e *MockKnativeServiceStore_Expecter) Add(ctx interface{}, service interface{}) *MockKnativeServiceStore_Add_Call { + return &MockKnativeServiceStore_Add_Call{Call: _e.mock.On("Add", ctx, service)} +} + +func (_c *MockKnativeServiceStore_Add_Call) Run(run func(ctx context.Context, service *database.KnativeService)) *MockKnativeServiceStore_Add_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*database.KnativeService)) + }) + return _c +} + +func (_c *MockKnativeServiceStore_Add_Call) Return(_a0 error) *MockKnativeServiceStore_Add_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockKnativeServiceStore_Add_Call) RunAndReturn(run func(context.Context, *database.KnativeService) error) *MockKnativeServiceStore_Add_Call { + _c.Call.Return(run) + return _c +} + +// Delete provides a mock function with given fields: ctx, svcName, clusterID +func (_m *MockKnativeServiceStore) Delete(ctx context.Context, svcName string, clusterID string) error { + ret := _m.Called(ctx, svcName, clusterID) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, svcName, clusterID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockKnativeServiceStore_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type MockKnativeServiceStore_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - ctx context.Context +// - svcName string +// - clusterID string +func (_e *MockKnativeServiceStore_Expecter) Delete(ctx interface{}, svcName interface{}, clusterID interface{}) *MockKnativeServiceStore_Delete_Call { + return &MockKnativeServiceStore_Delete_Call{Call: _e.mock.On("Delete", ctx, svcName, clusterID)} +} + +func (_c *MockKnativeServiceStore_Delete_Call) Run(run func(ctx context.Context, svcName string, clusterID string)) *MockKnativeServiceStore_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MockKnativeServiceStore_Delete_Call) Return(_a0 error) *MockKnativeServiceStore_Delete_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockKnativeServiceStore_Delete_Call) RunAndReturn(run func(context.Context, string, string) error) *MockKnativeServiceStore_Delete_Call { + _c.Call.Return(run) + return _c +} + +// Get provides a mock function with given fields: ctx, svcName, clusterID +func (_m *MockKnativeServiceStore) Get(ctx context.Context, svcName string, clusterID string) (*database.KnativeService, error) { + ret := _m.Called(ctx, svcName, clusterID) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 *database.KnativeService + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (*database.KnativeService, error)); ok { + return rf(ctx, svcName, clusterID) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) *database.KnativeService); ok { + r0 = rf(ctx, svcName, clusterID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.KnativeService) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, svcName, clusterID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockKnativeServiceStore_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type MockKnativeServiceStore_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - ctx context.Context +// - svcName string +// - clusterID string +func (_e *MockKnativeServiceStore_Expecter) Get(ctx interface{}, svcName interface{}, clusterID interface{}) *MockKnativeServiceStore_Get_Call { + return &MockKnativeServiceStore_Get_Call{Call: _e.mock.On("Get", ctx, svcName, clusterID)} +} + +func (_c *MockKnativeServiceStore_Get_Call) Run(run func(ctx context.Context, svcName string, clusterID string)) *MockKnativeServiceStore_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MockKnativeServiceStore_Get_Call) Return(_a0 *database.KnativeService, _a1 error) *MockKnativeServiceStore_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockKnativeServiceStore_Get_Call) RunAndReturn(run func(context.Context, string, string) (*database.KnativeService, error)) *MockKnativeServiceStore_Get_Call { + _c.Call.Return(run) + return _c +} + +// GetByCluster provides a mock function with given fields: ctx, clusterID +func (_m *MockKnativeServiceStore) GetByCluster(ctx context.Context, clusterID string) ([]database.KnativeService, error) { + ret := _m.Called(ctx, clusterID) + + if len(ret) == 0 { + panic("no return value specified for GetByCluster") + } + + var r0 []database.KnativeService + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]database.KnativeService, error)); ok { + return rf(ctx, clusterID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) []database.KnativeService); ok { + r0 = rf(ctx, clusterID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]database.KnativeService) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, clusterID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockKnativeServiceStore_GetByCluster_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByCluster' +type MockKnativeServiceStore_GetByCluster_Call struct { + *mock.Call +} + +// GetByCluster is a helper method to define mock.On call +// - ctx context.Context +// - clusterID string +func (_e *MockKnativeServiceStore_Expecter) GetByCluster(ctx interface{}, clusterID interface{}) *MockKnativeServiceStore_GetByCluster_Call { + return &MockKnativeServiceStore_GetByCluster_Call{Call: _e.mock.On("GetByCluster", ctx, clusterID)} +} + +func (_c *MockKnativeServiceStore_GetByCluster_Call) Run(run func(ctx context.Context, clusterID string)) *MockKnativeServiceStore_GetByCluster_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockKnativeServiceStore_GetByCluster_Call) Return(_a0 []database.KnativeService, _a1 error) *MockKnativeServiceStore_GetByCluster_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockKnativeServiceStore_GetByCluster_Call) RunAndReturn(run func(context.Context, string) ([]database.KnativeService, error)) *MockKnativeServiceStore_GetByCluster_Call { + _c.Call.Return(run) + return _c +} + +// Update provides a mock function with given fields: ctx, service +func (_m *MockKnativeServiceStore) Update(ctx context.Context, service *database.KnativeService) error { + ret := _m.Called(ctx, service) + + if len(ret) == 0 { + panic("no return value specified for Update") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *database.KnativeService) error); ok { + r0 = rf(ctx, service) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockKnativeServiceStore_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' +type MockKnativeServiceStore_Update_Call struct { + *mock.Call +} + +// Update is a helper method to define mock.On call +// - ctx context.Context +// - service *database.KnativeService +func (_e *MockKnativeServiceStore_Expecter) Update(ctx interface{}, service interface{}) *MockKnativeServiceStore_Update_Call { + return &MockKnativeServiceStore_Update_Call{Call: _e.mock.On("Update", ctx, service)} +} + +func (_c *MockKnativeServiceStore_Update_Call) Run(run func(ctx context.Context, service *database.KnativeService)) *MockKnativeServiceStore_Update_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*database.KnativeService)) + }) + return _c +} + +func (_c *MockKnativeServiceStore_Update_Call) Return(_a0 error) *MockKnativeServiceStore_Update_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockKnativeServiceStore_Update_Call) RunAndReturn(run func(context.Context, *database.KnativeService) error) *MockKnativeServiceStore_Update_Call { + _c.Call.Return(run) + return _c +} + +// NewMockKnativeServiceStore creates a new instance of MockKnativeServiceStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockKnativeServiceStore(t interface { + mock.TestingT + Cleanup(func()) +}) *MockKnativeServiceStore { + mock := &MockKnativeServiceStore{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/builder/deploy/cluster/cluster_manager.go b/builder/deploy/cluster/cluster_manager.go index b87e25be..9e31c4fc 100644 --- a/builder/deploy/cluster/cluster_manager.go +++ b/builder/deploy/cluster/cluster_manager.go @@ -8,6 +8,7 @@ import ( "math/rand" "path/filepath" "strings" + "time" "github.com/argoproj/argo-workflows/v3/pkg/client/clientset/versioned" v1 "k8s.io/api/core/v1" @@ -23,12 +24,14 @@ import ( ) // Cluster holds basic information about a Kubernetes cluster + type Cluster struct { - ID string // Unique identifier for the cluster - ConfigPath string // Path to the kubeconfig file - Client *kubernetes.Clientset // Kubernetes client - KnativeClient *knative.Clientset // Knative client - ArgoClient *versioned.Clientset // Argo client + CID string // config id + ID string // unique id + ConfigPath string // Path to the kubeconfig file + Client kubernetes.Interface // Kubernetes client + KnativeClient knative.Interface // Knative client + ArgoClient versioned.Interface // Argo client StorageClass string } @@ -69,22 +72,26 @@ func NewClusterPool() (*ClusterPool, error) { } knativeClient, err := knative.NewForConfig(config) if err != nil { - slog.Error("falied to create knative client", "error", err) - return nil, fmt.Errorf("falied to create knative client,%w", err) + slog.Error("failed to create knative client", "error", err) + return nil, fmt.Errorf("failed to create knative client,%w", err) } id := filepath.Base(kubeconfig) + ctxTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + cluster, err := pool.ClusterStore.Add(ctxTimeout, id, fmt.Sprintf("region-%d", i)) + if err != nil { + slog.Error("failed to add cluster info to db", slog.Any("error", err), slog.Any("congfig id", id)) + return nil, fmt.Errorf("failed to add cluster info to db,%v", err) + } pool.Clusters = append(pool.Clusters, Cluster{ - ID: id, + CID: id, + ID: cluster.ClusterID, ConfigPath: kubeconfig, Client: client, KnativeClient: knativeClient, ArgoClient: argoClient, }) - err = pool.ClusterStore.Add(context.TODO(), id, fmt.Sprintf("region-%d", i)) - if err != nil { - slog.Error("falied to add cluster info to db", "error", err) - return nil, fmt.Errorf("falied to add cluster info to db,%w", err) - } + } return pool, nil @@ -114,7 +121,7 @@ func (p *ClusterPool) GetClusterByID(ctx context.Context, id string) (*Cluster, storageClass = cInfo.StorageClass } for _, Cluster := range p.Clusters { - if Cluster.ID == cfId { + if Cluster.CID == cfId { Cluster.StorageClass = storageClass return &Cluster, nil } @@ -123,7 +130,7 @@ func (p *ClusterPool) GetClusterByID(ctx context.Context, id string) (*Cluster, } // getNodeResources retrieves all node cpu and gpu info -func GetNodeResources(clientset *kubernetes.Clientset, config *config.Config) (map[string]types.NodeResourceInfo, error) { +func GetNodeResources(clientset kubernetes.Interface, config *config.Config) (map[string]types.NodeResourceInfo, error) { nodes, err := clientset.CoreV1().Nodes().List(context.TODO(), metav1.ListOptions{}) if err != nil { return nil, err @@ -140,7 +147,7 @@ func GetNodeResources(clientset *kubernetes.Clientset, config *config.Config) (m memCapacity := node.Status.Capacity["memory"] memQuantity, ok := memCapacity.AsInt64() if !ok { - slog.Error("falied to get node memory", "node", node.Name, "error", err) + slog.Error("failed to get node memory", "node", node.Name, "error", err) continue } totalMem := getMem(memQuantity) diff --git a/builder/deploy/deployer.go b/builder/deploy/deployer.go index 3c981a45..0ca29c79 100644 --- a/builder/deploy/deployer.go +++ b/builder/deploy/deployer.go @@ -210,22 +210,6 @@ func (d *deployer) Deploy(ctx context.Context, dr types.DeployRepo) (int64, erro return deploy.ID, nil } -func (d *deployer) refreshStatus() { - for { - ctxTimeout, cancel := context.WithTimeout(context.Background(), 3*time.Second) - status, err := d.imageRunner.StatusAll(ctxTimeout) - cancel() - if err != nil { - slog.Error("refresh status all failed", slog.Any("error", err)) - } else { - slog.Debug("status all cached", slog.Any("status", d.runnerStatusCache)) - d.runnerStatusCache = status - } - - time.Sleep(5 * time.Second) - } -} - func (d *deployer) Status(ctx context.Context, dr types.DeployRepo, needDetails bool) (string, int, []types.Instance, error) { deploy, err := d.deployTaskStore.GetDeployByID(ctx, dr.DeployID) if err != nil || deploy == nil { @@ -233,39 +217,19 @@ func (d *deployer) Status(ctx context.Context, dr types.DeployRepo, needDetails return "", common.Stopped, nil, fmt.Errorf("can't get deploy, %w", err) } svcName := deploy.SvcName - // srvName := common.UniqueSpaceAppName(dr.Namespace, dr.Name, dr.SpaceID) - rstatus, found := d.runnerStatusCache[svcName] - if !found { - slog.Debug("status cache miss", slog.String("svc_name", svcName)) - if deploy.Status == common.Running { - // service was Stopped or delete, so no running instance - return svcName, common.Stopped, nil, nil - } - return svcName, deploy.Status, nil, nil - } - deployStatus := rstatus.Code - if dr.ModelID > 0 { - targetID := dr.DeployID // support model deploy with multi-instance - status, err := d.imageRunner.Status(ctx, &types.StatusRequest{ - ClusterID: dr.ClusterID, - OrgName: dr.Namespace, - RepoName: dr.Name, - SvcName: deploy.SvcName, - ID: targetID, - NeedDetails: needDetails, - }) - if err != nil { - slog.Error("fail to get status by deploy id", slog.Any("DeployID", deploy.ID), slog.Any("error", err)) - return "", common.RunTimeError, nil, fmt.Errorf("can't get deploy status, %w", err) - } - rstatus.Instances = status.Instances - deployStatus = status.Code - + svc, err := d.imageRunner.Exist(ctx, &types.CheckRequest{ + SvcName: svcName, + ClusterID: deploy.ClusterID, + }) + if err != nil { + slog.Error("fail to get deploy by service name", slog.Any("Service NamE", svcName), slog.Any("error", err)) + return "", common.Stopped, nil, fmt.Errorf("can't get svc, %w", err) } - if rstatus.DeployID == 0 || rstatus.DeployID >= deploy.ID { - return svcName, deployStatus, rstatus.Instances, nil + if svc.Code == common.Stopped || svc.Code == -1 { + // like queuing, or stopped, use status from deploy + return svcName, deploy.Status, nil, nil } - return svcName, deployStatus, rstatus.Instances, nil + return svcName, svc.Code, svc.Instances, nil } func (d *deployer) Logs(ctx context.Context, dr types.DeployRepo) (*MultiLogReader, error) { @@ -395,12 +359,11 @@ func (d *deployer) Exist(ctx context.Context, dr types.DeployRepo) (bool, error) // service check with error slog.Error("deploy check result", slog.Any("resp", resp)) return true, errors.New("fail to check deploy instance") - } else if resp.Code == 1 { - // service exist - return true, nil + } else if resp.Code == common.Stopped { + // service not exist + return false, nil } - // service not exist - return false, nil + return true, nil } func (d *deployer) GetReplica(ctx context.Context, dr types.DeployRepo) (int, int, []types.Instance, error) { @@ -622,7 +585,13 @@ func (d *deployer) startAcctFeeing() { for { resMap := d.getResourceMap() slog.Debug("get resources map", slog.Any("resMap", resMap)) - for _, svc := range d.runnerStatusCache { + ctxTimeout, cancel := context.WithTimeout(context.Background(), 10*time.Second) + cancel() + status, err := d.imageRunner.StatusAll(ctxTimeout) + if err != nil { + slog.Error("failed to get all service status", slog.Any("error", err)) + } + for _, svc := range status { d.startAcctRequestFee(resMap, svc) } // accounting interval in min, get from env config diff --git a/builder/deploy/deployer_ce.go b/builder/deploy/deployer_ce.go index ddf6f5b8..21ce24d9 100644 --- a/builder/deploy/deployer_ce.go +++ b/builder/deploy/deployer_ce.go @@ -55,7 +55,6 @@ func newDeployer(s scheduler.Scheduler, ib imagebuilder.Builder, ir imagerunner. userStore: database.NewUserStore(), } - go d.refreshStatus() d.startJobs() return d, nil } diff --git a/builder/deploy/deployer_test.go b/builder/deploy/deployer_test.go index 23e2654d..0840af2c 100644 --- a/builder/deploy/deployer_test.go +++ b/builder/deploy/deployer_test.go @@ -243,29 +243,36 @@ func TestDeployer_Status(t *testing.T) { }) t.Run("cache miss and running", func(t *testing.T) { dr := types.DeployRepo{ - DeployID: 1, - UserUUID: "1", - Path: "namespace/name", - Type: types.InferenceType, + DeployID: 1, + UserUUID: "1", + Path: "namespace/name", + Type: types.InferenceType, + ClusterID: "test", } deploy := &database.Deploy{ - Status: common.Running, + Status: common.Building, SvcName: "svc", } - + mockRunner := mockrunner.NewMockRunner(t) mockDeployTaskStore := mockdb.NewMockDeployTaskStore(t) mockDeployTaskStore.EXPECT().GetDeployByID(mock.Anything, dr.DeployID). Return(deploy, nil) d := &deployer{ deployTaskStore: mockDeployTaskStore, + imageRunner: mockRunner, } - d.runnerStatusCache = make(map[string]types.StatusResponse) + mockRunner.EXPECT().Exist(mock.Anything, mock.Anything). + Return(&types.StatusResponse{ + DeployID: 1, + UserID: "", + Code: common.Stopped, + }, nil) svcName, deployStatus, instances, err := d.Status(context.TODO(), dr, false) require.Nil(t, err) require.Equal(t, "svc", svcName) - require.Equal(t, common.Stopped, deployStatus) + require.Equal(t, common.Building, deployStatus) require.Nil(t, instances) }) @@ -281,15 +288,21 @@ func TestDeployer_Status(t *testing.T) { Status: common.BuildSuccess, SvcName: "svc", } - + mockRunner := mockrunner.NewMockRunner(t) mockDeployTaskStore := mockdb.NewMockDeployTaskStore(t) mockDeployTaskStore.EXPECT().GetDeployByID(mock.Anything, dr.DeployID). Return(deploy, nil) d := &deployer{ deployTaskStore: mockDeployTaskStore, + imageRunner: mockRunner, } - d.runnerStatusCache = make(map[string]types.StatusResponse) + mockRunner.EXPECT().Exist(mock.Anything, mock.Anything). + Return(&types.StatusResponse{ + DeployID: 1, + UserID: "", + Code: int(common.BuildSuccess), + }, nil) svcName, deployStatus, instances, err := d.Status(context.TODO(), dr, false) require.Nil(t, err) @@ -318,37 +331,21 @@ func TestDeployer_Status(t *testing.T) { Return(deploy, nil) mockRunner := mockrunner.NewMockRunner(t) - mockRunner.EXPECT().Status(mock.Anything, mock.Anything). + + d := &deployer{ + deployTaskStore: mockDeployTaskStore, + imageRunner: mockRunner, + } + mockRunner.EXPECT().Exist(mock.Anything, mock.Anything). Return(&types.StatusResponse{ DeployID: 1, UserID: "", - // running status from the runner (latest) - Code: int(common.Running), - Message: "", - Endpoint: "http://localhost", + Code: common.Running, Instances: []types.Instance{{ - Name: "instance1", - Status: "ready", + Name: "instance1", }}, - Replica: 1, - DeployType: 0, - ServiceName: "svc", - DeploySku: "", }, nil) - d := &deployer{ - deployTaskStore: mockDeployTaskStore, - imageRunner: mockRunner, - } - d.runnerStatusCache = make(map[string]types.StatusResponse) - // deploying status in cache - d.runnerStatusCache["svc"] = types.StatusResponse{ - DeployID: 1, - UserID: "", - Code: int(common.Deploying), - Message: "", - } - svcName, deployStatus, instances, err := d.Status(context.TODO(), dr, false) require.Nil(t, err) require.Equal(t, "svc", svcName) @@ -475,7 +472,7 @@ func TestDeployer_Exists(t *testing.T) { mockRunner := mockrunner.NewMockRunner(t) mockRunner.EXPECT().Exist(mock.Anything, mock.Anything). Return(&types.StatusResponse{ - Code: 2, + Code: common.Stopped, }, nil) d := &deployer{ diff --git a/builder/store/database/cluster.go b/builder/store/database/cluster.go index 0b0fb226..dabff5f2 100644 --- a/builder/store/database/cluster.go +++ b/builder/store/database/cluster.go @@ -14,7 +14,7 @@ type clusterInfoStoreImpl struct { } type ClusterInfoStore interface { - Add(ctx context.Context, clusterConfig string, region string) error + Add(ctx context.Context, clusterConfig string, region string) (*ClusterInfo, error) Update(ctx context.Context, clusterInfo ClusterInfo) error ByClusterID(ctx context.Context, clusterId string) (clusterInfo ClusterInfo, err error) ByClusterConfig(ctx context.Context, clusterConfig string) (clusterInfo ClusterInfo, err error) @@ -43,22 +43,21 @@ type ClusterInfo struct { Enable bool `bun:",notnull" json:"enable"` } -func (r *clusterInfoStoreImpl) Add(ctx context.Context, clusterConfig string, region string) error { - err := r.db.Operator.Core.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +func (r *clusterInfoStoreImpl) Add(ctx context.Context, clusterConfig string, region string) (*ClusterInfo, error) { + cluster, err := r.ByClusterConfig(ctx, clusterConfig) + if errors.Is(err, sql.ErrNoRows) { cluster := &ClusterInfo{ ClusterID: uuid.New().String(), ClusterConfig: clusterConfig, Region: region, Enable: true, } - - _, err := r.ByClusterConfig(ctx, clusterConfig) - if errors.Is(err, sql.ErrNoRows) { - return assertAffectedOneRow(r.db.Operator.Core.NewInsert().Model(cluster).Exec(ctx)) + _, err = r.db.Operator.Core.NewInsert().Model(cluster).Exec(ctx) + if err != nil { + return nil, err } - return err - }) - return err + } + return &cluster, err } func (r *clusterInfoStoreImpl) Update(ctx context.Context, clusterInfo ClusterInfo) error { diff --git a/builder/store/database/cluster_test.go b/builder/store/database/cluster_test.go index e44c37c0..4ea1fe6f 100644 --- a/builder/store/database/cluster_test.go +++ b/builder/store/database/cluster_test.go @@ -16,7 +16,7 @@ func TestClusterStore_CRUD(t *testing.T) { store := database.NewClusterInfoStoreWithDB(db) - err := store.Add(ctx, "foo", "bar") + _, err := store.Add(ctx, "foo", "bar") require.Nil(t, err) cfg := &database.ClusterInfo{} @@ -25,7 +25,7 @@ func TestClusterStore_CRUD(t *testing.T) { require.Equal(t, "bar", cfg.Region) // already exist, do nothing - err = store.Add(ctx, "foo", "bar2") + _, err = store.Add(ctx, "foo", "bar2") require.Nil(t, err) err = db.Core.NewSelect().Model(cfg).Where("cluster_config=?", "foo").Scan(ctx) require.Nil(t, err) diff --git a/builder/store/database/knative_service.go b/builder/store/database/knative_service.go new file mode 100644 index 00000000..f3431596 --- /dev/null +++ b/builder/store/database/knative_service.go @@ -0,0 +1,92 @@ +package database + +import ( + "context" + + "opencsg.com/csghub-server/common/types" + + corev1 "k8s.io/api/core/v1" +) + +type knativeServiceImpl struct { + db *DB +} + +type KnativeServiceStore interface { + Get(ctx context.Context, svcName, clusterID string) (*KnativeService, error) + GetByCluster(ctx context.Context, clusterID string) ([]KnativeService, error) + Add(ctx context.Context, service *KnativeService) error + Update(ctx context.Context, service *KnativeService) error + Delete(ctx context.Context, svcName, clusterID string) error +} + +func NewKnativeServiceStore() KnativeServiceStore { + return &knativeServiceImpl{ + db: defaultDB, + } +} + +func NewKnativeServiceWithDB(db *DB) KnativeServiceStore { + return &knativeServiceImpl{ + db: db, + } +} + +type KnativeService struct { + ID int64 `bun:",pk,autoincrement" json:"id"` + Name string `bun:",notnull" json:"name"` + Status corev1.ConditionStatus `bun:",notnull" json:"status"` + Code int `bun:",notnull" json:"code"` + ClusterID string `bun:",notnull" json:"cluster_id"` + Endpoint string `bun:"," json:"endpoint"` + ActualReplica int `bun:"," json:"actual_replica"` + DesiredReplica int `bun:"," json:"desired_replica"` + Instances []types.Instance `bun:"type:jsonb" json:"instances,omitempty"` + UserUUID string `bun:"," json:"user_uuid"` + DeployID int64 `bun:"," json:"deploy_id"` + DeployType int `bun:"," json:"deploy_type"` + DeploySKU string `bun:"," json:"deploy_sku"` + OrderDetailID int64 `bun:"," json:"order_detail_id"` + times +} + +// get +func (s *knativeServiceImpl) Get(ctx context.Context, svcName, clusterID string) (*KnativeService, error) { + var service KnativeService + var err error + if clusterID == "" { + // backward compatibility, some space has no cluster id + err = s.db.Operator.Core.NewSelect().Model(&service).Where("name = ?", svcName).Scan(ctx) + } else { + err = s.db.Operator.Core.NewSelect().Model(&service).Where("name = ? and cluster_id = ?", svcName, clusterID).Scan(ctx) + } + return &service, err +} + +// add +func (s *knativeServiceImpl) Add(ctx context.Context, service *KnativeService) error { + _, err := s.db.Operator.Core.NewInsert().Model(service).On("CONFLICT(name, cluster_id) DO UPDATE").Exec(ctx) + return err +} + +// update +func (s *knativeServiceImpl) Update(ctx context.Context, service *KnativeService) error { + _, err := s.db.Operator.Core.NewUpdate().Model(service).WherePK().Exec(ctx) + return err +} + +// delete +func (s *knativeServiceImpl) Delete(ctx context.Context, svcName, clusterID string) error { + _, err := s.db.Operator.Core.NewDelete(). + Model(&KnativeService{}). + Where("name = ? and cluster_id = ?", svcName, clusterID). + Exec(ctx) + return err +} + +// GetByCluster +func (s *knativeServiceImpl) GetByCluster(ctx context.Context, clusterID string) ([]KnativeService, error) { + var services []KnativeService + err := s.db.Operator.Core.NewSelect().Model(&services).Where("cluster_id = ?", clusterID).Scan(ctx) + return services, err +} diff --git a/builder/store/database/knative_service_test.go b/builder/store/database/knative_service_test.go new file mode 100644 index 00000000..697860fb --- /dev/null +++ b/builder/store/database/knative_service_test.go @@ -0,0 +1,94 @@ +package database_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + "opencsg.com/csghub-server/builder/deploy/common" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" +) + +func TestKnativeServiceStore_Get(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewKnativeServiceWithDB(db) + err := store.Add(ctx, &database.KnativeService{ + Name: "test", + Status: corev1.ConditionTrue, + Code: common.Running, + ClusterID: "cluster1", + }) + require.Nil(t, err) + err = store.Add(ctx, &database.KnativeService{ + Name: "test2", + Status: corev1.ConditionTrue, + Code: common.Running, + ClusterID: "cluster1", + }) + require.Nil(t, err) + err = store.Add(ctx, &database.KnativeService{ + Name: "test3", + Status: corev1.ConditionTrue, + Code: common.Running, + ClusterID: "cluster2", + }) + require.Nil(t, err) + ks, err := store.Get(ctx, "test", "cluster1") + require.Nil(t, err) + require.Equal(t, "test", ks.Name) + list, err := store.GetByCluster(ctx, "cluster1") + require.Nil(t, err) + require.Equal(t, 2, len(list)) +} + +func TestKnativeServiceStore_Delete(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewKnativeServiceWithDB(db) + err := store.Add(ctx, &database.KnativeService{ + Name: "test", + Status: corev1.ConditionTrue, + Code: common.Running, + ClusterID: "cluster1", + }) + require.Nil(t, err) + err = store.Delete(ctx, "test", "cluster1") + require.Nil(t, err) + _, err = store.Get(ctx, "test", "cluster1") + require.NotNil(t, err) +} + +func TestKnativeServiceStore_Update(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + store := database.NewKnativeServiceWithDB(db) + err := store.Add(ctx, &database.KnativeService{ + ID: 1, + Name: "test", + Status: corev1.ConditionFalse, + Code: common.Deploying, + ClusterID: "cluster1", + }) + require.Nil(t, err) + err = store.Update(ctx, &database.KnativeService{ + ID: 1, + Name: "test", + Status: corev1.ConditionTrue, + Code: common.Running, + ClusterID: "cluster1", + }) + require.Nil(t, err) + ks, err := store.Get(ctx, "test", "cluster1") + require.Nil(t, err) + require.Equal(t, corev1.ConditionTrue, ks.Status) + require.Equal(t, common.Running, ks.Code) +} diff --git a/builder/store/database/migrations/20241225124808_create_table_knative_service.go b/builder/store/database/migrations/20241225124808_create_table_knative_service.go new file mode 100644 index 00000000..b9b07174 --- /dev/null +++ b/builder/store/database/migrations/20241225124808_create_table_knative_service.go @@ -0,0 +1,54 @@ +package migrations + +import ( + "context" + "fmt" + "go/types" + + "github.com/uptrace/bun" + corev1 "k8s.io/api/core/v1" +) + +type KnativeService struct { + ID int64 `bun:",pk,autoincrement" json:"id"` + Name string `bun:",notnull" json:"name"` + Status corev1.ConditionStatus `bun:",notnull" json:"status"` + Code int `bun:",notnull" json:"code"` + ClusterID string `bun:",notnull" json:"cluster_id"` + Endpoint string `bun:"," json:"endpoint"` + ActualReplica int `bun:"," json:"actual_replica"` + DesiredReplica int `bun:"," json:"desired_replica"` + Instances []types.Instance `bun:"type:jsonb" json:"instances,omitempty"` + UserUUID string `bun:"," json:"user_uuid"` + DeployID int64 `bun:"," json:"deploy_id"` + DeployType int `bun:"," json:"deploy_type"` + DeploySKU string `bun:"," json:"deploy_sku"` + OrderDetailID int64 `bun:"," json:"order_detail_id"` + times +} + +func init() { + Migrations.MustRegister(func(ctx context.Context, db *bun.DB) error { + err := createTables(ctx, db, KnativeService{}) + if err != nil { + return err + } + _, err = db.ExecContext(ctx, "ALTER TABLE knative_services ADD CONSTRAINT unique_cluster_svc UNIQUE (cluster_id, name)") + if err != nil { + return fmt.Errorf("failed to add unique for knative_services table: %w", err) + } + _, err = db.NewCreateIndex(). + Model((*KnativeService)(nil)). + Index("idx_knative_name_cluster"). + Column("name", "cluster_id"). + Unique(). + IfNotExists(). + Exec(ctx) + if err != nil { + return fmt.Errorf("fail to create index idx_knative_name_cluster_user : %w", err) + } + return err + }, func(ctx context.Context, db *bun.DB) error { + return dropTables(ctx, db, KnativeService{}) + }) +} diff --git a/common/types/service_runner.go b/common/types/service_runner.go index c18a1b78..6a5bd799 100644 --- a/common/types/service_runner.go +++ b/common/types/service_runner.go @@ -67,16 +67,19 @@ type ( } StatusResponse struct { - DeployID int64 `json:"deploy_id"` - UserID string `json:"user_id"` - Code int `json:"code"` - Message string `json:"message"` - Endpoint string `json:"url"` - Instances []Instance `json:"instance"` - Replica int `json:"replica"` - DeployType int `json:"deploy_type"` - ServiceName string `json:"service_name"` - DeploySku string `json:"deploy_sku"` + DeployID int64 `json:"deploy_id"` + UserID string `json:"user_id"` + Code int `json:"code"` + Message string `json:"message"` + Endpoint string `json:"url"` + Instances []Instance `json:"instance"` + Replica int `json:"replica"` + DeployType int `json:"deploy_type"` + ServiceName string `json:"service_name"` + DeploySku string `json:"deploy_sku"` + OrderDetailID int64 `json:"order_detail_id"` + ActualReplica int `json:"actual_replica"` + DesiredReplica int `json:"desired_replica"` } LogsRequest struct { @@ -130,7 +133,8 @@ type ( } ServiceRequest struct { - ClusterID string `json:"cluster_id"` + ServiceName string `json:"-"` + ClusterID string `json:"cluster_id"` } ServiceInfoResponse struct { @@ -169,17 +173,19 @@ type ( KnativeClient *knative.Clientset // Knative client } SVCRequest struct { - ImageID string `json:"image_id" binding:"required"` - Hardware HardWare `json:"hardware,omitempty"` - Env map[string]string `json:"env,omitempty"` - Annotation map[string]string `json:"annotation,omitempty"` - DeployID int64 `json:"deploy_id" binding:"required"` - RepoType string `json:"repo_type"` - MinReplica int `json:"min_replica"` - MaxReplica int `json:"max_replica"` - ClusterID string `json:"cluster_id"` - DeployType int `json:"deploy_type"` - UserID string `json:"user_id"` - Sku string `json:"sku"` + ImageID string `json:"image_id" binding:"required"` + Hardware HardWare `json:"hardware,omitempty"` + Env map[string]string `json:"env,omitempty"` + Annotation map[string]string `json:"annotation,omitempty"` + DeployID int64 `json:"deploy_id" binding:"required"` + RepoType string `json:"repo_type"` + MinReplica int `json:"min_replica"` + MaxReplica int `json:"max_replica"` + ClusterID string `json:"cluster_id"` + DeployType int `json:"deploy_type"` + UserID string `json:"user_id"` + Sku string `json:"sku"` + OrderDetailID int64 `json:"order_detail_id"` + SrvName string `json:"-"` } ) diff --git a/go.mod b/go.mod index 8e9d06a5..93e4fadf 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/minio/minio-go/v7 v7.0.66 github.com/minio/sha256-simd v1.0.1 github.com/naoina/toml v0.1.1 - github.com/redis/go-redis/v9 v9.3.0 + github.com/redis/go-redis/v9 v9.5.1 github.com/sethvargo/go-envconfig v1.1.0 github.com/spf13/cast v1.5.1 github.com/spf13/cobra v1.8.0 @@ -76,6 +76,7 @@ require ( github.com/docker/docker v27.1.1+incompatible // indirect github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect + github.com/evanphx/json-patch v5.9.0+incompatible // indirect github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect @@ -145,6 +146,7 @@ require ( google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240827150818-7e3bb234dfed // indirect gopkg.in/DataDog/dd-trace-go.v1 v1.32.0 // indirect + gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect ) @@ -183,7 +185,7 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.20.0 github.com/go-redis/redis/v8 v8.11.5 // indirect - github.com/goccy/go-json v0.10.2 // indirect + github.com/goccy/go-json v0.10.3 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt/v4 v4.5.1 // indirect github.com/golang/protobuf v1.5.4 // indirect diff --git a/go.sum b/go.sum index 41d5b88b..52183d31 100644 --- a/go.sum +++ b/go.sum @@ -161,8 +161,8 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/evanphx/json-patch v5.8.0+incompatible h1:1Av9pn2FyxPdvrWNQszj1g6D6YthSmvCfcN6SYclTJg= -github.com/evanphx/json-patch v5.8.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= +github.com/evanphx/json-patch v5.9.0+incompatible h1:fBXyNpNMuTTDdquAq/uisOr2lShz4oaXpDTX2bLe7ls= +github.com/evanphx/json-patch v5.9.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch/v5 v5.8.0 h1:lRj6N9Nci7MvzrXuX6HFzU8XjmhPiXPlsKEy1u0KQro= github.com/evanphx/json-patch/v5 v5.8.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ= github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a h1:yDWHCSQ40h88yih2JAcL6Ls/kVkSE8GFACTGVnMPruw= @@ -243,8 +243,8 @@ github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8Wd github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= -github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= -github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= +github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A= @@ -530,8 +530,8 @@ github.com/prometheus/prometheus v0.50.1 h1:N2L+DYrxqPh4WZStU+o1p/gQlBaqFbcLBTjl github.com/prometheus/prometheus v0.50.1/go.mod h1:FvE8dtQ1Ww63IlyKBn1V4s+zMwF9kHkVNkQBR1pM4CU= github.com/prometheus/statsd_exporter v0.22.7 h1:7Pji/i2GuhK6Lu7DHrtTkFmNBCudCPT1pX2CziuyQR0= github.com/prometheus/statsd_exporter v0.22.7/go.mod h1:N/TevpjkIh9ccs6nuzY3jQn9dFqnUakOjnEuMPJJJnI= -github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0= -github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= +github.com/redis/go-redis/v9 v9.5.1 h1:H1X4D3yHPaYrkL5X06Wh6xNVM/pX0Ft4RV0vMGvLBh8= +github.com/redis/go-redis/v9 v9.5.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/robfig/cron v1.2.0 h1:ZjScXvvxeQ63Dbyxy76Fj3AT3Ut0aKsyd2/tl3DTMuQ= diff --git a/runner/component/service.go b/runner/component/service.go index 0b7276e0..e6501615 100644 --- a/runner/component/service.go +++ b/runner/component/service.go @@ -2,48 +2,88 @@ package component import ( "context" + "errors" "fmt" "log/slog" + "net/http" "path" "strconv" "strings" + "sync" + "time" corev1 "k8s.io/api/core/v1" + k8serrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/client-go/tools/cache" v1 "knative.dev/serving/pkg/apis/serving/v1" + "knative.dev/serving/pkg/client/informers/externalversions" "opencsg.com/csghub-server/builder/deploy/cluster" + "opencsg.com/csghub-server/builder/deploy/common" + "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/types" ) var ( - KeyDeployID string = "deploy_id" - KeyDeployType string = "deploy_type" - KeyUserID string = "user_id" - KeyDeploySKU string = "deploy_sku" + KeyDeployID string = "deploy_id" + KeyDeployType string = "deploy_type" + KeyUserID string = "user_id" + KeyDeploySKU string = "deploy_sku" + KeyOrderDetailID string = "order-detail-id" + KeyMinScale string = "autoscaling.knative.dev/min-scale" ) -type ServiceComponent struct { +type serviceComponentImpl struct { k8sNameSpace string env *config.Config spaceDockerRegBase string modelDockerRegBase string imagePullSecret string + serviceStore database.KnativeServiceStore + clusterPool *cluster.ClusterPool } -func NewServiceComponent(config *config.Config, k8sNameSpace string) *ServiceComponent { - sc := &ServiceComponent{ - k8sNameSpace: k8sNameSpace, +type ServiceComponent interface { + RunService(ctx context.Context, req types.SVCRequest) error + StopService(ctx context.Context, req types.StopRequest) (*types.StopResponse, error) + PurgeService(ctx context.Context, req types.PurgeRequest) (*types.PurgeResponse, error) + UpdateService(ctx context.Context, req types.ModelUpdateRequest) (*types.ModelUpdateResponse, error) + GenerateService(ctx context.Context, cluster cluster.Cluster, request types.SVCRequest) (*v1.Service, error) + // get secret from k8s + // notes: admin should create nim secret "ngc-secret" and "nvidia-nim-secrets" in related namespace before deploy + GetNimSecret(ctx context.Context, cluster cluster.Cluster) (string, error) + GetServicePodsWithStatus(ctx context.Context, cluster cluster.Cluster, srvName string, namespace string) ([]types.Instance, error) + // NewPersistentVolumeClaim creates a new k8s PVC with some default values set. + NewPersistentVolumeClaim(name string, ctx context.Context, cluster cluster.Cluster, hardware types.HardWare) error + RunInformer() + GetServicePods(ctx context.Context, cluster cluster.Cluster, srvName string, namespace string, limit int64) ([]string, error) + GetAllServiceStatus(ctx context.Context) (map[string]*types.StatusResponse, error) + GetServiceByName(ctx context.Context, srvName, clusterId string) (*types.StatusResponse, error) + RemoveServiceForcely(ctx context.Context, cluster *cluster.Cluster, svcName string) error + GetServiceInfo(ctx context.Context, req types.ServiceRequest) (*types.ServiceInfoResponse, error) + AddServiceInDB(srv v1.Service, clusterID string) error + DeleteServiceInDB(srv v1.Service, clusterID string) error + UpdateServiceInDB(srv v1.Service, revision *v1.Revision, clusterID string) error +} + +func NewServiceComponent(config *config.Config, clusterPool *cluster.ClusterPool) ServiceComponent { + domainParts := strings.SplitN(config.Space.InternalRootDomain, ".", 2) + sc := &serviceComponentImpl{ + k8sNameSpace: domainParts[0], env: config, spaceDockerRegBase: config.Space.DockerRegBase, modelDockerRegBase: config.Model.DockerRegBase, imagePullSecret: config.Space.ImagePullSecret, + serviceStore: database.NewKnativeServiceStore(), + clusterPool: clusterPool, } return sc } -func (s *ServiceComponent) GenerateService(ctx context.Context, cluster cluster.Cluster, request types.SVCRequest, srvName string) (*v1.Service, error) { +func (s *serviceComponentImpl) GenerateService(ctx context.Context, cluster cluster.Cluster, request types.SVCRequest) (*v1.Service, error) { annotations := request.Annotation environments := []corev1.EnvVar{} @@ -93,6 +133,7 @@ func (s *ServiceComponent) GenerateService(ctx context.Context, cluster cluster. annotations[KeyDeployType] = strconv.Itoa(request.DeployType) annotations[KeyUserID] = request.UserID annotations[KeyDeploySKU] = request.Sku + annotations[KeyOrderDetailID] = strconv.FormatInt(request.OrderDetailID, 10) containerImg := request.ImageID // add prefix if image is not full path @@ -148,7 +189,7 @@ func (s *ServiceComponent) GenerateService(ctx context.Context, cluster cluster. service := &v1.Service{ ObjectMeta: metav1.ObjectMeta{ - Name: srvName, + Name: request.SrvName, Namespace: s.k8sNameSpace, Annotations: annotations, }, @@ -162,8 +203,6 @@ func (s *ServiceComponent) GenerateService(ctx context.Context, cluster cluster. PodSpec: corev1.PodSpec{ NodeSelector: nodeSelector, Containers: []corev1.Container{{ - // TODO: docker registry url + image id - // Image: "ghcr.io/knative/helloworld-go:latest", Image: containerImg, Ports: exposePorts, Resources: resources, @@ -186,7 +225,7 @@ func (s *ServiceComponent) GenerateService(ctx context.Context, cluster cluster. // get secret from k8s // notes: admin should create nim secret "ngc-secret" and "nvidia-nim-secrets" in related namespace before deploy -func (s *ServiceComponent) GetNimSecret(ctx context.Context, cluster cluster.Cluster) (string, error) { +func (s *serviceComponentImpl) GetNimSecret(ctx context.Context, cluster cluster.Cluster) (string, error) { secret, err := cluster.Client.CoreV1().Secrets(s.k8sNameSpace).Get(ctx, s.env.Model.NimNGCSecretName, metav1.GetOptions{}) if err != nil { return "", err @@ -194,7 +233,7 @@ func (s *ServiceComponent) GetNimSecret(ctx context.Context, cluster cluster.Clu return string(secret.Data["NGC_API_KEY"]), nil } -func (s *ServiceComponent) GetServicePodsWithStatus(ctx context.Context, cluster cluster.Cluster, srvName string, namespace string) ([]types.Instance, error) { +func (s *serviceComponentImpl) GetServicePodsWithStatus(ctx context.Context, cluster cluster.Cluster, srvName string, namespace string) ([]types.Instance, error) { labelSelector := fmt.Sprintf("serving.knative.dev/service=%s", srvName) // Get the list of Pods based on the label selector pods, err := cluster.Client.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{ @@ -207,13 +246,15 @@ func (s *ServiceComponent) GetServicePodsWithStatus(ctx context.Context, cluster // Extract the Pod names and status var podInstances []types.Instance for _, pod := range pods.Items { + if pod.DeletionTimestamp != nil { + continue + } podInstances = append(podInstances, types.Instance{ Name: pod.Name, Status: string(pod.Status.Phase), }, ) - slog.Debug("pod", slog.Any("pod.Name", pod.Name), slog.Any("pod.Status.Phase", pod.Status.Phase)) } return podInstances, nil } @@ -251,7 +292,7 @@ func GenerateResources(hardware types.HardWare) (map[corev1.ResourceName]resourc } // NewPersistentVolumeClaim creates a new k8s PVC with some default values set. -func (s *ServiceComponent) NewPersistentVolumeClaim(name string, ctx context.Context, cluster cluster.Cluster, hardware types.HardWare) error { +func (s *serviceComponentImpl) NewPersistentVolumeClaim(name string, ctx context.Context, cluster cluster.Cluster, hardware types.HardWare) error { // Check if it already exists _, err := cluster.Client.CoreV1().PersistentVolumeClaims(s.k8sNameSpace).Get(ctx, name, metav1.GetOptions{}) if err == nil { @@ -287,3 +328,573 @@ func (s *ServiceComponent) NewPersistentVolumeClaim(name string, ctx context.Con _, err = cluster.Client.CoreV1().PersistentVolumeClaims(s.k8sNameSpace).Create(ctx, &pvc, metav1.CreateOptions{}) return err } + +func (s *serviceComponentImpl) RunInformer() { + var wg sync.WaitGroup + stopCh := make(chan struct{}) + defer close(stopCh) + defer runtime.HandleCrash() + for _, cls := range s.clusterPool.Clusters { + _, err := cls.Client.Discovery().ServerVersion() + if err != nil { + slog.Error("cluster is unavailable ", slog.Any("cluster config", cls.CID), slog.Any("error", err)) + continue + } + wg.Add(2) + go func(cluster cluster.Cluster) { + defer wg.Done() + s.RunRevisionInformer(stopCh, cluster) + }(cls) + go func(cluster cluster.Cluster) { + defer wg.Done() + s.RunServiceInformer(stopCh, cluster) + }(cls) + } + wg.Wait() +} + +// Run Revision informer,mainly handle pod changes +func (s *serviceComponentImpl) RunRevisionInformer(stopCh <-chan struct{}, cluster cluster.Cluster) { + informerFactory := externalversions.NewSharedInformerFactoryWithOptions( + cluster.KnativeClient, + 0, //never resync + externalversions.WithNamespace(s.k8sNameSpace), + ) + informer := informerFactory.Serving().V1().Revisions().Informer() + _, err := informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + UpdateFunc: func(oldObj, newObj interface{}) { + revision := newObj.(*v1.Revision) + service, err := s.getServiceByRevision(revision, cluster.ID) + if err != nil { + slog.Error("failed to get service from revision ", slog.Any("service", service.Name), slog.Any("error", err)) + return + } + err = s.UpdateServiceInDB(*service, revision, cluster.ID) + if err != nil { + slog.Error("failed to update service status ", slog.Any("service", service.Name), slog.Any("error", err)) + } + + }, + }) + if err != nil { + runtime.HandleError(fmt.Errorf("failed to add event handler for knative revision informer")) + } + informer.Run(stopCh) + if !cache.WaitForCacheSync(stopCh, informer.HasSynced) { + runtime.HandleError(fmt.Errorf("timed out waiting for caches to sync for knative revision informer")) + } +} + +// Run service informer, main handle the service changes +func (s *serviceComponentImpl) RunServiceInformer(stopCh <-chan struct{}, cluster cluster.Cluster) { + informerFactory := externalversions.NewSharedInformerFactoryWithOptions( + cluster.KnativeClient, + 1*time.Hour, //sync every 1 hour + externalversions.WithNamespace(s.k8sNameSpace), + ) + informer := informerFactory.Serving().V1().Services().Informer() + _, err := informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj interface{}) { + service := obj.(*v1.Service) + err := s.AddServiceInDB(*service, cluster.ID) + if err != nil { + slog.Error("failed to add service ", slog.Any("service", service.Name), slog.Any("error", err)) + } + }, + UpdateFunc: func(oldObj, newObj interface{}) { + new := newObj.(*v1.Service) + old := oldObj.(*v1.Service) + newStatus := getReadyCondition(new) + oldStatus := getReadyCondition(old) + if newStatus != oldStatus || newStatus == corev1.ConditionUnknown { + err := s.UpdateServiceInDB(*new, nil, cluster.ID) + if err != nil { + slog.Error("failed to update service status ", slog.Any("service", new.Name), slog.Any("error", err)) + } + } + }, + DeleteFunc: func(obj interface{}) { + service := obj.(*v1.Service) + err := s.DeleteServiceInDB(*service, cluster.ID) + if err != nil { + slog.Error("failed to mark service as deleted ", slog.Any("service", service.Name), slog.Any("error", err)) + } + }, + }) + if err != nil { + runtime.HandleError(fmt.Errorf("failed to add event handler for knative service informer")) + } + informer.Run(stopCh) + if !cache.WaitForCacheSync(stopCh, informer.HasSynced) { + runtime.HandleError(fmt.Errorf("timed out waiting for caches to sync for knative service informer")) + } +} + +func (s *serviceComponentImpl) getServiceByRevision(revision *v1.Revision, clusterID string) (*v1.Service, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + serviceName, exists := revision.Labels["serving.knative.dev/service"] + if !exists { + return nil, fmt.Errorf("revision %s does not have a parent service", revision.Name) + } + cluster, err := s.clusterPool.GetClusterByID(context.Background(), clusterID) + if err != nil { + return nil, fmt.Errorf("fail to get cluster,error: %v ", err) + } + return cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace).Get(ctx, serviceName, metav1.GetOptions{}) +} + +func (s *serviceComponentImpl) AddServiceInDB(srv v1.Service, clusterID string) error { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + status, err := s.GetServiceStatus(ctx, srv, clusterID) + if err != nil { + return err + } + deployIDStr := srv.Annotations[KeyDeployID] + deployID, _ := strconv.ParseInt(deployIDStr, 10, 64) + deployTypeStr := srv.Annotations[KeyDeployType] + deployType, err := strconv.Atoi(deployTypeStr) + if err != nil { + deployType = 0 + } + userID := srv.Annotations[KeyUserID] + deploySku := srv.Annotations[KeyDeploySKU] + orderDetailIdStr := srv.Annotations[KeyOrderDetailID] + orderDetailId, err := strconv.ParseInt(orderDetailIdStr, 10, 64) + if err != nil { + orderDetailId = 0 + } + DesiredReplica := 1 + if minScale, ok := srv.Spec.Template.Annotations[KeyMinScale]; ok { + DesiredReplica, _ = strconv.Atoi(minScale) + } + service := &database.KnativeService{ + Code: status.Code, + Name: srv.Name, + ClusterID: clusterID, + Status: getReadyCondition(&srv), + Endpoint: srv.Status.URL.String(), + DeployID: deployID, + UserUUID: userID, + DeployType: deployType, + DeploySKU: deploySku, + OrderDetailID: orderDetailId, + Instances: status.Instances, + DesiredReplica: DesiredReplica, + ActualReplica: len(status.Instances), + } + + return s.serviceStore.Add(ctx, service) +} + +// Delete service, just mark the service as stopped +func (s *serviceComponentImpl) DeleteServiceInDB(srv v1.Service, clusterID string) error { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + return s.serviceStore.Delete(ctx, srv.Name, clusterID) +} + +func (s *serviceComponentImpl) UpdateServiceInDB(srv v1.Service, revision *v1.Revision, clusterID string) error { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + status, err := s.GetServiceStatus(ctx, srv, clusterID) + if err != nil { + return err + } + oldService, err := s.serviceStore.Get(ctx, srv.Name, clusterID) + if err != nil { + return err + } + oldService.Code = status.Code + oldService.Endpoint = srv.Status.URL.String() + oldService.Status = getReadyCondition(&srv) + oldService.Instances = status.Instances + if revision != nil { + DesiredReplicas := 1 + ActualReplicas := 0 + if revision.Status.DesiredReplicas != nil { + DesiredReplicas = int(*revision.Status.DesiredReplicas) + } + + if revision.Status.ActualReplicas != nil { + ActualReplicas = int(*revision.Status.ActualReplicas) + } + oldService.DesiredReplica = DesiredReplicas + oldService.ActualReplica = ActualReplicas + } + return s.serviceStore.Update(ctx, oldService) +} + +func (s *serviceComponentImpl) GetServiceStatus(ctx context.Context, ks v1.Service, clusterID string) (resp types.StatusResponse, err error) { + serviceCondition := ks.Status.GetCondition(v1.ServiceConditionReady) + cluster, err := s.clusterPool.GetClusterByID(ctx, clusterID) + if err != nil { + return resp, fmt.Errorf("fail to get cluster,error: %v ", err) + } + instList, err := s.GetServicePodsWithStatus(ctx, *cluster, ks.Name, ks.Namespace) + if err != nil { + return resp, fmt.Errorf("fail to get service pod name list,error: %v ", err) + } + switch { + case serviceCondition == nil: + resp.Code = common.Deploying + case serviceCondition.Status == corev1.ConditionUnknown: + resp.Code = common.DeployFailed + for _, instance := range instList { + if instance.Status == string(corev1.PodRunning) || instance.Status == string(corev1.PodPending) { + resp.Code = common.Deploying + break + } + } + case serviceCondition.Status == corev1.ConditionTrue: + resp.Code = common.Running + if len(instList) == 0 { + resp.Code = common.Sleeping + } + case serviceCondition.Status == corev1.ConditionFalse: + resp.Code = common.DeployFailed + } + resp.Instances = instList + return resp, err +} + +// corev1.ConditionTrue +func getReadyCondition(service *v1.Service) corev1.ConditionStatus { + for _, condition := range service.Status.Conditions { + if condition.Type == v1.ServiceConditionReady { + return condition.Status + } + } + return corev1.ConditionUnknown +} + +func (s *serviceComponentImpl) GetServicePods(ctx context.Context, cluster cluster.Cluster, srvName string, namespace string, limit int64) ([]string, error) { + labelSelector := fmt.Sprintf("serving.knative.dev/service=%s", srvName) + // Get the list of Pods based on the label selector + opts := metav1.ListOptions{ + LabelSelector: labelSelector, + } + if limit > 0 { + opts = metav1.ListOptions{ + LabelSelector: labelSelector, + Limit: limit, + } + } + pods, err := cluster.Client.CoreV1().Pods(namespace).List(ctx, opts) + if err != nil { + return nil, err + } + + // Extract the Pod names + var podNames []string + for _, pod := range pods.Items { + if pod.DeletionTimestamp != nil { + continue + } + podNames = append(podNames, pod.Name) + } + + return podNames, nil +} + +// GetAllServiceStatus +func (s *serviceComponentImpl) GetAllServiceStatus(ctx context.Context) (map[string]*types.StatusResponse, error) { + allStatus := make(map[string]*types.StatusResponse) + for _, cls := range s.clusterPool.Clusters { + svcs, err := s.serviceStore.GetByCluster(ctx, cls.ID) + if err != nil { + return nil, fmt.Errorf("fail to get service list,error: %v ", err) + } + for _, svc := range svcs { + status := &types.StatusResponse{ + DeployID: svc.DeployID, + UserID: svc.UserUUID, + DeployType: svc.DeployType, + ServiceName: svc.Name, + DeploySku: svc.DeploySKU, + OrderDetailID: svc.OrderDetailID, + Code: svc.Code, + } + allStatus[svc.Name] = status + } + } + return allStatus, nil +} + +// GetServiceStatus +func (s *serviceComponentImpl) GetServiceByName(ctx context.Context, srvName, clusterId string) (*types.StatusResponse, error) { + svc, err := s.serviceStore.Get(ctx, srvName, clusterId) + if err != nil { + return nil, err + } + resp := &types.StatusResponse{ + DeployID: svc.DeployID, + UserID: svc.UserUUID, + DeployType: svc.DeployType, + ServiceName: svc.Name, + DeploySku: svc.DeploySKU, + OrderDetailID: svc.OrderDetailID, + Endpoint: svc.Endpoint, + Code: svc.Code, + Instances: svc.Instances, + Replica: len(svc.Instances), + } + return resp, nil + +} + +// RunService +func (s *serviceComponentImpl) RunService(ctx context.Context, req types.SVCRequest) error { + cluster, err := s.clusterPool.GetClusterByID(ctx, req.ClusterID) + if err != nil { + return fmt.Errorf("fail to get cluster, error %v ", err) + } + + // check if the ksvc exists + _, err = cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace).Get(ctx, req.SrvName, metav1.GetOptions{}) + if err == nil { + err = s.RemoveServiceForcely(ctx, cluster, req.SrvName) + if err != nil { + slog.Error("fail to remove service", slog.Any("error", err), slog.Any("req", req)) + } + slog.Info("service already exists,delete it first", slog.String("srv_name", req.SrvName), slog.Any("image_id", req.ImageID)) + } + service, err := s.GenerateService(ctx, *cluster, req) + if err != nil { + return fmt.Errorf("fail to generate service, %v ", err) + } + volumes := []corev1.Volume{} + volumeMounts := []corev1.VolumeMount{} + if req.DeployType != types.SpaceType { + // dshm volume for multi-gpu share memory + volumes = append(volumes, corev1.Volume{ + Name: "dshm", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{ + Medium: corev1.StorageMediumMemory, + }, + }, + }) + + volumeMounts = append(volumeMounts, corev1.VolumeMount{ + Name: "dshm", + MountPath: "/dev/shm", + }) + } + pvcName := req.SrvName + if req.DeployType == types.InferenceType { + pvcName = req.UserID + } + // add pvc if possible + // space image was built from user's code, model cache dir is hard to control + // so no PV cache for space case so far + if cluster.StorageClass != "" && req.DeployType != types.SpaceType { + err = s.NewPersistentVolumeClaim(pvcName, ctx, *cluster, req.Hardware) + if err != nil { + return fmt.Errorf("failed to create persist volume, %v", err) + } + volumes = append(volumes, corev1.Volume{ + Name: "nas-pvc", + VolumeSource: corev1.VolumeSource{ + PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ + ClaimName: pvcName, + }, + }, + }) + + volumeMounts = append(volumeMounts, corev1.VolumeMount{ + Name: "nas-pvc", + MountPath: "/workspace", + }) + } + service.Spec.Template.Spec.Volumes = volumes + service.Spec.Template.Spec.Containers[0].VolumeMounts = volumeMounts + + slog.Debug("ksvc", slog.Any("knative service", service)) + + // create ksvc + _, err = cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace).Create(ctx, service, metav1.CreateOptions{}) + if err != nil { + return fmt.Errorf("failed to create service, error: %v, req: %v", err, req) + } + return nil +} + +func (s *serviceComponentImpl) RemoveServiceForcely(ctx context.Context, cluster *cluster.Cluster, svcName string) error { + err := cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace).Delete(context.Background(), svcName, *metav1.NewDeleteOptions(0)) + if err != nil { + return err + } + podNames, _ := s.GetServicePods(ctx, *cluster, svcName, s.k8sNameSpace, -1) + if podNames == nil { + return nil + } + //before k8s 1.31, kill pod does not kill the process immediately, instead we still need wait for the process to exit. more details see: https://github.com/kubernetes/kubernetes/issues/120449 + gracePeriodSeconds := int64(10) + deletePolicy := metav1.DeletePropagationForeground + deleteOptions := metav1.DeleteOptions{ + GracePeriodSeconds: &gracePeriodSeconds, + PropagationPolicy: &deletePolicy, + } + + for _, podName := range podNames { + errForce := cluster.Client.CoreV1().Pods(s.k8sNameSpace).Delete(ctx, podName, deleteOptions) + if errForce != nil { + slog.Error("removeServiceForcely failed to delete pod", slog.String("pod_name", podName), slog.Any("error", errForce)) + } + } + return nil +} + +// StopService +func (s *serviceComponentImpl) StopService(ctx context.Context, req types.StopRequest) (*types.StopResponse, error) { + var resp types.StopResponse + cluster, err := s.clusterPool.GetClusterByID(ctx, req.ClusterID) + if err != nil { + return nil, fmt.Errorf("fail to get cluster, error: %v ", err) + } + + srv, err := cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace). + Get(ctx, req.SvcName, metav1.GetOptions{}) + if err != nil { + k8serr := new(k8serrors.StatusError) + if errors.As(err, &k8serr) { + if k8serr.Status().Code == http.StatusNotFound { + slog.Info("stop service skip,service not exist", slog.String("srv_name", req.SvcName), slog.Any("k8s_err", k8serr)) + resp.Code = 0 + resp.Message = "skip,service not exist" + return &resp, nil + } + } + resp.Code = -1 + resp.Message = "failed to get service status" + return &resp, fmt.Errorf("cannot get service info, error: %v", err) + } + + if srv == nil { + resp.Code = 0 + resp.Message = "service not exist" + return &resp, nil + } + err = s.RemoveServiceForcely(ctx, cluster, req.SvcName) + if err != nil { + resp.Code = -1 + resp.Message = "failed to get service status" + return &resp, fmt.Errorf("cannot delete service,error: %v", err) + } + return &resp, nil +} + +// UpdateService +func (s *serviceComponentImpl) UpdateService(ctx context.Context, req types.ModelUpdateRequest) (*types.ModelUpdateResponse, error) { + var resp types.ModelUpdateResponse + cluster, err := s.clusterPool.GetClusterByID(ctx, req.ClusterID) + if err != nil { + return nil, fmt.Errorf("fail to get cluster, error: %v ", err) + } + + srv, err := cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace). + Get(ctx, req.SvcName, metav1.GetOptions{}) + if err != nil { + k8serr := new(k8serrors.StatusError) + if errors.As(err, &k8serr) { + if k8serr.Status().Code == http.StatusNotFound { + resp.Code = 0 + resp.Message = "skipped, service not exist" + return &resp, nil + } + } + resp.Code = -1 + resp.Message = "failed to get service status" + return &resp, fmt.Errorf("cannot get service info, error: %v", err) + } + + if srv == nil { + resp.Code = 0 + resp.Message = "service not exist" + return &resp, nil + } + // Update Image + containerImg := path.Join(s.modelDockerRegBase, req.ImageID) + srv.Spec.Template.Spec.Containers[0].Image = containerImg + // Update env + environments := []corev1.EnvVar{} + if req.Env != nil { + // generate env + for key, value := range req.Env { + environments = append(environments, corev1.EnvVar{Name: key, Value: value}) + } + srv.Spec.Template.Spec.Containers[0].Env = environments + } + // Update CPU and Memory requests and limits + hardware := req.Hardware + resReq, _ := GenerateResources(hardware) + resources := corev1.ResourceRequirements{ + Limits: resReq, + Requests: resReq, + } + srv.Spec.Template.Spec.Containers[0].Resources = resources + // Update replica + srv.Spec.Template.Annotations["autoscaling.knative.dev/min-scale"] = strconv.Itoa(req.MinReplica) + srv.Spec.Template.Annotations["autoscaling.knative.dev/max-scale"] = strconv.Itoa(req.MaxReplica) + + _, err = cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace).Update(ctx, srv, metav1.UpdateOptions{}) + if err != nil { + resp.Code = -1 + resp.Message = "failed to update service" + return &resp, fmt.Errorf("cannot update service, error: %v", err) + } + return &resp, nil +} + +func (s *serviceComponentImpl) PurgeService(ctx context.Context, req types.PurgeRequest) (*types.PurgeResponse, error) { + var resp types.PurgeResponse + cluster, err := s.clusterPool.GetClusterByID(ctx, req.ClusterID) + if err != nil { + return nil, fmt.Errorf("fail to get cluster, error: %v ", err) + } + _, err = cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace). + Get(ctx, req.SvcName, metav1.GetOptions{}) + if err != nil { + k8serr := new(k8serrors.StatusError) + if errors.As(err, &k8serr) { + if k8serr.Status().Code == http.StatusNotFound { + slog.Info("service not exist", slog.String("srv_name", req.SvcName), slog.Any("k8s_err", k8serr)) + } + } + slog.Error("cannot get service info, skip service purge", slog.String("srv_name", req.SvcName), slog.Any("error", err)) + } else { + // 1 delete service + err = s.RemoveServiceForcely(ctx, cluster, req.SvcName) + if err != nil { + resp.Code = -1 + resp.Message = "failed to remove service" + return &resp, fmt.Errorf("failed to remove service, error: %v", err) + } + } + // 2 clean up pvc + if cluster.StorageClass != "" && req.DeployType == types.FinetuneType { + err = cluster.Client.CoreV1().PersistentVolumeClaims(s.k8sNameSpace).Delete(ctx, req.SvcName, metav1.DeleteOptions{}) + if err != nil { + resp.Code = -1 + resp.Message = "failed to remove pvc" + return &resp, fmt.Errorf("failed to remove pvc, error: %v", err) + } + } + return &resp, nil +} + +// GetServiceInfo +func (s *serviceComponentImpl) GetServiceInfo(ctx context.Context, req types.ServiceRequest) (*types.ServiceInfoResponse, error) { + var resp types.ServiceInfoResponse + svc, err := s.serviceStore.Get(ctx, req.ServiceName, req.ClusterID) + if err != nil { + return nil, err + } + resp.ServiceName = svc.Name + for _, v := range svc.Instances { + resp.PodNames = append(resp.PodNames, v.Name) + } + return &resp, nil +} diff --git a/runner/component/service_test.go b/runner/component/service_test.go new file mode 100644 index 00000000..71c4eb3e --- /dev/null +++ b/runner/component/service_test.go @@ -0,0 +1,670 @@ +package component + +import ( + "context" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" + knativefake "knative.dev/serving/pkg/client/clientset/versioned/fake" + mockdb "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/builder/deploy/cluster" + "opencsg.com/csghub-server/builder/deploy/common" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" +) + +func TestServiceComponent_RunService(t *testing.T) { + kss := mockdb.NewMockKnativeServiceStore(t) + ctx := context.TODO() + pool := &cluster.ClusterPool{} + pool.ClusterStore = mockdb.NewMockClusterInfoStore(t) + kubeClient := fake.NewSimpleClientset() + pool.Clusters = append(pool.Clusters, cluster.Cluster{ + CID: "config", + ID: "test", + Client: kubeClient, + KnativeClient: knativefake.NewSimpleClientset(), + }) + sc := &serviceComponentImpl{ + k8sNameSpace: "test", + env: &config.Config{}, + spaceDockerRegBase: "http://test.com", + modelDockerRegBase: "http://test.com", + imagePullSecret: "test", + serviceStore: kss, + clusterPool: pool, + } + req := types.SVCRequest{ + ImageID: "test", + DeployID: 1, + DeployType: types.InferenceType, + RepoType: string(types.ModelRepo), + MinReplica: 1, + MaxReplica: 1, + UserID: "test", + Sku: "1", + SrvName: "test", + Hardware: types.HardWare{ + Gpu: types.GPU{ + Num: "1", + Type: "A10", + }, + Memory: "16Gi", + }, + Env: map[string]string{ + "test": "test", + "port": "8000", + }, + Annotation: map[string]string{}, + } + err := sc.RunService(ctx, req) + require.Nil(t, err) +} + +func TestServiceComponent_StopService(t *testing.T) { + kss := mockdb.NewMockKnativeServiceStore(t) + ctx := context.TODO() + pool := &cluster.ClusterPool{} + cis := mockdb.NewMockClusterInfoStore(t) + pool.ClusterStore = cis + kubeClient := fake.NewSimpleClientset() + cluster := cluster.Cluster{ + CID: "config", + ID: "test", + Client: kubeClient, + KnativeClient: knativefake.NewSimpleClientset(), + } + pool.Clusters = append(pool.Clusters, cluster) + sc := &serviceComponentImpl{ + k8sNameSpace: "test", + env: &config.Config{}, + spaceDockerRegBase: "http://test.com", + modelDockerRegBase: "http://test.com", + imagePullSecret: "test", + serviceStore: kss, + clusterPool: pool, + } + req := types.SVCRequest{ + ImageID: "test", + DeployID: 1, + DeployType: types.InferenceType, + RepoType: string(types.ModelRepo), + MinReplica: 1, + MaxReplica: 1, + UserID: "test", + Sku: "1", + SrvName: "test", + Hardware: types.HardWare{ + Gpu: types.GPU{ + Num: "1", + Type: "A10", + }, + Memory: "16Gi", + }, + Env: map[string]string{ + "test": "test", + "port": "8000", + }, + Annotation: map[string]string{}, + } + err := sc.RunService(ctx, req) + require.Nil(t, err) + cis.EXPECT().ByClusterID(ctx, "test").Return(database.ClusterInfo{ + ClusterID: "test", + ClusterConfig: "config", + StorageClass: "test", + }, nil) + resp, err := sc.StopService(ctx, types.StopRequest{ + SvcName: "test", + ClusterID: "test", + }) + require.Nil(t, err) + require.NotNil(t, resp) + require.Equal(t, resp.Code, 0) +} + +func TestServiceComponent_PurgeService(t *testing.T) { + kss := mockdb.NewMockKnativeServiceStore(t) + ctx := context.TODO() + pool := &cluster.ClusterPool{} + cis := mockdb.NewMockClusterInfoStore(t) + pool.ClusterStore = cis + kubeClient := fake.NewSimpleClientset() + cluster := cluster.Cluster{ + CID: "config", + ID: "test", + Client: kubeClient, + KnativeClient: knativefake.NewSimpleClientset(), + } + pool.Clusters = append(pool.Clusters, cluster) + sc := &serviceComponentImpl{ + k8sNameSpace: "test", + env: &config.Config{}, + spaceDockerRegBase: "http://test.com", + modelDockerRegBase: "http://test.com", + imagePullSecret: "test", + serviceStore: kss, + clusterPool: pool, + } + req := types.SVCRequest{ + ImageID: "test", + DeployID: 1, + DeployType: types.InferenceType, + RepoType: string(types.ModelRepo), + MinReplica: 1, + MaxReplica: 1, + UserID: "test", + Sku: "1", + SrvName: "test", + Hardware: types.HardWare{ + Gpu: types.GPU{ + Num: "1", + Type: "A10", + }, + Memory: "16Gi", + }, + Env: map[string]string{ + "test": "test", + "port": "8000", + }, + Annotation: map[string]string{}, + } + err := sc.RunService(ctx, req) + require.Nil(t, err) + cis.EXPECT().ByClusterID(ctx, "test").Return(database.ClusterInfo{ + ClusterID: "test", + ClusterConfig: "config", + StorageClass: "test", + }, nil) + resp, err := sc.PurgeService(ctx, types.PurgeRequest{ + SvcName: "test", + ClusterID: "test", + }) + require.Nil(t, err) + require.NotNil(t, resp) + require.Equal(t, resp.Code, 0) +} + +func TestServiceComponent_UpdateService(t *testing.T) { + kss := mockdb.NewMockKnativeServiceStore(t) + ctx := context.TODO() + pool := &cluster.ClusterPool{} + cis := mockdb.NewMockClusterInfoStore(t) + pool.ClusterStore = cis + kubeClient := fake.NewSimpleClientset() + cluster := cluster.Cluster{ + CID: "config", + ID: "test", + Client: kubeClient, + KnativeClient: knativefake.NewSimpleClientset(), + } + pool.Clusters = append(pool.Clusters, cluster) + sc := &serviceComponentImpl{ + k8sNameSpace: "test", + env: &config.Config{}, + spaceDockerRegBase: "http://test.com", + modelDockerRegBase: "http://test.com", + imagePullSecret: "test", + serviceStore: kss, + clusterPool: pool, + } + req := types.SVCRequest{ + ImageID: "test", + DeployID: 1, + DeployType: types.InferenceType, + RepoType: string(types.ModelRepo), + MinReplica: 1, + MaxReplica: 1, + UserID: "test", + Sku: "1", + SrvName: "test", + Hardware: types.HardWare{ + Gpu: types.GPU{ + Num: "1", + Type: "A10", + }, + Memory: "16Gi", + }, + Env: map[string]string{ + "test": "test", + "port": "8000", + }, + Annotation: map[string]string{}, + } + err := sc.RunService(ctx, req) + require.Nil(t, err) + cis.EXPECT().ByClusterID(ctx, "test").Return(database.ClusterInfo{ + ClusterID: "test", + ClusterConfig: "config", + StorageClass: "test", + }, nil) + resp, err := sc.UpdateService(ctx, types.ModelUpdateRequest{ + SvcName: "test", + ClusterID: "test", + MinReplica: 2, + MaxReplica: 2, + ImageID: "test2", + Hardware: types.HardWare{ + Gpu: types.GPU{ + Num: "1", + Type: "A10", + }, + Memory: "16Gi", + }, + }) + require.Nil(t, err) + require.NotNil(t, resp) + require.Equal(t, resp.Code, 0) +} +func TestServiceComponent_GetServicePodWithStatus(t *testing.T) { + kss := mockdb.NewMockKnativeServiceStore(t) + ctx := context.TODO() + pool := &cluster.ClusterPool{} + pool.ClusterStore = mockdb.NewMockClusterInfoStore(t) + kubeClient := fake.NewSimpleClientset() + pool.Clusters = append(pool.Clusters, cluster.Cluster{ + CID: "config", + ID: "test", + Client: kubeClient, + KnativeClient: knativefake.NewSimpleClientset(), + }) + sc := &serviceComponentImpl{ + k8sNameSpace: "test", + env: &config.Config{}, + spaceDockerRegBase: "http://test.com", + modelDockerRegBase: "http://test.com", + imagePullSecret: "test", + serviceStore: kss, + clusterPool: pool, + } + req := types.SVCRequest{ + ImageID: "test", + DeployID: 1, + DeployType: types.InferenceType, + RepoType: string(types.ModelRepo), + MinReplica: 1, + MaxReplica: 1, + UserID: "test", + Sku: "1", + SrvName: "test", + Hardware: types.HardWare{ + Gpu: types.GPU{ + Num: "1", + Type: "A10", + }, + Memory: "16Gi", + }, + Env: map[string]string{ + "test": "test", + "port": "8000", + }, + Annotation: map[string]string{}, + } + err := sc.RunService(ctx, req) + require.Nil(t, err) + _, err = sc.GetServicePodsWithStatus(ctx, pool.Clusters[0], "test", "test") + require.Nil(t, err) +} + +func TestServiceComponent_GetAllStatus(t *testing.T) { + kss := mockdb.NewMockKnativeServiceStore(t) + ctx := context.TODO() + pool := &cluster.ClusterPool{} + pool.ClusterStore = mockdb.NewMockClusterInfoStore(t) + kubeClient := fake.NewSimpleClientset() + pool.Clusters = append(pool.Clusters, cluster.Cluster{ + CID: "config", + ID: "test", + Client: kubeClient, + KnativeClient: knativefake.NewSimpleClientset(), + }) + sc := &serviceComponentImpl{ + k8sNameSpace: "test", + env: &config.Config{}, + spaceDockerRegBase: "http://test.com", + modelDockerRegBase: "http://test.com", + imagePullSecret: "test", + serviceStore: kss, + clusterPool: pool, + } + req := types.SVCRequest{ + ImageID: "test", + DeployID: 1, + DeployType: types.InferenceType, + RepoType: string(types.ModelRepo), + MinReplica: 1, + MaxReplica: 1, + UserID: "test", + Sku: "1", + SrvName: "test", + Hardware: types.HardWare{ + Gpu: types.GPU{ + Num: "1", + Type: "A10", + }, + Memory: "16Gi", + }, + Env: map[string]string{ + "test": "test", + "port": "8000", + }, + Annotation: map[string]string{}, + } + err := sc.RunService(ctx, req) + require.Nil(t, err) + kss.EXPECT().GetByCluster(ctx, "test").Return([]database.KnativeService{ + { + Name: "test", + ID: 1, + Code: common.Running, + }, + }, nil) + status, err := sc.GetAllServiceStatus(ctx) + require.Nil(t, err) + require.Equal(t, 1, len(status)) + require.Equal(t, common.Running, status["test"].Code) +} + +func TestServiceComponent_GetServiceByName(t *testing.T) { + kss := mockdb.NewMockKnativeServiceStore(t) + ctx := context.TODO() + pool := &cluster.ClusterPool{} + pool.ClusterStore = mockdb.NewMockClusterInfoStore(t) + kubeClient := fake.NewSimpleClientset() + pool.Clusters = append(pool.Clusters, cluster.Cluster{ + CID: "config", + ID: "test", + Client: kubeClient, + KnativeClient: knativefake.NewSimpleClientset(), + }) + sc := &serviceComponentImpl{ + k8sNameSpace: "test", + env: &config.Config{}, + spaceDockerRegBase: "http://test.com", + modelDockerRegBase: "http://test.com", + imagePullSecret: "test", + serviceStore: kss, + clusterPool: pool, + } + req := types.SVCRequest{ + ImageID: "test", + DeployID: 1, + DeployType: types.InferenceType, + RepoType: string(types.ModelRepo), + MinReplica: 1, + MaxReplica: 1, + UserID: "test", + Sku: "1", + SrvName: "test", + Hardware: types.HardWare{ + Gpu: types.GPU{ + Num: "1", + Type: "A10", + }, + Memory: "16Gi", + }, + Env: map[string]string{ + "test": "test", + "port": "8000", + }, + Annotation: map[string]string{}, + } + err := sc.RunService(ctx, req) + require.Nil(t, err) + kss.EXPECT().Get(ctx, "test", "test").Return(&database.KnativeService{ + Name: "test", + ID: 1, + Code: common.Running, + }, nil) + resp, err := sc.GetServiceByName(ctx, "test", "test") + require.Nil(t, err) + require.Equal(t, "test", resp.ServiceName) + require.Equal(t, common.Running, resp.Code) +} + +func TestServiceComponent_GetServiceInfo(t *testing.T) { + kss := mockdb.NewMockKnativeServiceStore(t) + ctx := context.TODO() + pool := &cluster.ClusterPool{} + pool.ClusterStore = mockdb.NewMockClusterInfoStore(t) + kubeClient := fake.NewSimpleClientset() + pool.Clusters = append(pool.Clusters, cluster.Cluster{ + CID: "config", + ID: "test", + Client: kubeClient, + KnativeClient: knativefake.NewSimpleClientset(), + }) + sc := &serviceComponentImpl{ + k8sNameSpace: "test", + env: &config.Config{}, + spaceDockerRegBase: "http://test.com", + modelDockerRegBase: "http://test.com", + imagePullSecret: "test", + serviceStore: kss, + clusterPool: pool, + } + req := types.SVCRequest{ + ImageID: "test", + DeployID: 1, + DeployType: types.InferenceType, + RepoType: string(types.ModelRepo), + MinReplica: 1, + MaxReplica: 1, + UserID: "test", + Sku: "1", + SrvName: "test", + Hardware: types.HardWare{ + Gpu: types.GPU{ + Num: "1", + Type: "A10", + }, + Memory: "16Gi", + }, + Env: map[string]string{ + "test": "test", + "port": "8000", + }, + Annotation: map[string]string{}, + } + err := sc.RunService(ctx, req) + require.Nil(t, err) + kss.EXPECT().Get(ctx, "test", "test").Return(&database.KnativeService{ + Name: "test", + ID: 1, + Code: common.Running, + }, nil) + resp, err := sc.GetServiceInfo(ctx, types.ServiceRequest{ + ServiceName: "test", + ClusterID: "test", + }) + require.Nil(t, err) + require.Equal(t, "test", resp.ServiceName) +} + +func TestServiceComponent_AddServiceInDB(t *testing.T) { + kss := mockdb.NewMockKnativeServiceStore(t) + pool := &cluster.ClusterPool{} + cis := mockdb.NewMockClusterInfoStore(t) + pool.ClusterStore = cis + kubeClient := fake.NewSimpleClientset() + knativeClient := knativefake.NewSimpleClientset() + pool.Clusters = append(pool.Clusters, cluster.Cluster{ + CID: "config", + ID: "test", + Client: kubeClient, + KnativeClient: knativeClient, + }) + sc := &serviceComponentImpl{ + k8sNameSpace: "test", + env: &config.Config{}, + spaceDockerRegBase: "http://test.com", + modelDockerRegBase: "http://test.com", + imagePullSecret: "test", + serviceStore: kss, + clusterPool: pool, + } + req := types.SVCRequest{ + ImageID: "test", + DeployID: 1, + DeployType: types.InferenceType, + RepoType: string(types.ModelRepo), + MinReplica: 1, + MaxReplica: 1, + UserID: "test", + Sku: "1", + SrvName: "test", + Hardware: types.HardWare{ + Gpu: types.GPU{ + Num: "1", + Type: "A10", + }, + Memory: "16Gi", + }, + Env: map[string]string{ + "test": "test", + "port": "8000", + }, + Annotation: map[string]string{}, + } + ctx := context.TODO() + err := sc.RunService(ctx, req) + require.Nil(t, err) + cis.EXPECT().ByClusterID(mock.Anything, "test").Return(database.ClusterInfo{ + ClusterID: "test", + ClusterConfig: "config", + StorageClass: "test", + }, nil) + kss.EXPECT().Add(mock.Anything, mock.Anything).Return(nil) + ksvc, err := knativeClient.ServingV1().Services("test"). + Get(ctx, "test", metav1.GetOptions{}) + require.Nil(t, err) + err = sc.AddServiceInDB(*ksvc, "test") + require.Nil(t, err) +} + +func TestServiceComponent_updateServiceInDB(t *testing.T) { + kss := mockdb.NewMockKnativeServiceStore(t) + pool := &cluster.ClusterPool{} + cis := mockdb.NewMockClusterInfoStore(t) + pool.ClusterStore = cis + kubeClient := fake.NewSimpleClientset() + knativeClient := knativefake.NewSimpleClientset() + pool.Clusters = append(pool.Clusters, cluster.Cluster{ + CID: "config", + ID: "test", + Client: kubeClient, + KnativeClient: knativeClient, + }) + sc := &serviceComponentImpl{ + k8sNameSpace: "test", + env: &config.Config{}, + spaceDockerRegBase: "http://test.com", + modelDockerRegBase: "http://test.com", + imagePullSecret: "test", + serviceStore: kss, + clusterPool: pool, + } + req := types.SVCRequest{ + ImageID: "test", + DeployID: 1, + DeployType: types.InferenceType, + RepoType: string(types.ModelRepo), + MinReplica: 1, + MaxReplica: 1, + UserID: "test", + Sku: "1", + SrvName: "test", + Hardware: types.HardWare{ + Gpu: types.GPU{ + Num: "1", + Type: "A10", + }, + Memory: "16Gi", + }, + Env: map[string]string{ + "test": "test", + "port": "8000", + }, + Annotation: map[string]string{}, + } + ctx := context.TODO() + err := sc.RunService(ctx, req) + require.Nil(t, err) + cis.EXPECT().ByClusterID(mock.Anything, "test").Return(database.ClusterInfo{ + ClusterID: "test", + ClusterConfig: "config", + StorageClass: "test", + }, nil) + kss.EXPECT().Update(mock.Anything, mock.Anything).Return(nil) + kss.EXPECT().Get(mock.Anything, "test", "test").Return(&database.KnativeService{ + ID: 1, + Name: "test", + ClusterID: "test", + Code: common.Running, + }, nil) + ksvc, err := knativeClient.ServingV1().Services("test"). + Get(ctx, "test", metav1.GetOptions{}) + require.Nil(t, err) + err = sc.UpdateServiceInDB(*ksvc, nil, "test") + require.Nil(t, err) +} + +func TestServiceComponent_deleteServiceInDB(t *testing.T) { + kss := mockdb.NewMockKnativeServiceStore(t) + pool := &cluster.ClusterPool{} + cis := mockdb.NewMockClusterInfoStore(t) + pool.ClusterStore = cis + kubeClient := fake.NewSimpleClientset() + knativeClient := knativefake.NewSimpleClientset() + pool.Clusters = append(pool.Clusters, cluster.Cluster{ + CID: "config", + ID: "test", + Client: kubeClient, + KnativeClient: knativeClient, + }) + sc := &serviceComponentImpl{ + k8sNameSpace: "test", + env: &config.Config{}, + spaceDockerRegBase: "http://test.com", + modelDockerRegBase: "http://test.com", + imagePullSecret: "test", + serviceStore: kss, + clusterPool: pool, + } + req := types.SVCRequest{ + ImageID: "test", + DeployID: 1, + DeployType: types.InferenceType, + RepoType: string(types.ModelRepo), + MinReplica: 1, + MaxReplica: 1, + UserID: "test", + Sku: "1", + SrvName: "test", + Hardware: types.HardWare{ + Gpu: types.GPU{ + Num: "1", + Type: "A10", + }, + Memory: "16Gi", + }, + Env: map[string]string{ + "test": "test", + "port": "8000", + }, + Annotation: map[string]string{}, + } + ctx := context.TODO() + err := sc.RunService(ctx, req) + require.Nil(t, err) + kss.EXPECT().Delete(mock.Anything, "test", "test").Return(nil) + ksvc, err := knativeClient.ServingV1().Services("test"). + Get(ctx, "test", metav1.GetOptions{}) + require.Nil(t, err) + err = sc.DeleteServiceInDB(*ksvc, "test") + require.Nil(t, err) +} diff --git a/runner/component/workflow.go b/runner/component/workflow.go index 3b164256..9038aeff 100644 --- a/runner/component/workflow.go +++ b/runner/component/workflow.go @@ -55,8 +55,6 @@ func NewWorkFlowComponent(config *config.Config, clusterPool *cluster.ClusterPoo clusterPool: clusterPool, eventPub: &event.DefaultEventPublisher, } - //watch workflows events - go wc.RunWorkflowsInformer(clusterPool, config) return wc } @@ -271,7 +269,7 @@ func generateWorkflow(req types.ArgoWorkFlowReq, config *config.Config) *v1alpha func (wc *workFlowComponentImpl) RunWorkflowsInformer(clusterPool *cluster.ClusterPool, c *config.Config) { clientset := clusterPool.Clusters[0].ArgoClient - f := externalversions.NewSharedInformerFactoryWithOptions(clientset, 60*time.Second, externalversions.WithTweakListOptions(func(list *v1.ListOptions) { + f := externalversions.NewSharedInformerFactoryWithOptions(clientset, 2*time.Minute, externalversions.WithTweakListOptions(func(list *v1.ListOptions) { list.LabelSelector = "workflow-scope=csghub" })) informer := f.Argoproj().V1alpha1().Workflows().Informer() @@ -371,7 +369,7 @@ func (wc *workFlowComponentImpl) StartAcctRequestFee(wf database.ArgoWorkflow) { // get cluster func GetCluster(ctx context.Context, clusterPool *cluster.ClusterPool, clusterID string) (*cluster.Cluster, string, error) { if clusterID == "" { - clusterInfo, err := clusterPool.ClusterStore.ByClusterConfig(ctx, clusterPool.Clusters[0].ID) + clusterInfo, err := clusterPool.ClusterStore.ByClusterConfig(ctx, clusterPool.Clusters[0].CID) if err != nil { return nil, "", fmt.Errorf("failed to get cluster info: %v", err) } diff --git a/runner/component/workflow_test.go b/runner/component/workflow_test.go new file mode 100644 index 00000000..26b6ebed --- /dev/null +++ b/runner/component/workflow_test.go @@ -0,0 +1,170 @@ +package component + +import ( + "context" + "testing" + + "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" + argofake "github.com/argoproj/argo-workflows/v3/pkg/client/clientset/versioned/fake" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "k8s.io/client-go/kubernetes/fake" + knativefake "knative.dev/serving/pkg/client/clientset/versioned/fake" + mockdb "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/builder/deploy/cluster" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" +) + +func TestArgoComponent_CreateWorkflow(t *testing.T) { + argoStore := mockdb.NewMockArgoWorkFlowStore(t) + pool := &cluster.ClusterPool{} + cis := mockdb.NewMockClusterInfoStore(t) + pool.ClusterStore = cis + kubeClient := fake.NewSimpleClientset() + pool.Clusters = append(pool.Clusters, cluster.Cluster{ + CID: "config", + ID: "test", + Client: kubeClient, + KnativeClient: knativefake.NewSimpleClientset(), + ArgoClient: argofake.NewSimpleClientset(), + }) + wfc := workFlowComponentImpl{ + wf: argoStore, + clusterPool: pool, + config: &config.Config{}, + } + cis.EXPECT().ByClusterID(mock.Anything, "test").Return(database.ClusterInfo{ + ClusterID: "test", + ClusterConfig: "config", + StorageClass: "test", + }, nil) + ctx := context.TODO() + argoStore.EXPECT().CreateWorkFlow(ctx, mock.Anything).Return(&database.ArgoWorkflow{ + ID: 1, + ClusterID: "test", + RepoType: "test", + TaskName: "test", + TaskId: "test", + Username: "test", + UserUUID: "test", + RepoIds: []string{"test"}, + Status: v1alpha1.WorkflowPhase(v1alpha1.NodePending), + }, nil) + wf, err := wfc.CreateWorkflow(ctx, types.ArgoWorkFlowReq{ + ClusterID: "test", + RepoType: string(types.ModelRepo), + TaskName: "test", + TaskId: "test", + Username: "test", + UserUUID: "test", + RepoIds: []string{"test"}, + Datasets: []string{"test"}, + Image: "test", + }) + require.Nil(t, err) + require.Equal(t, v1alpha1.WorkflowPhase(v1alpha1.NodePending), wf.Status) +} + +// func TestArgoComponent_UpdateWorkflow(t *testing.T) { +// argoStore := mockdb.NewMockArgoWorkFlowStore(t) +// pool := &cluster.ClusterPool{} +// cis := mockdb.NewMockClusterInfoStore(t) +// pool.ClusterStore = cis +// kubeClient := fake.NewSimpleClientset() +// argoClient := argofake.NewSimpleClientset() +// pool.Clusters = append(pool.Clusters, cluster.Cluster{ +// CID: "config", +// ID: "test", +// Client: kubeClient, +// KnativeClient: knativefake.NewSimpleClientset(), +// ArgoClient: argoClient, +// }) +// wfc := workFlowComponentImpl{ +// wf: argoStore, +// clusterPool: pool, +// config: &config.Config{}, +// } +// cis.EXPECT().ByClusterID(mock.Anything, "test").Return(database.ClusterInfo{ +// ClusterID: "test", +// ClusterConfig: "config", +// StorageClass: "test", +// }, nil) +// ctx := context.TODO() +// argoStore.EXPECT().CreateWorkFlow(ctx, mock.Anything).Return(&database.ArgoWorkflow{ +// ID: 1, +// ClusterID: "test", +// RepoType: "test", +// TaskName: "test", +// TaskId: "test", +// Username: "test", +// UserUUID: "test", +// RepoIds: []string{"test"}, +// Status: v1alpha1.WorkflowPhase(v1alpha1.NodePending), +// }, nil) +// wf, err := wfc.CreateWorkflow(ctx, types.ArgoWorkFlowReq{ +// ClusterID: "test", +// RepoType: string(types.ModelRepo), +// TaskName: "test", +// TaskId: "test", +// Username: "test", +// UserUUID: "test", +// RepoIds: []string{"test"}, +// Datasets: []string{"test"}, +// Image: "test", +// }) +// require.Nil(t, err) +// require.Equal(t, v1alpha1.WorkflowPhase(v1alpha1.NodePending), wf.Status) +// oldWF, err := argoClient.ArgoprojV1alpha1().Workflows("").Get(ctx, "test", metav1.GetOptions{}) +// require.Nil(t, err) +// oldWF.Status = v1alpha1.WorkflowStatus{ +// Phase: v1alpha1.WorkflowRunning, +// } +// arf, err := wfc.UpdateWorkflow(ctx, oldWF) +// require.Nil(t, err) +// require.Equal(t, v1alpha1.WorkflowPhase(v1alpha1.WorkflowRunning), arf.Status) +// } + +func TestArgoComponent_DeleteWorkflow(t *testing.T) { + argoStore := mockdb.NewMockArgoWorkFlowStore(t) + pool := &cluster.ClusterPool{} + cis := mockdb.NewMockClusterInfoStore(t) + pool.ClusterStore = cis + kubeClient := fake.NewSimpleClientset() + argoClient := argofake.NewSimpleClientset() + pool.Clusters = append(pool.Clusters, cluster.Cluster{ + CID: "config", + ID: "test", + Client: kubeClient, + KnativeClient: knativefake.NewSimpleClientset(), + ArgoClient: argoClient, + }) + wfc := workFlowComponentImpl{ + wf: argoStore, + clusterPool: pool, + config: &config.Config{}, + } + cis.EXPECT().ByClusterID(mock.Anything, "test").Return(database.ClusterInfo{ + ClusterID: "test", + ClusterConfig: "config", + StorageClass: "test", + }, nil) + ctx := context.TODO() + argoStore.EXPECT().DeleteWorkFlow(ctx, mock.Anything).Return(nil) + argoStore.EXPECT().FindByID(ctx, mock.Anything).Return(database.ArgoWorkflow{ + ID: 1, + ClusterID: "test", + RepoType: "test", + TaskName: "test", + TaskId: "test", + Username: "test", + UserUUID: "test", + RepoIds: []string{"test"}, + Status: v1alpha1.WorkflowPhase(v1alpha1.NodePending), + }, nil) + err := wfc.DeleteWorkflow(ctx, 1, "test") + require.Nil(t, err) + +} diff --git a/runner/handler/service.go b/runner/handler/service.go index 863e4c28..7014f779 100644 --- a/runner/handler/service.go +++ b/runner/handler/service.go @@ -2,20 +2,17 @@ package handler import ( "context" + "database/sql" "errors" "fmt" "log/slog" "net/http" - "path" - "strconv" "strings" "time" "github.com/gin-gonic/gin" corev1 "k8s.io/api/core/v1" - k8serrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - v1 "knative.dev/serving/pkg/apis/serving/v1" "opencsg.com/csghub-server/builder/deploy/cluster" "opencsg.com/csghub-server/builder/deploy/common" "opencsg.com/csghub-server/builder/store/database" @@ -29,17 +26,18 @@ type K8sHandler struct { k8sNameSpace string modelDockerRegBase string env *config.Config - s *component.ServiceComponent + serviceCompoent component.ServiceComponent } func NewK8sHandler(config *config.Config, clusterPool *cluster.ClusterPool) (*K8sHandler, error) { domainParts := strings.SplitN(config.Space.InternalRootDomain, ".", 2) - serviceComponent := component.NewServiceComponent(config, domainParts[0]) + serviceComponent := component.NewServiceComponent(config, clusterPool) + go serviceComponent.RunInformer() return &K8sHandler{ k8sNameSpace: domainParts[0], clusterPool: clusterPool, env: config, - s: serviceComponent, + serviceCompoent: serviceComponent, modelDockerRegBase: config.Model.DockerRegBase, }, nil } @@ -52,99 +50,20 @@ func (s *K8sHandler) RunService(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - slog.Debug("Recv request", slog.Any("body", request)) - - cluster, err := s.clusterPool.GetClusterByID(c, request.ClusterID) - if err != nil { - slog.Error("fail to get cluster ", slog.Any("error", err), slog.Any("req", request)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - srvName := s.getServiceNameFromRequest(c) - // check if the ksvc exists - _, err = cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace).Get(c.Request.Context(), srvName, metav1.GetOptions{}) - if err == nil { - err = s.removeServiceForcely(c, cluster, srvName) - if err != nil { - slog.Error("fail to remove service", slog.Any("error", err), slog.Any("req", request)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - } - slog.Info("service already exists,delete it first", slog.String("srv_name", srvName), slog.Any("image_id", request.ImageID)) - } - service, err := s.s.GenerateService(c, *cluster, *request, srvName) - if err != nil { - slog.Error("fail to generate service ", slog.Any("error", err), slog.Any("req", request)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - volumes := []corev1.Volume{} - volumeMounts := []corev1.VolumeMount{} - if request.DeployType != types.SpaceType { - // dshm volume for multi-gpu share memory - volumes = append(volumes, corev1.Volume{ - Name: "dshm", - VolumeSource: corev1.VolumeSource{ - EmptyDir: &corev1.EmptyDirVolumeSource{ - Medium: corev1.StorageMediumMemory, - }, - }, - }) - - volumeMounts = append(volumeMounts, corev1.VolumeMount{ - Name: "dshm", - MountPath: "/dev/shm", - }) - } - pvcName := srvName - if request.DeployType == types.InferenceType { - pvcName = request.UserID - } - // add pvc if possible - // space image was built from user's code, model cache dir is hard to control - // so no PV cache for space case so far - if cluster.StorageClass != "" && request.DeployType != types.SpaceType { - err = s.s.NewPersistentVolumeClaim(pvcName, c, *cluster, request.Hardware) - if err != nil { - slog.Error("Failed to create persist volume", "error", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create persist volume"}) - return - } - volumes = append(volumes, corev1.Volume{ - Name: "nas-pvc", - VolumeSource: corev1.VolumeSource{ - PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ - ClaimName: pvcName, - }, - }, - }) - - volumeMounts = append(volumeMounts, corev1.VolumeMount{ - Name: "nas-pvc", - MountPath: "/workspace", - }) - } - service.Spec.Template.Spec.Volumes = volumes - service.Spec.Template.Spec.Containers[0].VolumeMounts = volumeMounts - - slog.Debug("ksvc", slog.Any("knative service", service)) - - // create ksvc - _, err = cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace).Create(c, service, metav1.CreateOptions{}) + request.SrvName = srvName + err = s.serviceCompoent.RunService(c.Request.Context(), *request) if err != nil { - slog.Error("Failed to create service", "error", err, slog.Int64("deploy_id", request.DeployID), - slog.String("image_id", request.ImageID), - slog.String("srv_name", srvName)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create service"}) + slog.Error("fail to run service", slog.Any("error", err), slog.Any("req", request)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - slog.Info("service created successfully", slog.String("srv_name", srvName), slog.Int64("deploy_id", request.DeployID)) c.JSON(http.StatusOK, gin.H{"message": "Service created successfully"}) } func (s *K8sHandler) StopService(c *gin.Context) { - var resp types.StopResponse + var request = &types.StopRequest{} err := c.BindJSON(request) @@ -153,48 +72,12 @@ func (s *K8sHandler) StopService(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - cluster, err := s.clusterPool.GetClusterByID(c, request.ClusterID) - if err != nil { - slog.Error("fail to get cluster ", slog.Any("error", err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - srvName := s.getServiceNameFromRequest(c) - srv, err := cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace). - Get(c.Request.Context(), srvName, metav1.GetOptions{}) - if err != nil { - k8serr := new(k8serrors.StatusError) - if errors.As(err, &k8serr) { - if k8serr.Status().Code == http.StatusNotFound { - slog.Info("stop image skip,service not exist", slog.String("srv_name", srvName), slog.Any("k8s_err", k8serr)) - resp.Code = 0 - resp.Message = "skip,service not exist" - c.JSON(http.StatusOK, nil) - return - } - } - slog.Error("stop image failed, cannot get service info", slog.String("srv_name", srvName), slog.Any("error", err), - slog.String("srv_name", srvName)) - resp.Code = -1 - resp.Message = "failed to get service status" - c.JSON(http.StatusInternalServerError, resp) - return - } - - if srv == nil { - resp.Code = 0 - resp.Message = "service not exist" - c.JSON(http.StatusOK, resp) - return - } - err = s.removeServiceForcely(c, cluster, srvName) + request.SvcName = srvName + resp, err := s.serviceCompoent.StopService(c.Request.Context(), *request) if err != nil { - slog.Error("stop image failed, cannot delete service ", slog.String("srv_name", srvName), slog.Any("error", err), - slog.String("srv_name", srvName)) - resp.Code = -1 - resp.Message = "failed to get service status" - c.JSON(http.StatusInternalServerError, resp) + slog.Error("failed to stop service", slog.Any("error", err), slog.Any("req", request)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -204,34 +87,8 @@ func (s *K8sHandler) StopService(c *gin.Context) { c.JSON(http.StatusOK, resp) } -func (s *K8sHandler) removeServiceForcely(c *gin.Context, cluster *cluster.Cluster, svcName string) error { - err := cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace).Delete(context.Background(), svcName, *metav1.NewDeleteOptions(0)) - if err != nil { - return err - } - podNames, _ := s.GetServicePods(c.Request.Context(), *cluster, svcName, s.k8sNameSpace, -1) - if podNames == nil { - return nil - } - //before k8s 1.31, kill pod does not kill the process immediately, instead we still need wait for the process to exit. more details see: https://github.com/kubernetes/kubernetes/issues/120449 - gracePeriodSeconds := int64(10) - deletePolicy := metav1.DeletePropagationForeground - deleteOptions := metav1.DeleteOptions{ - GracePeriodSeconds: &gracePeriodSeconds, - PropagationPolicy: &deletePolicy, - } - - for _, podName := range podNames { - errForce := cluster.Client.CoreV1().Pods(s.k8sNameSpace).Delete(c.Request.Context(), podName, deleteOptions) - if errForce != nil { - slog.Error("removeServiceForcely failed to delete pod", slog.String("pod_name", podName), slog.Any("error", errForce)) - } - } - return nil -} - func (s *K8sHandler) UpdateService(c *gin.Context) { - var resp types.ModelUpdateResponse + var request = &types.ModelUpdateRequest{} err := c.BindJSON(request) @@ -240,75 +97,13 @@ func (s *K8sHandler) UpdateService(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - cluster, err := s.clusterPool.GetClusterByID(c, request.ClusterID) - if err != nil { - slog.Error("fail to get cluster ", slog.Any("error", err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - srvName := s.getServiceNameFromRequest(c) - srv, err := cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace). - Get(c.Request.Context(), srvName, metav1.GetOptions{}) + request.SvcName = srvName + resp, err := s.serviceCompoent.UpdateService(c.Request.Context(), *request) if err != nil { - k8serr := new(k8serrors.StatusError) - if errors.As(err, &k8serr) { - if k8serr.Status().Code == http.StatusNotFound { - slog.Info("update service skip,service not exist", slog.String("srv_name", srvName), slog.Any("k8s_err", k8serr)) - resp.Code = 0 - resp.Message = "skip,service not exist" - c.JSON(http.StatusOK, nil) - return - } - } - slog.Error("update service failed, cannot get service info", slog.String("srv_name", srvName), slog.Any("error", err), - slog.String("srv_name", srvName)) - resp.Code = -1 - resp.Message = "failed to get service status" - c.JSON(http.StatusInternalServerError, resp) - return + slog.Error("failed to update service", slog.Any("error", err), slog.Any("req", request)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } - - if srv == nil { - resp.Code = 0 - resp.Message = "service not exist" - c.JSON(http.StatusOK, resp) - return - } - // Update Image - containerImg := path.Join(s.modelDockerRegBase, request.ImageID) - srv.Spec.Template.Spec.Containers[0].Image = containerImg - // Update env - environments := []corev1.EnvVar{} - if request.Env != nil { - // generate env - for key, value := range request.Env { - environments = append(environments, corev1.EnvVar{Name: key, Value: value}) - } - srv.Spec.Template.Spec.Containers[0].Env = environments - } - // Update CPU and Memory requests and limits - hardware := request.Hardware - resReq, _ := component.GenerateResources(hardware) - resources := corev1.ResourceRequirements{ - Limits: resReq, - Requests: resReq, - } - srv.Spec.Template.Spec.Containers[0].Resources = resources - // Update replica - srv.Spec.Template.Annotations["autoscaling.knative.dev/min-scale"] = strconv.Itoa(request.MinReplica) - srv.Spec.Template.Annotations["autoscaling.knative.dev/max-scale"] = strconv.Itoa(request.MaxReplica) - - _, err = cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace).Update(c, srv, metav1.UpdateOptions{}) - if err != nil { - slog.Error("failed to update service ", slog.String("srv_name", srvName), slog.Any("error", err), - slog.String("srv_name", srvName)) - resp.Code = -1 - resp.Message = "failed to update service" - c.JSON(http.StatusInternalServerError, resp) - return - } - slog.Info("service updated", slog.String("srv_name", srvName)) resp.Code = 0 resp.Message = "service updated" @@ -316,104 +111,19 @@ func (s *K8sHandler) UpdateService(c *gin.Context) { } func (s *K8sHandler) ServiceStatus(c *gin.Context) { - var resp types.StatusResponse - var request = &types.StatusRequest{} err := c.BindJSON(request) - if err != nil { slog.Error("serviceStatus get bad request", slog.Any("error", err), slog.Any("req", request)) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - - cluster, err := s.clusterPool.GetClusterByID(c, request.ClusterID) - - if err != nil { - slog.Error("fail to get cluster ", slog.Any("error", err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - srvName := s.getServiceNameFromRequest(c) - srv, err := cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace). - Get(c.Request.Context(), srvName, metav1.GetOptions{}) + resp, err := s.serviceCompoent.GetServiceByName(c.Request.Context(), srvName, request.ClusterID) if err != nil { - slog.Error("get image status failed, cannot get service info", slog.String("srv_name", srvName), slog.Any("error", err), - slog.String("srv_name", srvName)) - resp.Code = common.Stopped - resp.Message = "failed to get service status" - c.JSON(http.StatusOK, resp) - return - } - deployIDStr := srv.Annotations["deploy_id"] - deployID, _ := strconv.ParseInt(deployIDStr, 10, 64) - resp.DeployID = deployID - resp.UserID = srv.Annotations["user_id"] - - // retrieve pod list and status - if request.NeedDetails { - instList, err := s.s.GetServicePodsWithStatus(c.Request.Context(), *cluster, srvName, s.k8sNameSpace) - if err != nil { - slog.Error("fail to get service pod name list", slog.Any("error", err)) - c.JSON(http.StatusNotFound, gin.H{"error": "fail to get service pod name list"}) - return - } - resp.Instances = instList - } - - if srv.IsFailed() { - resp.Code = common.DeployFailed - // read message of Ready - resp.Message = srv.Status.GetCondition(v1.ServiceConditionReady).Message - // append message of ConfigurationsReady - srvConfigReady := srv.Status.GetCondition(v1.ServiceConditionConfigurationsReady) - if srvConfigReady != nil { - resp.Message += srvConfigReady.Message - } - // for inference case: model loading case one pod is not ready - for _, instance := range resp.Instances { - if instance.Status == string(corev1.PodRunning) || instance.Status == string(corev1.PodPending) { - resp.Code = common.Deploying - break - } - } - slog.Info("service status is failed", slog.String("srv_name", srvName), slog.Any("resp", resp)) - c.JSON(http.StatusOK, resp) - return - } - - if srv.IsReady() { - podNames, err := s.GetServicePods(c.Request.Context(), *cluster, srvName, s.k8sNameSpace, 1) - if err != nil { - slog.Error("get image status failed, can not get pods info", slog.String("srv_name", srvName), slog.Any("error", err)) - c.JSON(http.StatusInternalServerError, gin.H{"code": 0, "message": "unknown service status, failed to get pods"}) - return - } - if len(podNames) == 0 { - resp.Code = common.Sleeping - resp.Message = "service sleeping, no running pods" - slog.Debug("get image status success", slog.String("srv_name", srvName), slog.Any("resp", resp)) - c.JSON(http.StatusOK, resp) - return - } - - resp.Code = common.Running - resp.Message = "service running" - if srv.Status.URL != nil { - slog.Debug("knative endpoint", slog.Any("svc name", srvName), slog.Any("url", srv.Status.URL.URL().String())) - resp.Endpoint = srv.Status.URL.URL().String() - } - - slog.Debug("service status is ready", slog.String("srv_name", srvName), slog.Any("resp", resp)) - c.JSON(http.StatusOK, resp) - return + slog.Error("failed to get service", slog.Any("error", err), slog.String("srv_name", srvName)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get service"}) } - - // default to deploying status - resp.Code = common.Deploying - resp.Message = "service is not ready or failed" - slog.Info("get service status success, service is not ready or failed", slog.String("srv_name", srvName), slog.Any("resp", resp)) c.JSON(http.StatusOK, resp) } @@ -422,7 +132,7 @@ func (s *K8sHandler) ServiceLogs(c *gin.Context) { err := c.BindJSON(request) if err != nil { - slog.Error("serviceLogs get bad request", slog.Any("error", err), slog.Any("req", request)) + slog.Error("get bad request", slog.Any("error", err), slog.Any("req", request)) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } @@ -433,7 +143,7 @@ func (s *K8sHandler) ServiceLogs(c *gin.Context) { return } srvName := s.getServiceNameFromRequest(c) - podNames, err := s.GetServicePods(c.Request.Context(), *cluster, srvName, s.k8sNameSpace, 1) + podNames, err := s.serviceCompoent.GetServicePods(c.Request.Context(), *cluster, srvName, s.k8sNameSpace, 1) if err != nil { slog.Error("failed to read image logs, cannot get pods info", slog.Any("error", err), slog.String("srv_name", srvName)) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get pods info"}) @@ -539,96 +249,19 @@ func (s *K8sHandler) GetLogsByPod(c *gin.Context, cluster cluster.Cluster, podNa } func (s *K8sHandler) ServiceStatusAll(c *gin.Context) { - allStatus := make(map[string]*types.StatusResponse) - for index := range s.clusterPool.Clusters { - cluster := s.clusterPool.Clusters[index] - services, err := cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace). - List(c.Request.Context(), metav1.ListOptions{}) - if err != nil { - slog.Error("get image status all failed, cannot get service infos", slog.Any("error", err)) - //continue to next in multi cluster - continue - } - - for _, srv := range services.Items { - deployIDStr := srv.Annotations[component.KeyDeployID] - deployID, _ := strconv.ParseInt(deployIDStr, 10, 64) - deployTypeStr := srv.Annotations[component.KeyDeployType] - deployType, err := strconv.ParseInt(deployTypeStr, 10, 64) - if err != nil { - deployType = 0 - } - userID := srv.Annotations[component.KeyUserID] - deploySku := srv.Annotations[component.KeyDeploySKU] - status := &types.StatusResponse{ - DeployID: deployID, - UserID: userID, - DeployType: int(deployType), - ServiceName: srv.Name, - DeploySku: deploySku, - } - allStatus[srv.Name] = status - if srv.IsFailed() { - status.Code = common.DeployFailed - continue - } - - if srv.IsReady() { - podNames, err := s.GetServicePods(c.Request.Context(), cluster, srv.Name, s.k8sNameSpace, 1) - if err != nil { - slog.Error("get image status failed, cannot get pods info", slog.Any("error", err)) - status.Code = common.Running - continue - } - status.Replica = len(podNames) - if len(podNames) == 0 { - status.Code = common.Sleeping - continue - } - - status.Code = common.Running - continue - } - - // default to deploying - status.Code = common.Deploying - } - } - - c.JSON(http.StatusOK, allStatus) -} - -func (s *K8sHandler) GetServicePods(ctx context.Context, cluster cluster.Cluster, srvName string, namespace string, limit int64) ([]string, error) { - labelSelector := fmt.Sprintf("serving.knative.dev/service=%s", srvName) - // Get the list of Pods based on the label selector - opts := metav1.ListOptions{ - LabelSelector: labelSelector, - } - if limit > 0 { - opts = metav1.ListOptions{ - LabelSelector: labelSelector, - Limit: limit, - } - } - pods, err := cluster.Client.CoreV1().Pods(namespace).List(ctx, opts) + allStatus, err := s.serviceCompoent.GetAllServiceStatus(c) if err != nil { - return nil, err - } - - // Extract the Pod names - var podNames []string - for _, pod := range pods.Items { - podNames = append(podNames, pod.Name) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return } - - return podNames, nil + c.JSON(http.StatusOK, allStatus) } func (s *K8sHandler) GetClusterInfo(c *gin.Context) { clusterRes := []types.CluserResponse{} for index := range s.clusterPool.Clusters { cls := s.clusterPool.Clusters[index] - cInfo, err := s.clusterPool.ClusterStore.ByClusterConfig(c.Request.Context(), cls.ID) + cInfo, err := s.clusterPool.ClusterStore.ByClusterConfig(c.Request.Context(), cls.CID) if err != nil { slog.Error("get cluster info failed", slog.Any("error", err)) continue @@ -687,52 +320,27 @@ func (s *K8sHandler) GetServiceByName(c *gin.Context) { c.JSON(http.StatusOK, resp) return } - cluster, err := s.clusterPool.GetClusterByID(c, request.ClusterID) - if err != nil { - slog.Error("fail to get cluster config", slog.Any("error", err)) - resp.Code = -1 - resp.Message = "fail to get cluster config" - c.JSON(http.StatusOK, resp) - return - } srvName := s.getServiceNameFromRequest(c) - srv, err := cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace).Get(c.Request.Context(), srvName, metav1.GetOptions{}) - if err != nil { - k8serr := new(k8serrors.StatusError) - if errors.As(err, &k8serr) { - if k8serr.Status().Code == http.StatusNotFound { - // service not exist - resp.Code = 0 - resp.Message = "service not exist" - c.JSON(http.StatusOK, resp) - return - } - } - // get service with error - slog.Error("fail to get service with error", slog.Any("error", err)) + svc, err := s.serviceCompoent.GetServiceByName(c.Request.Context(), srvName, request.ClusterID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { resp.Code = -1 resp.Message = "fail to get service" c.JSON(http.StatusOK, resp) return } - - if srv == nil { + if svc == nil { // service not exist - resp.Code = 0 + resp.Code = common.Stopped resp.Message = "service not exist" c.JSON(http.StatusOK, resp) return } - // service exist - deployIDStr := srv.Annotations[types.ResDeployID] - deployID, _ := strconv.ParseInt(deployIDStr, 10, 64) - resp.DeployID = deployID - resp.Code = 1 + resp.DeployID = svc.DeployID + resp.Code = svc.Code resp.Message = srvName - if srv.Status.URL != nil { - resp.Endpoint = srv.Status.URL.URL().String() - } + resp.Endpoint = svc.Endpoint + resp.Instances = svc.Instances c.JSON(http.StatusOK, resp) } @@ -745,59 +353,26 @@ func (s *K8sHandler) GetReplica(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "fail to parse input parameters"}) return } - cluster, err := s.clusterPool.GetClusterByID(c, request.ClusterID) - if err != nil { - slog.Error("fail to get cluster config", slog.Any("error", err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "fail to get cluster config"}) - return - } srvName := s.getServiceNameFromRequest(c) - srv, err := cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace).Get(c.Request.Context(), srvName, metav1.GetOptions{}) - if err != nil { - // get service with error + svc, err := s.serviceCompoent.GetServiceByName(c.Request.Context(), srvName, request.ClusterID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { slog.Error("fail to get service", slog.Any("error", err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "fail to get service"}) return } - - if srv == nil { + if svc == nil { // service not exist slog.Error("service not exist") c.JSON(http.StatusNotFound, gin.H{"error": "service not exist"}) return } - // revisionName := srv.Status.LatestReadyRevisionName - revisionName := srv.Status.LatestCreatedRevisionName - if len(revisionName) < 1 { - slog.Error("fail to get latest created revision") - c.JSON(http.StatusInternalServerError, gin.H{"error": "fail to get latest created revision"}) - return - } - revision, err := cluster.KnativeClient.ServingV1().Revisions(s.k8sNameSpace).Get(c.Request.Context(), revisionName, metav1.GetOptions{}) - if err != nil { - slog.Error("fail to get revision with error", slog.Any("error", err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "fail to get revision with error"}) - return - } - - if revision == nil { - slog.Error("revision not exist") - c.JSON(http.StatusNotFound, gin.H{"error": "revision not exist"}) - return - } - instList, err := s.s.GetServicePodsWithStatus(c.Request.Context(), *cluster, srvName, s.k8sNameSpace) - if err != nil { - slog.Error("fail to get service pod name list", slog.Any("error", err)) - c.JSON(http.StatusNotFound, gin.H{"error": "fail to get service pod name list"}) - return - } // revision exist resp.Code = 1 resp.Message = srvName - resp.ActualReplica = int(*revision.Status.ActualReplicas) - resp.DesiredReplica = int(*revision.Status.DesiredReplicas) - resp.Instances = instList + resp.ActualReplica = svc.ActualReplica + resp.DesiredReplica = svc.DesiredReplica + resp.Instances = svc.Instances c.JSON(http.StatusOK, resp) } @@ -826,7 +401,7 @@ func (s *K8sHandler) UpdateCluster(c *gin.Context) { } func (s *K8sHandler) PurgeService(c *gin.Context) { - var resp types.PurgeResponse + var resp = &types.PurgeResponse{} var request = &types.PurgeRequest{} err := c.BindJSON(request) if err != nil { @@ -836,46 +411,13 @@ func (s *K8sHandler) PurgeService(c *gin.Context) { c.JSON(http.StatusBadRequest, resp) return } - cluster, err := s.clusterPool.GetClusterByID(c, request.ClusterID) - if err != nil { - slog.Error("fail to get cluster config", slog.Any("error", err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "fail to get cluster config"}) - return - } srvName := s.getServiceNameFromRequest(c) - _, err = cluster.KnativeClient.ServingV1().Services(s.k8sNameSpace). - Get(c.Request.Context(), srvName, metav1.GetOptions{}) + request.SvcName = srvName + resp, err = s.serviceCompoent.PurgeService(c.Request.Context(), *request) if err != nil { - k8serr := new(k8serrors.StatusError) - if errors.As(err, &k8serr) { - if k8serr.Status().Code == http.StatusNotFound { - slog.Info("service not exist", slog.String("srv_name", srvName), slog.Any("k8s_err", k8serr)) - } - } - slog.Error("purge service failed, cannot get service info", slog.String("srv_name", srvName), slog.Any("error", err), - slog.String("srv_name", srvName)) - } else { - // 1 delete service - err = s.removeServiceForcely(c, cluster, srvName) - if err != nil { - slog.Error("failed to delete service ", slog.String("srv_name", srvName), slog.Any("error", err), - slog.String("srv_name", srvName)) - resp.Code = -1 - resp.Message = "failed to get service status" - c.JSON(http.StatusInternalServerError, resp) - return - } - } - - // 2 clean up pvc - if cluster.StorageClass != "" && request.DeployType == types.FinetuneType { - err = cluster.Client.CoreV1().PersistentVolumeClaims(s.k8sNameSpace).Delete(c, srvName, metav1.DeleteOptions{}) - if err != nil { - slog.Error("fail to delete pvc", slog.Any("error", err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "fail to delete pvc"}) - return - } - slog.Info("persistent volume claims deleted.", slog.String("srv_name", srvName)) + slog.Error("fail to purge service", slog.Any("error", err)) + c.JSON(http.StatusInternalServerError, resp) + return } slog.Info("service deleted.", slog.String("srv_name", srvName)) resp.Code = 0 @@ -884,7 +426,6 @@ func (s *K8sHandler) PurgeService(c *gin.Context) { } func (s *K8sHandler) GetServiceInfo(c *gin.Context) { - var resp types.ServiceInfoResponse var request = &types.ServiceRequest{} err := c.BindJSON(request) if err != nil { @@ -892,21 +433,14 @@ func (s *K8sHandler) GetServiceInfo(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "fail to parse input parameters"}) return } - cluster, err := s.clusterPool.GetClusterByID(c, request.ClusterID) - if err != nil { - slog.Error("fail to get cluster config", slog.Any("error", err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "fail to get cluster config"}) - return - } srvName := s.getServiceNameFromRequest(c) - podNames, err := s.GetServicePods(c.Request.Context(), *cluster, srvName, s.k8sNameSpace, -1) + request.ServiceName = srvName + resp, err := s.serviceCompoent.GetServiceInfo(c.Request.Context(), *request) if err != nil { - slog.Error("failed to read image logs, cannot get pods info", slog.Any("error", err), slog.String("srv_name", srvName)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get pods info"}) + slog.Error("fail to get service info", slog.Any("error", err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "fail to get service info"}) return } - resp.PodNames = podNames - resp.ServiceName = srvName c.JSON(http.StatusOK, resp) } diff --git a/runner/handler/workflow.go b/runner/handler/workflow.go index b29f7b41..5b94d320 100644 --- a/runner/handler/workflow.go +++ b/runner/handler/workflow.go @@ -24,6 +24,8 @@ type ArgoHandler struct { func NewArgoHandler(config *config.Config, clusterPool *cluster.ClusterPool) (*ArgoHandler, error) { wfc := component.NewWorkFlowComponent(config, clusterPool) + //watch workflows events + go wfc.RunWorkflowsInformer(clusterPool, config) return &ArgoHandler{ clusterPool: clusterPool, config: config, From 6419f2d5de6af54f51f4b46536cbfdc9abbe9687 Mon Sep 17 00:00:00 2001 From: Yiling-J Date: Fri, 3 Jan 2025 16:03:26 +0800 Subject: [PATCH 33/34] Refactor temporal workflow (#231) * Merge branch 'refactor/temporal_client' into 'main' Refactor temporal and add workflow tests See merge request product/starhub/starhub-server!777 * Merge branch 'fix/exec_workflow' into 'main' Fix temporal workflow run param See merge request product/starhub/starhub-server!782 * update mockery yaml --------- Co-authored-by: yiling.ji --- .mockery.yaml | 273 +- .../go.temporal.io/sdk/client/mock_Client.go | 119 + .../builder/temporal/mock_Client.go | 2319 +++++++++++++++++ .../callback/mock_GitCallbackComponent.go | 305 +++ .../component/mock_MultiSyncComponent.go | 146 ++ .../component/mock_RecomComponent.go | 166 ++ .../component/mock_OrganizationComponent.go | 381 +++ api/handler/callback/git_callback.go | 8 +- api/handler/internal.go | 11 +- api/handler/internal_test.go | 14 +- api/workflow/activity/activities_ce.go | 46 + api/workflow/activity/calc_recom_score.go | 13 +- api/workflow/activity/handle_push.go | 53 +- api/workflow/activity/sync_as_client.go | 17 +- api/workflow/cron_calc_recom_score.go | 6 +- api/workflow/cron_sync_as_client.go | 6 +- api/workflow/cron_worker.go | 82 - api/workflow/cron_worker_ce.go | 63 + api/workflow/handle_push.go | 14 +- api/workflow/worker.go | 49 - api/workflow/worker_ce.go | 80 + api/workflow/workflow_ce_test.go | 54 + api/workflow/workflow_test.go | 74 + builder/temporal/temporal.go | 61 + builder/temporal/temporal_test.go | 71 + cmd/csghub-server/cmd/mirror/repo_sync.go | 10 +- cmd/csghub-server/cmd/start/server.go | 18 +- cmd/csghub-server/cmd/trigger/git_callback.go | 15 +- component/callback/git_callback.go | 1 + go.mod | 123 +- go.sum | 282 +- mirror/reposyncer/local_woker.go | 8 +- 32 files changed, 4393 insertions(+), 495 deletions(-) create mode 100644 _mocks/opencsg.com/csghub-server/builder/temporal/mock_Client.go create mode 100644 _mocks/opencsg.com/csghub-server/component/callback/mock_GitCallbackComponent.go create mode 100644 _mocks/opencsg.com/csghub-server/component/mock_MultiSyncComponent.go create mode 100644 _mocks/opencsg.com/csghub-server/component/mock_RecomComponent.go create mode 100644 _mocks/opencsg.com/csghub-server/user/component/mock_OrganizationComponent.go create mode 100644 api/workflow/activity/activities_ce.go create mode 100644 api/workflow/cron_worker_ce.go delete mode 100644 api/workflow/worker.go create mode 100644 api/workflow/worker_ce.go create mode 100644 api/workflow/workflow_ce_test.go create mode 100644 api/workflow/workflow_test.go create mode 100644 builder/temporal/temporal.go create mode 100644 builder/temporal/temporal_test.go diff --git a/.mockery.yaml b/.mockery.yaml index 719c4b2b..25e5682f 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -4,129 +4,152 @@ dir: "_mocks/{{.PackagePath}}" # https://github.com/vektra/mockery/commit/13fc60740a1eab3fc705fcc131a4b9a6eacfd0fe#diff-4bc8f03915f15ecf2cb9965592fcf933bb06986003613b5633a8dea7962c31a9R40 resolve-type-alias: false packages: - io: - config: - interfaces: - Reader: - opencsg.com/csghub-server/component: - config: - interfaces: - RepoComponent: - TagComponent: - AccountingComponent: - SpaceComponent: - SpaceResourceComponent: - RuntimeArchitectureComponent: - SensitiveComponent: - CodeComponent: - PromptComponent: - ModelComponent: - UserComponent: - GitHTTPComponent: - DiscussionComponent: - DatasetComponent: - CollectionComponent: - InternalComponent: - MirrorSourceComponent: - MirrorComponent: - EvaluationComponent: + io: + config: + interfaces: + Reader: + opencsg.com/csghub-server/component: + config: + interfaces: + RepoComponent: + TagComponent: + AccountingComponent: + SpaceComponent: + SpaceResourceComponent: + RuntimeArchitectureComponent: + SensitiveComponent: + CodeComponent: + PromptComponent: + ModelComponent: + UserComponent: + GitHTTPComponent: + DiscussionComponent: + DatasetComponent: + CollectionComponent: + InternalComponent: + MirrorSourceComponent: + MirrorComponent: + ImportComponent: + EvaluationComponent: + RecomComponent: + MultiSyncComponent: + opencsg.com/csghub-server/component/callback: + config: + interfaces: + SyncVersionGenerator: + GitCallbackComponent: + opencsg.com/csghub-server/user/component: + config: + interfaces: + MemberComponent: + OrganizationComponent: - - opencsg.com/csghub-server/user/component: - config: - interfaces: - MemberComponent: - opencsg.com/csghub-server/builder/store/database: - config: - all: True - opencsg.com/csghub-server/builder/rpc: - config: - interfaces: - UserSvcClient: - ModerationSvcClient: - opencsg.com/csghub-server/builder/sensitive: - config: - interfaces: - GreenClient: - Green2022Client: - SensitiveChecker: - opencsg.com/csghub-server/builder/git/gitserver: - config: - interfaces: - GitServer: - opencsg.com/csghub-server/builder/git/mirrorserver: - config: - interfaces: - MirrorServer: - opencsg.com/csghub-server/builder/git/membership: - config: - interfaces: - GitMemerShip: - opencsg.com/csghub-server/builder/rsa: - config: - interfaces: - KeysReader: - opencsg.com/csghub-server/mirror/cache: - config: - interfaces: - Cache: - opencsg.com/csghub-server/common/types: - config: - interfaces: - SensitiveRequestV2: - opencsg.com/csghub-server/mq: - config: - interfaces: - MessageQueue: - opencsg.com/csghub-server/builder/store/s3: - config: - interfaces: - Client: - opencsg.com/csghub-server/mirror/queue: - config: - interfaces: - PriorityQueue: - opencsg.com/csghub-server/builder/deploy: - config: - interfaces: - Deployer: - opencsg.com/csghub-server/builder/deploy/scheduler: - config: - interfaces: - Scheduler: - opencsg.com/csghub-server/builder/deploy/imagerunner: - config: - interfaces: - Runner: - opencsg.com/csghub-server/builder/deploy/imagebuilder: - config: - interfaces: - Builder: - - opencsg.com/csghub-server/accounting/component: - config: - interfaces: - AccountingUserComponent: - AccountingStatementComponent: - AccountingBillComponent: - AccountingPresentComponent: - opencsg.com/csghub-server/builder/accounting: - config: - interfaces: - AccountingClient: - opencsg.com/csghub-server/builder/parquet: - config: - interfaces: - Reader: - opencsg.com/csghub-server/builder/multisync: - config: - interfaces: - Client: - github.com/nats-io/nats.go/jetstream: - config: - interfaces: - Msg: - go.temporal.io/sdk/client: - config: - interfaces: - Client: + opencsg.com/csghub-server/builder/store/database: + config: + all: True + opencsg.com/csghub-server/builder/rpc: + config: + interfaces: + UserSvcClient: + ModerationSvcClient: + opencsg.com/csghub-server/builder/sensitive: + config: + interfaces: + GreenClient: + Green2022Client: + SensitiveChecker: + opencsg.com/csghub-server/builder/git/gitserver: + config: + interfaces: + GitServer: + opencsg.com/csghub-server/builder/git/mirrorserver: + config: + interfaces: + MirrorServer: + opencsg.com/csghub-server/builder/git/membership: + config: + interfaces: + GitMemerShip: + opencsg.com/csghub-server/builder/rsa: + config: + interfaces: + KeysReader: + opencsg.com/csghub-server/mirror/cache: + config: + interfaces: + Cache: + opencsg.com/csghub-server/common/types: + config: + interfaces: + SensitiveRequestV2: + opencsg.com/csghub-server/mq: + config: + interfaces: + MessageQueue: + opencsg.com/csghub-server/builder/store/s3: + config: + interfaces: + Client: + opencsg.com/csghub-server/mirror/queue: + config: + interfaces: + PriorityQueue: + opencsg.com/csghub-server/builder/deploy: + config: + interfaces: + Deployer: + opencsg.com/csghub-server/builder/deploy/scheduler: + config: + interfaces: + Scheduler: + opencsg.com/csghub-server/builder/deploy/imagerunner: + config: + interfaces: + Runner: + opencsg.com/csghub-server/builder/deploy/imagebuilder: + config: + interfaces: + Builder: + opencsg.com/csghub-server/accounting/component: + config: + interfaces: + AccountingUserComponent: + AccountingStatementComponent: + AccountingBillComponent: + AccountingPresentComponent: + MeteringComponent: + AccountingEventComponent: + AccountingPriceComponent: + AccountingOrderComponent: + opencsg.com/csghub-server/builder/accounting: + config: + interfaces: + AccountingClient: + opencsg.com/csghub-server/builder/parquet: + config: + interfaces: + Reader: + opencsg.com/csghub-server/builder/multisync: + config: + interfaces: + Client: + opencsg.com/csghub-server/builder/store/cache: + config: + interfaces: + RedisClient: + github.com/nats-io/nats.go/jetstream: + config: + interfaces: + Msg: + go.temporal.io/sdk/client: + config: + interfaces: + Client: + opencsg.com/csghub-server/builder/temporal: + config: + interfaces: + Client: + opencsg.com/csghub-server/builder/importer: + config: + interfaces: + Importer: diff --git a/_mocks/go.temporal.io/sdk/client/mock_Client.go b/_mocks/go.temporal.io/sdk/client/mock_Client.go index 350bfeff..d44d8a24 100644 --- a/_mocks/go.temporal.io/sdk/client/mock_Client.go +++ b/_mocks/go.temporal.io/sdk/client/mock_Client.go @@ -1199,6 +1199,66 @@ func (_c *MockClient_ListWorkflow_Call) RunAndReturn(run func(context.Context, * return _c } +// NewWithStartWorkflowOperation provides a mock function with given fields: options, workflow, args +func (_m *MockClient) NewWithStartWorkflowOperation(options client.StartWorkflowOptions, workflow interface{}, args ...interface{}) client.WithStartWorkflowOperation { + var _ca []interface{} + _ca = append(_ca, options, workflow) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for NewWithStartWorkflowOperation") + } + + var r0 client.WithStartWorkflowOperation + if rf, ok := ret.Get(0).(func(client.StartWorkflowOptions, interface{}, ...interface{}) client.WithStartWorkflowOperation); ok { + r0 = rf(options, workflow, args...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.WithStartWorkflowOperation) + } + } + + return r0 +} + +// MockClient_NewWithStartWorkflowOperation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NewWithStartWorkflowOperation' +type MockClient_NewWithStartWorkflowOperation_Call struct { + *mock.Call +} + +// NewWithStartWorkflowOperation is a helper method to define mock.On call +// - options client.StartWorkflowOptions +// - workflow interface{} +// - args ...interface{} +func (_e *MockClient_Expecter) NewWithStartWorkflowOperation(options interface{}, workflow interface{}, args ...interface{}) *MockClient_NewWithStartWorkflowOperation_Call { + return &MockClient_NewWithStartWorkflowOperation_Call{Call: _e.mock.On("NewWithStartWorkflowOperation", + append([]interface{}{options, workflow}, args...)...)} +} + +func (_c *MockClient_NewWithStartWorkflowOperation_Call) Run(run func(options client.StartWorkflowOptions, workflow interface{}, args ...interface{})) *MockClient_NewWithStartWorkflowOperation_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(client.StartWorkflowOptions), args[1].(interface{}), variadicArgs...) + }) + return _c +} + +func (_c *MockClient_NewWithStartWorkflowOperation_Call) Return(_a0 client.WithStartWorkflowOperation) *MockClient_NewWithStartWorkflowOperation_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_NewWithStartWorkflowOperation_Call) RunAndReturn(run func(client.StartWorkflowOptions, interface{}, ...interface{}) client.WithStartWorkflowOperation) *MockClient_NewWithStartWorkflowOperation_Call { + _c.Call.Return(run) + return _c +} + // OperatorService provides a mock function with given fields: func (_m *MockClient) OperatorService() operatorservice.OperatorServiceClient { ret := _m.Called() @@ -1845,6 +1905,65 @@ func (_c *MockClient_TerminateWorkflow_Call) RunAndReturn(run func(context.Conte return _c } +// UpdateWithStartWorkflow provides a mock function with given fields: ctx, options +func (_m *MockClient) UpdateWithStartWorkflow(ctx context.Context, options client.UpdateWithStartWorkflowOptions) (client.WorkflowUpdateHandle, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for UpdateWithStartWorkflow") + } + + var r0 client.WorkflowUpdateHandle + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, client.UpdateWithStartWorkflowOptions) (client.WorkflowUpdateHandle, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, client.UpdateWithStartWorkflowOptions) client.WorkflowUpdateHandle); ok { + r0 = rf(ctx, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.WorkflowUpdateHandle) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, client.UpdateWithStartWorkflowOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_UpdateWithStartWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateWithStartWorkflow' +type MockClient_UpdateWithStartWorkflow_Call struct { + *mock.Call +} + +// UpdateWithStartWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - options client.UpdateWithStartWorkflowOptions +func (_e *MockClient_Expecter) UpdateWithStartWorkflow(ctx interface{}, options interface{}) *MockClient_UpdateWithStartWorkflow_Call { + return &MockClient_UpdateWithStartWorkflow_Call{Call: _e.mock.On("UpdateWithStartWorkflow", ctx, options)} +} + +func (_c *MockClient_UpdateWithStartWorkflow_Call) Run(run func(ctx context.Context, options client.UpdateWithStartWorkflowOptions)) *MockClient_UpdateWithStartWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(client.UpdateWithStartWorkflowOptions)) + }) + return _c +} + +func (_c *MockClient_UpdateWithStartWorkflow_Call) Return(_a0 client.WorkflowUpdateHandle, _a1 error) *MockClient_UpdateWithStartWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_UpdateWithStartWorkflow_Call) RunAndReturn(run func(context.Context, client.UpdateWithStartWorkflowOptions) (client.WorkflowUpdateHandle, error)) *MockClient_UpdateWithStartWorkflow_Call { + _c.Call.Return(run) + return _c +} + // UpdateWorkerBuildIdCompatibility provides a mock function with given fields: ctx, options func (_m *MockClient) UpdateWorkerBuildIdCompatibility(ctx context.Context, options *client.UpdateWorkerBuildIdCompatibilityOptions) error { ret := _m.Called(ctx, options) diff --git a/_mocks/opencsg.com/csghub-server/builder/temporal/mock_Client.go b/_mocks/opencsg.com/csghub-server/builder/temporal/mock_Client.go new file mode 100644 index 00000000..cd0fde1a --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/builder/temporal/mock_Client.go @@ -0,0 +1,2319 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package temporal + +import ( + context "context" + + client "go.temporal.io/sdk/client" + + converter "go.temporal.io/sdk/converter" + + enums "go.temporal.io/api/enums/v1" + + mock "github.com/stretchr/testify/mock" + + operatorservice "go.temporal.io/api/operatorservice/v1" + + worker "go.temporal.io/sdk/worker" + + workflowservice "go.temporal.io/api/workflowservice/v1" +) + +// MockClient is an autogenerated mock type for the Client type +type MockClient struct { + mock.Mock +} + +type MockClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockClient) EXPECT() *MockClient_Expecter { + return &MockClient_Expecter{mock: &_m.Mock} +} + +// CancelWorkflow provides a mock function with given fields: ctx, workflowID, runID +func (_m *MockClient) CancelWorkflow(ctx context.Context, workflowID string, runID string) error { + ret := _m.Called(ctx, workflowID, runID) + + if len(ret) == 0 { + panic("no return value specified for CancelWorkflow") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, workflowID, runID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_CancelWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CancelWorkflow' +type MockClient_CancelWorkflow_Call struct { + *mock.Call +} + +// CancelWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - workflowID string +// - runID string +func (_e *MockClient_Expecter) CancelWorkflow(ctx interface{}, workflowID interface{}, runID interface{}) *MockClient_CancelWorkflow_Call { + return &MockClient_CancelWorkflow_Call{Call: _e.mock.On("CancelWorkflow", ctx, workflowID, runID)} +} + +func (_c *MockClient_CancelWorkflow_Call) Run(run func(ctx context.Context, workflowID string, runID string)) *MockClient_CancelWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MockClient_CancelWorkflow_Call) Return(_a0 error) *MockClient_CancelWorkflow_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_CancelWorkflow_Call) RunAndReturn(run func(context.Context, string, string) error) *MockClient_CancelWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// CheckHealth provides a mock function with given fields: ctx, request +func (_m *MockClient) CheckHealth(ctx context.Context, request *client.CheckHealthRequest) (*client.CheckHealthResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for CheckHealth") + } + + var r0 *client.CheckHealthResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *client.CheckHealthRequest) (*client.CheckHealthResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *client.CheckHealthRequest) *client.CheckHealthResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*client.CheckHealthResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *client.CheckHealthRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_CheckHealth_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckHealth' +type MockClient_CheckHealth_Call struct { + *mock.Call +} + +// CheckHealth is a helper method to define mock.On call +// - ctx context.Context +// - request *client.CheckHealthRequest +func (_e *MockClient_Expecter) CheckHealth(ctx interface{}, request interface{}) *MockClient_CheckHealth_Call { + return &MockClient_CheckHealth_Call{Call: _e.mock.On("CheckHealth", ctx, request)} +} + +func (_c *MockClient_CheckHealth_Call) Run(run func(ctx context.Context, request *client.CheckHealthRequest)) *MockClient_CheckHealth_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*client.CheckHealthRequest)) + }) + return _c +} + +func (_c *MockClient_CheckHealth_Call) Return(_a0 *client.CheckHealthResponse, _a1 error) *MockClient_CheckHealth_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_CheckHealth_Call) RunAndReturn(run func(context.Context, *client.CheckHealthRequest) (*client.CheckHealthResponse, error)) *MockClient_CheckHealth_Call { + _c.Call.Return(run) + return _c +} + +// Close provides a mock function with given fields: +func (_m *MockClient) Close() { + _m.Called() +} + +// MockClient_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockClient_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockClient_Expecter) Close() *MockClient_Close_Call { + return &MockClient_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockClient_Close_Call) Run(run func()) *MockClient_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockClient_Close_Call) Return() *MockClient_Close_Call { + _c.Call.Return() + return _c +} + +func (_c *MockClient_Close_Call) RunAndReturn(run func()) *MockClient_Close_Call { + _c.Call.Return(run) + return _c +} + +// CompleteActivity provides a mock function with given fields: ctx, taskToken, result, err +func (_m *MockClient) CompleteActivity(ctx context.Context, taskToken []byte, result interface{}, err error) error { + ret := _m.Called(ctx, taskToken, result, err) + + if len(ret) == 0 { + panic("no return value specified for CompleteActivity") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, []byte, interface{}, error) error); ok { + r0 = rf(ctx, taskToken, result, err) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_CompleteActivity_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CompleteActivity' +type MockClient_CompleteActivity_Call struct { + *mock.Call +} + +// CompleteActivity is a helper method to define mock.On call +// - ctx context.Context +// - taskToken []byte +// - result interface{} +// - err error +func (_e *MockClient_Expecter) CompleteActivity(ctx interface{}, taskToken interface{}, result interface{}, err interface{}) *MockClient_CompleteActivity_Call { + return &MockClient_CompleteActivity_Call{Call: _e.mock.On("CompleteActivity", ctx, taskToken, result, err)} +} + +func (_c *MockClient_CompleteActivity_Call) Run(run func(ctx context.Context, taskToken []byte, result interface{}, err error)) *MockClient_CompleteActivity_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([]byte), args[2].(interface{}), args[3].(error)) + }) + return _c +} + +func (_c *MockClient_CompleteActivity_Call) Return(_a0 error) *MockClient_CompleteActivity_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_CompleteActivity_Call) RunAndReturn(run func(context.Context, []byte, interface{}, error) error) *MockClient_CompleteActivity_Call { + _c.Call.Return(run) + return _c +} + +// CompleteActivityByID provides a mock function with given fields: ctx, namespace, workflowID, runID, activityID, result, err +func (_m *MockClient) CompleteActivityByID(ctx context.Context, namespace string, workflowID string, runID string, activityID string, result interface{}, err error) error { + ret := _m.Called(ctx, namespace, workflowID, runID, activityID, result, err) + + if len(ret) == 0 { + panic("no return value specified for CompleteActivityByID") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string, interface{}, error) error); ok { + r0 = rf(ctx, namespace, workflowID, runID, activityID, result, err) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_CompleteActivityByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CompleteActivityByID' +type MockClient_CompleteActivityByID_Call struct { + *mock.Call +} + +// CompleteActivityByID is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - workflowID string +// - runID string +// - activityID string +// - result interface{} +// - err error +func (_e *MockClient_Expecter) CompleteActivityByID(ctx interface{}, namespace interface{}, workflowID interface{}, runID interface{}, activityID interface{}, result interface{}, err interface{}) *MockClient_CompleteActivityByID_Call { + return &MockClient_CompleteActivityByID_Call{Call: _e.mock.On("CompleteActivityByID", ctx, namespace, workflowID, runID, activityID, result, err)} +} + +func (_c *MockClient_CompleteActivityByID_Call) Run(run func(ctx context.Context, namespace string, workflowID string, runID string, activityID string, result interface{}, err error)) *MockClient_CompleteActivityByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string), args[5].(interface{}), args[6].(error)) + }) + return _c +} + +func (_c *MockClient_CompleteActivityByID_Call) Return(_a0 error) *MockClient_CompleteActivityByID_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_CompleteActivityByID_Call) RunAndReturn(run func(context.Context, string, string, string, string, interface{}, error) error) *MockClient_CompleteActivityByID_Call { + _c.Call.Return(run) + return _c +} + +// CountWorkflow provides a mock function with given fields: ctx, request +func (_m *MockClient) CountWorkflow(ctx context.Context, request *workflowservice.CountWorkflowExecutionsRequest) (*workflowservice.CountWorkflowExecutionsResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for CountWorkflow") + } + + var r0 *workflowservice.CountWorkflowExecutionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.CountWorkflowExecutionsRequest) (*workflowservice.CountWorkflowExecutionsResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.CountWorkflowExecutionsRequest) *workflowservice.CountWorkflowExecutionsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.CountWorkflowExecutionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *workflowservice.CountWorkflowExecutionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_CountWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CountWorkflow' +type MockClient_CountWorkflow_Call struct { + *mock.Call +} + +// CountWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - request *workflowservice.CountWorkflowExecutionsRequest +func (_e *MockClient_Expecter) CountWorkflow(ctx interface{}, request interface{}) *MockClient_CountWorkflow_Call { + return &MockClient_CountWorkflow_Call{Call: _e.mock.On("CountWorkflow", ctx, request)} +} + +func (_c *MockClient_CountWorkflow_Call) Run(run func(ctx context.Context, request *workflowservice.CountWorkflowExecutionsRequest)) *MockClient_CountWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*workflowservice.CountWorkflowExecutionsRequest)) + }) + return _c +} + +func (_c *MockClient_CountWorkflow_Call) Return(_a0 *workflowservice.CountWorkflowExecutionsResponse, _a1 error) *MockClient_CountWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_CountWorkflow_Call) RunAndReturn(run func(context.Context, *workflowservice.CountWorkflowExecutionsRequest) (*workflowservice.CountWorkflowExecutionsResponse, error)) *MockClient_CountWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// DescribeTaskQueue provides a mock function with given fields: ctx, taskqueue, taskqueueType +func (_m *MockClient) DescribeTaskQueue(ctx context.Context, taskqueue string, taskqueueType enums.TaskQueueType) (*workflowservice.DescribeTaskQueueResponse, error) { + ret := _m.Called(ctx, taskqueue, taskqueueType) + + if len(ret) == 0 { + panic("no return value specified for DescribeTaskQueue") + } + + var r0 *workflowservice.DescribeTaskQueueResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, enums.TaskQueueType) (*workflowservice.DescribeTaskQueueResponse, error)); ok { + return rf(ctx, taskqueue, taskqueueType) + } + if rf, ok := ret.Get(0).(func(context.Context, string, enums.TaskQueueType) *workflowservice.DescribeTaskQueueResponse); ok { + r0 = rf(ctx, taskqueue, taskqueueType) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.DescribeTaskQueueResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, enums.TaskQueueType) error); ok { + r1 = rf(ctx, taskqueue, taskqueueType) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_DescribeTaskQueue_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeTaskQueue' +type MockClient_DescribeTaskQueue_Call struct { + *mock.Call +} + +// DescribeTaskQueue is a helper method to define mock.On call +// - ctx context.Context +// - taskqueue string +// - taskqueueType enums.TaskQueueType +func (_e *MockClient_Expecter) DescribeTaskQueue(ctx interface{}, taskqueue interface{}, taskqueueType interface{}) *MockClient_DescribeTaskQueue_Call { + return &MockClient_DescribeTaskQueue_Call{Call: _e.mock.On("DescribeTaskQueue", ctx, taskqueue, taskqueueType)} +} + +func (_c *MockClient_DescribeTaskQueue_Call) Run(run func(ctx context.Context, taskqueue string, taskqueueType enums.TaskQueueType)) *MockClient_DescribeTaskQueue_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(enums.TaskQueueType)) + }) + return _c +} + +func (_c *MockClient_DescribeTaskQueue_Call) Return(_a0 *workflowservice.DescribeTaskQueueResponse, _a1 error) *MockClient_DescribeTaskQueue_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_DescribeTaskQueue_Call) RunAndReturn(run func(context.Context, string, enums.TaskQueueType) (*workflowservice.DescribeTaskQueueResponse, error)) *MockClient_DescribeTaskQueue_Call { + _c.Call.Return(run) + return _c +} + +// DescribeTaskQueueEnhanced provides a mock function with given fields: ctx, options +func (_m *MockClient) DescribeTaskQueueEnhanced(ctx context.Context, options client.DescribeTaskQueueEnhancedOptions) (client.TaskQueueDescription, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for DescribeTaskQueueEnhanced") + } + + var r0 client.TaskQueueDescription + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, client.DescribeTaskQueueEnhancedOptions) (client.TaskQueueDescription, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, client.DescribeTaskQueueEnhancedOptions) client.TaskQueueDescription); ok { + r0 = rf(ctx, options) + } else { + r0 = ret.Get(0).(client.TaskQueueDescription) + } + + if rf, ok := ret.Get(1).(func(context.Context, client.DescribeTaskQueueEnhancedOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_DescribeTaskQueueEnhanced_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeTaskQueueEnhanced' +type MockClient_DescribeTaskQueueEnhanced_Call struct { + *mock.Call +} + +// DescribeTaskQueueEnhanced is a helper method to define mock.On call +// - ctx context.Context +// - options client.DescribeTaskQueueEnhancedOptions +func (_e *MockClient_Expecter) DescribeTaskQueueEnhanced(ctx interface{}, options interface{}) *MockClient_DescribeTaskQueueEnhanced_Call { + return &MockClient_DescribeTaskQueueEnhanced_Call{Call: _e.mock.On("DescribeTaskQueueEnhanced", ctx, options)} +} + +func (_c *MockClient_DescribeTaskQueueEnhanced_Call) Run(run func(ctx context.Context, options client.DescribeTaskQueueEnhancedOptions)) *MockClient_DescribeTaskQueueEnhanced_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(client.DescribeTaskQueueEnhancedOptions)) + }) + return _c +} + +func (_c *MockClient_DescribeTaskQueueEnhanced_Call) Return(_a0 client.TaskQueueDescription, _a1 error) *MockClient_DescribeTaskQueueEnhanced_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_DescribeTaskQueueEnhanced_Call) RunAndReturn(run func(context.Context, client.DescribeTaskQueueEnhancedOptions) (client.TaskQueueDescription, error)) *MockClient_DescribeTaskQueueEnhanced_Call { + _c.Call.Return(run) + return _c +} + +// DescribeWorkflowExecution provides a mock function with given fields: ctx, workflowID, runID +func (_m *MockClient) DescribeWorkflowExecution(ctx context.Context, workflowID string, runID string) (*workflowservice.DescribeWorkflowExecutionResponse, error) { + ret := _m.Called(ctx, workflowID, runID) + + if len(ret) == 0 { + panic("no return value specified for DescribeWorkflowExecution") + } + + var r0 *workflowservice.DescribeWorkflowExecutionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (*workflowservice.DescribeWorkflowExecutionResponse, error)); ok { + return rf(ctx, workflowID, runID) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) *workflowservice.DescribeWorkflowExecutionResponse); ok { + r0 = rf(ctx, workflowID, runID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.DescribeWorkflowExecutionResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, workflowID, runID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_DescribeWorkflowExecution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeWorkflowExecution' +type MockClient_DescribeWorkflowExecution_Call struct { + *mock.Call +} + +// DescribeWorkflowExecution is a helper method to define mock.On call +// - ctx context.Context +// - workflowID string +// - runID string +func (_e *MockClient_Expecter) DescribeWorkflowExecution(ctx interface{}, workflowID interface{}, runID interface{}) *MockClient_DescribeWorkflowExecution_Call { + return &MockClient_DescribeWorkflowExecution_Call{Call: _e.mock.On("DescribeWorkflowExecution", ctx, workflowID, runID)} +} + +func (_c *MockClient_DescribeWorkflowExecution_Call) Run(run func(ctx context.Context, workflowID string, runID string)) *MockClient_DescribeWorkflowExecution_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MockClient_DescribeWorkflowExecution_Call) Return(_a0 *workflowservice.DescribeWorkflowExecutionResponse, _a1 error) *MockClient_DescribeWorkflowExecution_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_DescribeWorkflowExecution_Call) RunAndReturn(run func(context.Context, string, string) (*workflowservice.DescribeWorkflowExecutionResponse, error)) *MockClient_DescribeWorkflowExecution_Call { + _c.Call.Return(run) + return _c +} + +// ExecuteWorkflow provides a mock function with given fields: ctx, options, workflow, args +func (_m *MockClient) ExecuteWorkflow(ctx context.Context, options client.StartWorkflowOptions, workflow interface{}, args ...interface{}) (client.WorkflowRun, error) { + var _ca []interface{} + _ca = append(_ca, ctx, options, workflow) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for ExecuteWorkflow") + } + + var r0 client.WorkflowRun + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, client.StartWorkflowOptions, interface{}, ...interface{}) (client.WorkflowRun, error)); ok { + return rf(ctx, options, workflow, args...) + } + if rf, ok := ret.Get(0).(func(context.Context, client.StartWorkflowOptions, interface{}, ...interface{}) client.WorkflowRun); ok { + r0 = rf(ctx, options, workflow, args...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.WorkflowRun) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, client.StartWorkflowOptions, interface{}, ...interface{}) error); ok { + r1 = rf(ctx, options, workflow, args...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_ExecuteWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ExecuteWorkflow' +type MockClient_ExecuteWorkflow_Call struct { + *mock.Call +} + +// ExecuteWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - options client.StartWorkflowOptions +// - workflow interface{} +// - args ...interface{} +func (_e *MockClient_Expecter) ExecuteWorkflow(ctx interface{}, options interface{}, workflow interface{}, args ...interface{}) *MockClient_ExecuteWorkflow_Call { + return &MockClient_ExecuteWorkflow_Call{Call: _e.mock.On("ExecuteWorkflow", + append([]interface{}{ctx, options, workflow}, args...)...)} +} + +func (_c *MockClient_ExecuteWorkflow_Call) Run(run func(ctx context.Context, options client.StartWorkflowOptions, workflow interface{}, args ...interface{})) *MockClient_ExecuteWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-3) + for i, a := range args[3:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(context.Context), args[1].(client.StartWorkflowOptions), args[2].(interface{}), variadicArgs...) + }) + return _c +} + +func (_c *MockClient_ExecuteWorkflow_Call) Return(_a0 client.WorkflowRun, _a1 error) *MockClient_ExecuteWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_ExecuteWorkflow_Call) RunAndReturn(run func(context.Context, client.StartWorkflowOptions, interface{}, ...interface{}) (client.WorkflowRun, error)) *MockClient_ExecuteWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// GetSearchAttributes provides a mock function with given fields: ctx +func (_m *MockClient) GetSearchAttributes(ctx context.Context) (*workflowservice.GetSearchAttributesResponse, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetSearchAttributes") + } + + var r0 *workflowservice.GetSearchAttributesResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (*workflowservice.GetSearchAttributesResponse, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) *workflowservice.GetSearchAttributesResponse); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.GetSearchAttributesResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_GetSearchAttributes_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSearchAttributes' +type MockClient_GetSearchAttributes_Call struct { + *mock.Call +} + +// GetSearchAttributes is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockClient_Expecter) GetSearchAttributes(ctx interface{}) *MockClient_GetSearchAttributes_Call { + return &MockClient_GetSearchAttributes_Call{Call: _e.mock.On("GetSearchAttributes", ctx)} +} + +func (_c *MockClient_GetSearchAttributes_Call) Run(run func(ctx context.Context)) *MockClient_GetSearchAttributes_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockClient_GetSearchAttributes_Call) Return(_a0 *workflowservice.GetSearchAttributesResponse, _a1 error) *MockClient_GetSearchAttributes_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_GetSearchAttributes_Call) RunAndReturn(run func(context.Context) (*workflowservice.GetSearchAttributesResponse, error)) *MockClient_GetSearchAttributes_Call { + _c.Call.Return(run) + return _c +} + +// GetWorkerBuildIdCompatibility provides a mock function with given fields: ctx, options +func (_m *MockClient) GetWorkerBuildIdCompatibility(ctx context.Context, options *client.GetWorkerBuildIdCompatibilityOptions) (*client.WorkerBuildIDVersionSets, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for GetWorkerBuildIdCompatibility") + } + + var r0 *client.WorkerBuildIDVersionSets + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *client.GetWorkerBuildIdCompatibilityOptions) (*client.WorkerBuildIDVersionSets, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, *client.GetWorkerBuildIdCompatibilityOptions) *client.WorkerBuildIDVersionSets); ok { + r0 = rf(ctx, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*client.WorkerBuildIDVersionSets) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *client.GetWorkerBuildIdCompatibilityOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_GetWorkerBuildIdCompatibility_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetWorkerBuildIdCompatibility' +type MockClient_GetWorkerBuildIdCompatibility_Call struct { + *mock.Call +} + +// GetWorkerBuildIdCompatibility is a helper method to define mock.On call +// - ctx context.Context +// - options *client.GetWorkerBuildIdCompatibilityOptions +func (_e *MockClient_Expecter) GetWorkerBuildIdCompatibility(ctx interface{}, options interface{}) *MockClient_GetWorkerBuildIdCompatibility_Call { + return &MockClient_GetWorkerBuildIdCompatibility_Call{Call: _e.mock.On("GetWorkerBuildIdCompatibility", ctx, options)} +} + +func (_c *MockClient_GetWorkerBuildIdCompatibility_Call) Run(run func(ctx context.Context, options *client.GetWorkerBuildIdCompatibilityOptions)) *MockClient_GetWorkerBuildIdCompatibility_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*client.GetWorkerBuildIdCompatibilityOptions)) + }) + return _c +} + +func (_c *MockClient_GetWorkerBuildIdCompatibility_Call) Return(_a0 *client.WorkerBuildIDVersionSets, _a1 error) *MockClient_GetWorkerBuildIdCompatibility_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_GetWorkerBuildIdCompatibility_Call) RunAndReturn(run func(context.Context, *client.GetWorkerBuildIdCompatibilityOptions) (*client.WorkerBuildIDVersionSets, error)) *MockClient_GetWorkerBuildIdCompatibility_Call { + _c.Call.Return(run) + return _c +} + +// GetWorkerTaskReachability provides a mock function with given fields: ctx, options +func (_m *MockClient) GetWorkerTaskReachability(ctx context.Context, options *client.GetWorkerTaskReachabilityOptions) (*client.WorkerTaskReachability, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for GetWorkerTaskReachability") + } + + var r0 *client.WorkerTaskReachability + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *client.GetWorkerTaskReachabilityOptions) (*client.WorkerTaskReachability, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, *client.GetWorkerTaskReachabilityOptions) *client.WorkerTaskReachability); ok { + r0 = rf(ctx, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*client.WorkerTaskReachability) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *client.GetWorkerTaskReachabilityOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_GetWorkerTaskReachability_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetWorkerTaskReachability' +type MockClient_GetWorkerTaskReachability_Call struct { + *mock.Call +} + +// GetWorkerTaskReachability is a helper method to define mock.On call +// - ctx context.Context +// - options *client.GetWorkerTaskReachabilityOptions +func (_e *MockClient_Expecter) GetWorkerTaskReachability(ctx interface{}, options interface{}) *MockClient_GetWorkerTaskReachability_Call { + return &MockClient_GetWorkerTaskReachability_Call{Call: _e.mock.On("GetWorkerTaskReachability", ctx, options)} +} + +func (_c *MockClient_GetWorkerTaskReachability_Call) Run(run func(ctx context.Context, options *client.GetWorkerTaskReachabilityOptions)) *MockClient_GetWorkerTaskReachability_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*client.GetWorkerTaskReachabilityOptions)) + }) + return _c +} + +func (_c *MockClient_GetWorkerTaskReachability_Call) Return(_a0 *client.WorkerTaskReachability, _a1 error) *MockClient_GetWorkerTaskReachability_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_GetWorkerTaskReachability_Call) RunAndReturn(run func(context.Context, *client.GetWorkerTaskReachabilityOptions) (*client.WorkerTaskReachability, error)) *MockClient_GetWorkerTaskReachability_Call { + _c.Call.Return(run) + return _c +} + +// GetWorkerVersioningRules provides a mock function with given fields: ctx, options +func (_m *MockClient) GetWorkerVersioningRules(ctx context.Context, options client.GetWorkerVersioningOptions) (*client.WorkerVersioningRules, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for GetWorkerVersioningRules") + } + + var r0 *client.WorkerVersioningRules + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, client.GetWorkerVersioningOptions) (*client.WorkerVersioningRules, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, client.GetWorkerVersioningOptions) *client.WorkerVersioningRules); ok { + r0 = rf(ctx, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*client.WorkerVersioningRules) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, client.GetWorkerVersioningOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_GetWorkerVersioningRules_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetWorkerVersioningRules' +type MockClient_GetWorkerVersioningRules_Call struct { + *mock.Call +} + +// GetWorkerVersioningRules is a helper method to define mock.On call +// - ctx context.Context +// - options client.GetWorkerVersioningOptions +func (_e *MockClient_Expecter) GetWorkerVersioningRules(ctx interface{}, options interface{}) *MockClient_GetWorkerVersioningRules_Call { + return &MockClient_GetWorkerVersioningRules_Call{Call: _e.mock.On("GetWorkerVersioningRules", ctx, options)} +} + +func (_c *MockClient_GetWorkerVersioningRules_Call) Run(run func(ctx context.Context, options client.GetWorkerVersioningOptions)) *MockClient_GetWorkerVersioningRules_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(client.GetWorkerVersioningOptions)) + }) + return _c +} + +func (_c *MockClient_GetWorkerVersioningRules_Call) Return(_a0 *client.WorkerVersioningRules, _a1 error) *MockClient_GetWorkerVersioningRules_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_GetWorkerVersioningRules_Call) RunAndReturn(run func(context.Context, client.GetWorkerVersioningOptions) (*client.WorkerVersioningRules, error)) *MockClient_GetWorkerVersioningRules_Call { + _c.Call.Return(run) + return _c +} + +// GetWorkflow provides a mock function with given fields: ctx, workflowID, runID +func (_m *MockClient) GetWorkflow(ctx context.Context, workflowID string, runID string) client.WorkflowRun { + ret := _m.Called(ctx, workflowID, runID) + + if len(ret) == 0 { + panic("no return value specified for GetWorkflow") + } + + var r0 client.WorkflowRun + if rf, ok := ret.Get(0).(func(context.Context, string, string) client.WorkflowRun); ok { + r0 = rf(ctx, workflowID, runID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.WorkflowRun) + } + } + + return r0 +} + +// MockClient_GetWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetWorkflow' +type MockClient_GetWorkflow_Call struct { + *mock.Call +} + +// GetWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - workflowID string +// - runID string +func (_e *MockClient_Expecter) GetWorkflow(ctx interface{}, workflowID interface{}, runID interface{}) *MockClient_GetWorkflow_Call { + return &MockClient_GetWorkflow_Call{Call: _e.mock.On("GetWorkflow", ctx, workflowID, runID)} +} + +func (_c *MockClient_GetWorkflow_Call) Run(run func(ctx context.Context, workflowID string, runID string)) *MockClient_GetWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MockClient_GetWorkflow_Call) Return(_a0 client.WorkflowRun) *MockClient_GetWorkflow_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_GetWorkflow_Call) RunAndReturn(run func(context.Context, string, string) client.WorkflowRun) *MockClient_GetWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// GetWorkflowHistory provides a mock function with given fields: ctx, workflowID, runID, isLongPoll, filterType +func (_m *MockClient) GetWorkflowHistory(ctx context.Context, workflowID string, runID string, isLongPoll bool, filterType enums.HistoryEventFilterType) client.HistoryEventIterator { + ret := _m.Called(ctx, workflowID, runID, isLongPoll, filterType) + + if len(ret) == 0 { + panic("no return value specified for GetWorkflowHistory") + } + + var r0 client.HistoryEventIterator + if rf, ok := ret.Get(0).(func(context.Context, string, string, bool, enums.HistoryEventFilterType) client.HistoryEventIterator); ok { + r0 = rf(ctx, workflowID, runID, isLongPoll, filterType) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.HistoryEventIterator) + } + } + + return r0 +} + +// MockClient_GetWorkflowHistory_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetWorkflowHistory' +type MockClient_GetWorkflowHistory_Call struct { + *mock.Call +} + +// GetWorkflowHistory is a helper method to define mock.On call +// - ctx context.Context +// - workflowID string +// - runID string +// - isLongPoll bool +// - filterType enums.HistoryEventFilterType +func (_e *MockClient_Expecter) GetWorkflowHistory(ctx interface{}, workflowID interface{}, runID interface{}, isLongPoll interface{}, filterType interface{}) *MockClient_GetWorkflowHistory_Call { + return &MockClient_GetWorkflowHistory_Call{Call: _e.mock.On("GetWorkflowHistory", ctx, workflowID, runID, isLongPoll, filterType)} +} + +func (_c *MockClient_GetWorkflowHistory_Call) Run(run func(ctx context.Context, workflowID string, runID string, isLongPoll bool, filterType enums.HistoryEventFilterType)) *MockClient_GetWorkflowHistory_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(bool), args[4].(enums.HistoryEventFilterType)) + }) + return _c +} + +func (_c *MockClient_GetWorkflowHistory_Call) Return(_a0 client.HistoryEventIterator) *MockClient_GetWorkflowHistory_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_GetWorkflowHistory_Call) RunAndReturn(run func(context.Context, string, string, bool, enums.HistoryEventFilterType) client.HistoryEventIterator) *MockClient_GetWorkflowHistory_Call { + _c.Call.Return(run) + return _c +} + +// GetWorkflowUpdateHandle provides a mock function with given fields: ref +func (_m *MockClient) GetWorkflowUpdateHandle(ref client.GetWorkflowUpdateHandleOptions) client.WorkflowUpdateHandle { + ret := _m.Called(ref) + + if len(ret) == 0 { + panic("no return value specified for GetWorkflowUpdateHandle") + } + + var r0 client.WorkflowUpdateHandle + if rf, ok := ret.Get(0).(func(client.GetWorkflowUpdateHandleOptions) client.WorkflowUpdateHandle); ok { + r0 = rf(ref) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.WorkflowUpdateHandle) + } + } + + return r0 +} + +// MockClient_GetWorkflowUpdateHandle_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetWorkflowUpdateHandle' +type MockClient_GetWorkflowUpdateHandle_Call struct { + *mock.Call +} + +// GetWorkflowUpdateHandle is a helper method to define mock.On call +// - ref client.GetWorkflowUpdateHandleOptions +func (_e *MockClient_Expecter) GetWorkflowUpdateHandle(ref interface{}) *MockClient_GetWorkflowUpdateHandle_Call { + return &MockClient_GetWorkflowUpdateHandle_Call{Call: _e.mock.On("GetWorkflowUpdateHandle", ref)} +} + +func (_c *MockClient_GetWorkflowUpdateHandle_Call) Run(run func(ref client.GetWorkflowUpdateHandleOptions)) *MockClient_GetWorkflowUpdateHandle_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(client.GetWorkflowUpdateHandleOptions)) + }) + return _c +} + +func (_c *MockClient_GetWorkflowUpdateHandle_Call) Return(_a0 client.WorkflowUpdateHandle) *MockClient_GetWorkflowUpdateHandle_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_GetWorkflowUpdateHandle_Call) RunAndReturn(run func(client.GetWorkflowUpdateHandleOptions) client.WorkflowUpdateHandle) *MockClient_GetWorkflowUpdateHandle_Call { + _c.Call.Return(run) + return _c +} + +// ListArchivedWorkflow provides a mock function with given fields: ctx, request +func (_m *MockClient) ListArchivedWorkflow(ctx context.Context, request *workflowservice.ListArchivedWorkflowExecutionsRequest) (*workflowservice.ListArchivedWorkflowExecutionsResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for ListArchivedWorkflow") + } + + var r0 *workflowservice.ListArchivedWorkflowExecutionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ListArchivedWorkflowExecutionsRequest) (*workflowservice.ListArchivedWorkflowExecutionsResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ListArchivedWorkflowExecutionsRequest) *workflowservice.ListArchivedWorkflowExecutionsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.ListArchivedWorkflowExecutionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *workflowservice.ListArchivedWorkflowExecutionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_ListArchivedWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListArchivedWorkflow' +type MockClient_ListArchivedWorkflow_Call struct { + *mock.Call +} + +// ListArchivedWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - request *workflowservice.ListArchivedWorkflowExecutionsRequest +func (_e *MockClient_Expecter) ListArchivedWorkflow(ctx interface{}, request interface{}) *MockClient_ListArchivedWorkflow_Call { + return &MockClient_ListArchivedWorkflow_Call{Call: _e.mock.On("ListArchivedWorkflow", ctx, request)} +} + +func (_c *MockClient_ListArchivedWorkflow_Call) Run(run func(ctx context.Context, request *workflowservice.ListArchivedWorkflowExecutionsRequest)) *MockClient_ListArchivedWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*workflowservice.ListArchivedWorkflowExecutionsRequest)) + }) + return _c +} + +func (_c *MockClient_ListArchivedWorkflow_Call) Return(_a0 *workflowservice.ListArchivedWorkflowExecutionsResponse, _a1 error) *MockClient_ListArchivedWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_ListArchivedWorkflow_Call) RunAndReturn(run func(context.Context, *workflowservice.ListArchivedWorkflowExecutionsRequest) (*workflowservice.ListArchivedWorkflowExecutionsResponse, error)) *MockClient_ListArchivedWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// ListClosedWorkflow provides a mock function with given fields: ctx, request +func (_m *MockClient) ListClosedWorkflow(ctx context.Context, request *workflowservice.ListClosedWorkflowExecutionsRequest) (*workflowservice.ListClosedWorkflowExecutionsResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for ListClosedWorkflow") + } + + var r0 *workflowservice.ListClosedWorkflowExecutionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ListClosedWorkflowExecutionsRequest) (*workflowservice.ListClosedWorkflowExecutionsResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ListClosedWorkflowExecutionsRequest) *workflowservice.ListClosedWorkflowExecutionsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.ListClosedWorkflowExecutionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *workflowservice.ListClosedWorkflowExecutionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_ListClosedWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListClosedWorkflow' +type MockClient_ListClosedWorkflow_Call struct { + *mock.Call +} + +// ListClosedWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - request *workflowservice.ListClosedWorkflowExecutionsRequest +func (_e *MockClient_Expecter) ListClosedWorkflow(ctx interface{}, request interface{}) *MockClient_ListClosedWorkflow_Call { + return &MockClient_ListClosedWorkflow_Call{Call: _e.mock.On("ListClosedWorkflow", ctx, request)} +} + +func (_c *MockClient_ListClosedWorkflow_Call) Run(run func(ctx context.Context, request *workflowservice.ListClosedWorkflowExecutionsRequest)) *MockClient_ListClosedWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*workflowservice.ListClosedWorkflowExecutionsRequest)) + }) + return _c +} + +func (_c *MockClient_ListClosedWorkflow_Call) Return(_a0 *workflowservice.ListClosedWorkflowExecutionsResponse, _a1 error) *MockClient_ListClosedWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_ListClosedWorkflow_Call) RunAndReturn(run func(context.Context, *workflowservice.ListClosedWorkflowExecutionsRequest) (*workflowservice.ListClosedWorkflowExecutionsResponse, error)) *MockClient_ListClosedWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// ListOpenWorkflow provides a mock function with given fields: ctx, request +func (_m *MockClient) ListOpenWorkflow(ctx context.Context, request *workflowservice.ListOpenWorkflowExecutionsRequest) (*workflowservice.ListOpenWorkflowExecutionsResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for ListOpenWorkflow") + } + + var r0 *workflowservice.ListOpenWorkflowExecutionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ListOpenWorkflowExecutionsRequest) (*workflowservice.ListOpenWorkflowExecutionsResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ListOpenWorkflowExecutionsRequest) *workflowservice.ListOpenWorkflowExecutionsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.ListOpenWorkflowExecutionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *workflowservice.ListOpenWorkflowExecutionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_ListOpenWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListOpenWorkflow' +type MockClient_ListOpenWorkflow_Call struct { + *mock.Call +} + +// ListOpenWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - request *workflowservice.ListOpenWorkflowExecutionsRequest +func (_e *MockClient_Expecter) ListOpenWorkflow(ctx interface{}, request interface{}) *MockClient_ListOpenWorkflow_Call { + return &MockClient_ListOpenWorkflow_Call{Call: _e.mock.On("ListOpenWorkflow", ctx, request)} +} + +func (_c *MockClient_ListOpenWorkflow_Call) Run(run func(ctx context.Context, request *workflowservice.ListOpenWorkflowExecutionsRequest)) *MockClient_ListOpenWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*workflowservice.ListOpenWorkflowExecutionsRequest)) + }) + return _c +} + +func (_c *MockClient_ListOpenWorkflow_Call) Return(_a0 *workflowservice.ListOpenWorkflowExecutionsResponse, _a1 error) *MockClient_ListOpenWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_ListOpenWorkflow_Call) RunAndReturn(run func(context.Context, *workflowservice.ListOpenWorkflowExecutionsRequest) (*workflowservice.ListOpenWorkflowExecutionsResponse, error)) *MockClient_ListOpenWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// ListWorkflow provides a mock function with given fields: ctx, request +func (_m *MockClient) ListWorkflow(ctx context.Context, request *workflowservice.ListWorkflowExecutionsRequest) (*workflowservice.ListWorkflowExecutionsResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for ListWorkflow") + } + + var r0 *workflowservice.ListWorkflowExecutionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ListWorkflowExecutionsRequest) (*workflowservice.ListWorkflowExecutionsResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ListWorkflowExecutionsRequest) *workflowservice.ListWorkflowExecutionsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.ListWorkflowExecutionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *workflowservice.ListWorkflowExecutionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_ListWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListWorkflow' +type MockClient_ListWorkflow_Call struct { + *mock.Call +} + +// ListWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - request *workflowservice.ListWorkflowExecutionsRequest +func (_e *MockClient_Expecter) ListWorkflow(ctx interface{}, request interface{}) *MockClient_ListWorkflow_Call { + return &MockClient_ListWorkflow_Call{Call: _e.mock.On("ListWorkflow", ctx, request)} +} + +func (_c *MockClient_ListWorkflow_Call) Run(run func(ctx context.Context, request *workflowservice.ListWorkflowExecutionsRequest)) *MockClient_ListWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*workflowservice.ListWorkflowExecutionsRequest)) + }) + return _c +} + +func (_c *MockClient_ListWorkflow_Call) Return(_a0 *workflowservice.ListWorkflowExecutionsResponse, _a1 error) *MockClient_ListWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_ListWorkflow_Call) RunAndReturn(run func(context.Context, *workflowservice.ListWorkflowExecutionsRequest) (*workflowservice.ListWorkflowExecutionsResponse, error)) *MockClient_ListWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// NewWithStartWorkflowOperation provides a mock function with given fields: options, workflow, args +func (_m *MockClient) NewWithStartWorkflowOperation(options client.StartWorkflowOptions, workflow interface{}, args ...interface{}) client.WithStartWorkflowOperation { + var _ca []interface{} + _ca = append(_ca, options, workflow) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for NewWithStartWorkflowOperation") + } + + var r0 client.WithStartWorkflowOperation + if rf, ok := ret.Get(0).(func(client.StartWorkflowOptions, interface{}, ...interface{}) client.WithStartWorkflowOperation); ok { + r0 = rf(options, workflow, args...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.WithStartWorkflowOperation) + } + } + + return r0 +} + +// MockClient_NewWithStartWorkflowOperation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NewWithStartWorkflowOperation' +type MockClient_NewWithStartWorkflowOperation_Call struct { + *mock.Call +} + +// NewWithStartWorkflowOperation is a helper method to define mock.On call +// - options client.StartWorkflowOptions +// - workflow interface{} +// - args ...interface{} +func (_e *MockClient_Expecter) NewWithStartWorkflowOperation(options interface{}, workflow interface{}, args ...interface{}) *MockClient_NewWithStartWorkflowOperation_Call { + return &MockClient_NewWithStartWorkflowOperation_Call{Call: _e.mock.On("NewWithStartWorkflowOperation", + append([]interface{}{options, workflow}, args...)...)} +} + +func (_c *MockClient_NewWithStartWorkflowOperation_Call) Run(run func(options client.StartWorkflowOptions, workflow interface{}, args ...interface{})) *MockClient_NewWithStartWorkflowOperation_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(client.StartWorkflowOptions), args[1].(interface{}), variadicArgs...) + }) + return _c +} + +func (_c *MockClient_NewWithStartWorkflowOperation_Call) Return(_a0 client.WithStartWorkflowOperation) *MockClient_NewWithStartWorkflowOperation_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_NewWithStartWorkflowOperation_Call) RunAndReturn(run func(client.StartWorkflowOptions, interface{}, ...interface{}) client.WithStartWorkflowOperation) *MockClient_NewWithStartWorkflowOperation_Call { + _c.Call.Return(run) + return _c +} + +// NewWorker provides a mock function with given fields: queue, options +func (_m *MockClient) NewWorker(queue string, options worker.Options) worker.Registry { + ret := _m.Called(queue, options) + + if len(ret) == 0 { + panic("no return value specified for NewWorker") + } + + var r0 worker.Registry + if rf, ok := ret.Get(0).(func(string, worker.Options) worker.Registry); ok { + r0 = rf(queue, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(worker.Registry) + } + } + + return r0 +} + +// MockClient_NewWorker_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NewWorker' +type MockClient_NewWorker_Call struct { + *mock.Call +} + +// NewWorker is a helper method to define mock.On call +// - queue string +// - options worker.Options +func (_e *MockClient_Expecter) NewWorker(queue interface{}, options interface{}) *MockClient_NewWorker_Call { + return &MockClient_NewWorker_Call{Call: _e.mock.On("NewWorker", queue, options)} +} + +func (_c *MockClient_NewWorker_Call) Run(run func(queue string, options worker.Options)) *MockClient_NewWorker_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string), args[1].(worker.Options)) + }) + return _c +} + +func (_c *MockClient_NewWorker_Call) Return(_a0 worker.Registry) *MockClient_NewWorker_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_NewWorker_Call) RunAndReturn(run func(string, worker.Options) worker.Registry) *MockClient_NewWorker_Call { + _c.Call.Return(run) + return _c +} + +// OperatorService provides a mock function with given fields: +func (_m *MockClient) OperatorService() operatorservice.OperatorServiceClient { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for OperatorService") + } + + var r0 operatorservice.OperatorServiceClient + if rf, ok := ret.Get(0).(func() operatorservice.OperatorServiceClient); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(operatorservice.OperatorServiceClient) + } + } + + return r0 +} + +// MockClient_OperatorService_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OperatorService' +type MockClient_OperatorService_Call struct { + *mock.Call +} + +// OperatorService is a helper method to define mock.On call +func (_e *MockClient_Expecter) OperatorService() *MockClient_OperatorService_Call { + return &MockClient_OperatorService_Call{Call: _e.mock.On("OperatorService")} +} + +func (_c *MockClient_OperatorService_Call) Run(run func()) *MockClient_OperatorService_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockClient_OperatorService_Call) Return(_a0 operatorservice.OperatorServiceClient) *MockClient_OperatorService_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_OperatorService_Call) RunAndReturn(run func() operatorservice.OperatorServiceClient) *MockClient_OperatorService_Call { + _c.Call.Return(run) + return _c +} + +// QueryWorkflow provides a mock function with given fields: ctx, workflowID, runID, queryType, args +func (_m *MockClient) QueryWorkflow(ctx context.Context, workflowID string, runID string, queryType string, args ...interface{}) (converter.EncodedValue, error) { + var _ca []interface{} + _ca = append(_ca, ctx, workflowID, runID, queryType) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for QueryWorkflow") + } + + var r0 converter.EncodedValue + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, ...interface{}) (converter.EncodedValue, error)); ok { + return rf(ctx, workflowID, runID, queryType, args...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, ...interface{}) converter.EncodedValue); ok { + r0 = rf(ctx, workflowID, runID, queryType, args...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(converter.EncodedValue) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, ...interface{}) error); ok { + r1 = rf(ctx, workflowID, runID, queryType, args...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_QueryWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryWorkflow' +type MockClient_QueryWorkflow_Call struct { + *mock.Call +} + +// QueryWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - workflowID string +// - runID string +// - queryType string +// - args ...interface{} +func (_e *MockClient_Expecter) QueryWorkflow(ctx interface{}, workflowID interface{}, runID interface{}, queryType interface{}, args ...interface{}) *MockClient_QueryWorkflow_Call { + return &MockClient_QueryWorkflow_Call{Call: _e.mock.On("QueryWorkflow", + append([]interface{}{ctx, workflowID, runID, queryType}, args...)...)} +} + +func (_c *MockClient_QueryWorkflow_Call) Run(run func(ctx context.Context, workflowID string, runID string, queryType string, args ...interface{})) *MockClient_QueryWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-4) + for i, a := range args[4:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), variadicArgs...) + }) + return _c +} + +func (_c *MockClient_QueryWorkflow_Call) Return(_a0 converter.EncodedValue, _a1 error) *MockClient_QueryWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_QueryWorkflow_Call) RunAndReturn(run func(context.Context, string, string, string, ...interface{}) (converter.EncodedValue, error)) *MockClient_QueryWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// QueryWorkflowWithOptions provides a mock function with given fields: ctx, request +func (_m *MockClient) QueryWorkflowWithOptions(ctx context.Context, request *client.QueryWorkflowWithOptionsRequest) (*client.QueryWorkflowWithOptionsResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for QueryWorkflowWithOptions") + } + + var r0 *client.QueryWorkflowWithOptionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *client.QueryWorkflowWithOptionsRequest) (*client.QueryWorkflowWithOptionsResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *client.QueryWorkflowWithOptionsRequest) *client.QueryWorkflowWithOptionsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*client.QueryWorkflowWithOptionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *client.QueryWorkflowWithOptionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_QueryWorkflowWithOptions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryWorkflowWithOptions' +type MockClient_QueryWorkflowWithOptions_Call struct { + *mock.Call +} + +// QueryWorkflowWithOptions is a helper method to define mock.On call +// - ctx context.Context +// - request *client.QueryWorkflowWithOptionsRequest +func (_e *MockClient_Expecter) QueryWorkflowWithOptions(ctx interface{}, request interface{}) *MockClient_QueryWorkflowWithOptions_Call { + return &MockClient_QueryWorkflowWithOptions_Call{Call: _e.mock.On("QueryWorkflowWithOptions", ctx, request)} +} + +func (_c *MockClient_QueryWorkflowWithOptions_Call) Run(run func(ctx context.Context, request *client.QueryWorkflowWithOptionsRequest)) *MockClient_QueryWorkflowWithOptions_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*client.QueryWorkflowWithOptionsRequest)) + }) + return _c +} + +func (_c *MockClient_QueryWorkflowWithOptions_Call) Return(_a0 *client.QueryWorkflowWithOptionsResponse, _a1 error) *MockClient_QueryWorkflowWithOptions_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_QueryWorkflowWithOptions_Call) RunAndReturn(run func(context.Context, *client.QueryWorkflowWithOptionsRequest) (*client.QueryWorkflowWithOptionsResponse, error)) *MockClient_QueryWorkflowWithOptions_Call { + _c.Call.Return(run) + return _c +} + +// RecordActivityHeartbeat provides a mock function with given fields: ctx, taskToken, details +func (_m *MockClient) RecordActivityHeartbeat(ctx context.Context, taskToken []byte, details ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, ctx, taskToken) + _ca = append(_ca, details...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for RecordActivityHeartbeat") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, []byte, ...interface{}) error); ok { + r0 = rf(ctx, taskToken, details...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_RecordActivityHeartbeat_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecordActivityHeartbeat' +type MockClient_RecordActivityHeartbeat_Call struct { + *mock.Call +} + +// RecordActivityHeartbeat is a helper method to define mock.On call +// - ctx context.Context +// - taskToken []byte +// - details ...interface{} +func (_e *MockClient_Expecter) RecordActivityHeartbeat(ctx interface{}, taskToken interface{}, details ...interface{}) *MockClient_RecordActivityHeartbeat_Call { + return &MockClient_RecordActivityHeartbeat_Call{Call: _e.mock.On("RecordActivityHeartbeat", + append([]interface{}{ctx, taskToken}, details...)...)} +} + +func (_c *MockClient_RecordActivityHeartbeat_Call) Run(run func(ctx context.Context, taskToken []byte, details ...interface{})) *MockClient_RecordActivityHeartbeat_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(context.Context), args[1].([]byte), variadicArgs...) + }) + return _c +} + +func (_c *MockClient_RecordActivityHeartbeat_Call) Return(_a0 error) *MockClient_RecordActivityHeartbeat_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_RecordActivityHeartbeat_Call) RunAndReturn(run func(context.Context, []byte, ...interface{}) error) *MockClient_RecordActivityHeartbeat_Call { + _c.Call.Return(run) + return _c +} + +// RecordActivityHeartbeatByID provides a mock function with given fields: ctx, namespace, workflowID, runID, activityID, details +func (_m *MockClient) RecordActivityHeartbeatByID(ctx context.Context, namespace string, workflowID string, runID string, activityID string, details ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, ctx, namespace, workflowID, runID, activityID) + _ca = append(_ca, details...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for RecordActivityHeartbeatByID") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string, ...interface{}) error); ok { + r0 = rf(ctx, namespace, workflowID, runID, activityID, details...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_RecordActivityHeartbeatByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecordActivityHeartbeatByID' +type MockClient_RecordActivityHeartbeatByID_Call struct { + *mock.Call +} + +// RecordActivityHeartbeatByID is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - workflowID string +// - runID string +// - activityID string +// - details ...interface{} +func (_e *MockClient_Expecter) RecordActivityHeartbeatByID(ctx interface{}, namespace interface{}, workflowID interface{}, runID interface{}, activityID interface{}, details ...interface{}) *MockClient_RecordActivityHeartbeatByID_Call { + return &MockClient_RecordActivityHeartbeatByID_Call{Call: _e.mock.On("RecordActivityHeartbeatByID", + append([]interface{}{ctx, namespace, workflowID, runID, activityID}, details...)...)} +} + +func (_c *MockClient_RecordActivityHeartbeatByID_Call) Run(run func(ctx context.Context, namespace string, workflowID string, runID string, activityID string, details ...interface{})) *MockClient_RecordActivityHeartbeatByID_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-5) + for i, a := range args[5:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string), variadicArgs...) + }) + return _c +} + +func (_c *MockClient_RecordActivityHeartbeatByID_Call) Return(_a0 error) *MockClient_RecordActivityHeartbeatByID_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_RecordActivityHeartbeatByID_Call) RunAndReturn(run func(context.Context, string, string, string, string, ...interface{}) error) *MockClient_RecordActivityHeartbeatByID_Call { + _c.Call.Return(run) + return _c +} + +// ResetWorkflowExecution provides a mock function with given fields: ctx, request +func (_m *MockClient) ResetWorkflowExecution(ctx context.Context, request *workflowservice.ResetWorkflowExecutionRequest) (*workflowservice.ResetWorkflowExecutionResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for ResetWorkflowExecution") + } + + var r0 *workflowservice.ResetWorkflowExecutionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ResetWorkflowExecutionRequest) (*workflowservice.ResetWorkflowExecutionResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ResetWorkflowExecutionRequest) *workflowservice.ResetWorkflowExecutionResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.ResetWorkflowExecutionResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *workflowservice.ResetWorkflowExecutionRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_ResetWorkflowExecution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ResetWorkflowExecution' +type MockClient_ResetWorkflowExecution_Call struct { + *mock.Call +} + +// ResetWorkflowExecution is a helper method to define mock.On call +// - ctx context.Context +// - request *workflowservice.ResetWorkflowExecutionRequest +func (_e *MockClient_Expecter) ResetWorkflowExecution(ctx interface{}, request interface{}) *MockClient_ResetWorkflowExecution_Call { + return &MockClient_ResetWorkflowExecution_Call{Call: _e.mock.On("ResetWorkflowExecution", ctx, request)} +} + +func (_c *MockClient_ResetWorkflowExecution_Call) Run(run func(ctx context.Context, request *workflowservice.ResetWorkflowExecutionRequest)) *MockClient_ResetWorkflowExecution_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*workflowservice.ResetWorkflowExecutionRequest)) + }) + return _c +} + +func (_c *MockClient_ResetWorkflowExecution_Call) Return(_a0 *workflowservice.ResetWorkflowExecutionResponse, _a1 error) *MockClient_ResetWorkflowExecution_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_ResetWorkflowExecution_Call) RunAndReturn(run func(context.Context, *workflowservice.ResetWorkflowExecutionRequest) (*workflowservice.ResetWorkflowExecutionResponse, error)) *MockClient_ResetWorkflowExecution_Call { + _c.Call.Return(run) + return _c +} + +// ScanWorkflow provides a mock function with given fields: ctx, request +func (_m *MockClient) ScanWorkflow(ctx context.Context, request *workflowservice.ScanWorkflowExecutionsRequest) (*workflowservice.ScanWorkflowExecutionsResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for ScanWorkflow") + } + + var r0 *workflowservice.ScanWorkflowExecutionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ScanWorkflowExecutionsRequest) (*workflowservice.ScanWorkflowExecutionsResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *workflowservice.ScanWorkflowExecutionsRequest) *workflowservice.ScanWorkflowExecutionsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workflowservice.ScanWorkflowExecutionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *workflowservice.ScanWorkflowExecutionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_ScanWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ScanWorkflow' +type MockClient_ScanWorkflow_Call struct { + *mock.Call +} + +// ScanWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - request *workflowservice.ScanWorkflowExecutionsRequest +func (_e *MockClient_Expecter) ScanWorkflow(ctx interface{}, request interface{}) *MockClient_ScanWorkflow_Call { + return &MockClient_ScanWorkflow_Call{Call: _e.mock.On("ScanWorkflow", ctx, request)} +} + +func (_c *MockClient_ScanWorkflow_Call) Run(run func(ctx context.Context, request *workflowservice.ScanWorkflowExecutionsRequest)) *MockClient_ScanWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*workflowservice.ScanWorkflowExecutionsRequest)) + }) + return _c +} + +func (_c *MockClient_ScanWorkflow_Call) Return(_a0 *workflowservice.ScanWorkflowExecutionsResponse, _a1 error) *MockClient_ScanWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_ScanWorkflow_Call) RunAndReturn(run func(context.Context, *workflowservice.ScanWorkflowExecutionsRequest) (*workflowservice.ScanWorkflowExecutionsResponse, error)) *MockClient_ScanWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// ScheduleClient provides a mock function with given fields: +func (_m *MockClient) ScheduleClient() client.ScheduleClient { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ScheduleClient") + } + + var r0 client.ScheduleClient + if rf, ok := ret.Get(0).(func() client.ScheduleClient); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.ScheduleClient) + } + } + + return r0 +} + +// MockClient_ScheduleClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ScheduleClient' +type MockClient_ScheduleClient_Call struct { + *mock.Call +} + +// ScheduleClient is a helper method to define mock.On call +func (_e *MockClient_Expecter) ScheduleClient() *MockClient_ScheduleClient_Call { + return &MockClient_ScheduleClient_Call{Call: _e.mock.On("ScheduleClient")} +} + +func (_c *MockClient_ScheduleClient_Call) Run(run func()) *MockClient_ScheduleClient_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockClient_ScheduleClient_Call) Return(_a0 client.ScheduleClient) *MockClient_ScheduleClient_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_ScheduleClient_Call) RunAndReturn(run func() client.ScheduleClient) *MockClient_ScheduleClient_Call { + _c.Call.Return(run) + return _c +} + +// SignalWithStartWorkflow provides a mock function with given fields: ctx, workflowID, signalName, signalArg, options, workflow, workflowArgs +func (_m *MockClient) SignalWithStartWorkflow(ctx context.Context, workflowID string, signalName string, signalArg interface{}, options client.StartWorkflowOptions, workflow interface{}, workflowArgs ...interface{}) (client.WorkflowRun, error) { + var _ca []interface{} + _ca = append(_ca, ctx, workflowID, signalName, signalArg, options, workflow) + _ca = append(_ca, workflowArgs...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for SignalWithStartWorkflow") + } + + var r0 client.WorkflowRun + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, interface{}, client.StartWorkflowOptions, interface{}, ...interface{}) (client.WorkflowRun, error)); ok { + return rf(ctx, workflowID, signalName, signalArg, options, workflow, workflowArgs...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, interface{}, client.StartWorkflowOptions, interface{}, ...interface{}) client.WorkflowRun); ok { + r0 = rf(ctx, workflowID, signalName, signalArg, options, workflow, workflowArgs...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.WorkflowRun) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, interface{}, client.StartWorkflowOptions, interface{}, ...interface{}) error); ok { + r1 = rf(ctx, workflowID, signalName, signalArg, options, workflow, workflowArgs...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_SignalWithStartWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SignalWithStartWorkflow' +type MockClient_SignalWithStartWorkflow_Call struct { + *mock.Call +} + +// SignalWithStartWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - workflowID string +// - signalName string +// - signalArg interface{} +// - options client.StartWorkflowOptions +// - workflow interface{} +// - workflowArgs ...interface{} +func (_e *MockClient_Expecter) SignalWithStartWorkflow(ctx interface{}, workflowID interface{}, signalName interface{}, signalArg interface{}, options interface{}, workflow interface{}, workflowArgs ...interface{}) *MockClient_SignalWithStartWorkflow_Call { + return &MockClient_SignalWithStartWorkflow_Call{Call: _e.mock.On("SignalWithStartWorkflow", + append([]interface{}{ctx, workflowID, signalName, signalArg, options, workflow}, workflowArgs...)...)} +} + +func (_c *MockClient_SignalWithStartWorkflow_Call) Run(run func(ctx context.Context, workflowID string, signalName string, signalArg interface{}, options client.StartWorkflowOptions, workflow interface{}, workflowArgs ...interface{})) *MockClient_SignalWithStartWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-6) + for i, a := range args[6:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(interface{}), args[4].(client.StartWorkflowOptions), args[5].(interface{}), variadicArgs...) + }) + return _c +} + +func (_c *MockClient_SignalWithStartWorkflow_Call) Return(_a0 client.WorkflowRun, _a1 error) *MockClient_SignalWithStartWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_SignalWithStartWorkflow_Call) RunAndReturn(run func(context.Context, string, string, interface{}, client.StartWorkflowOptions, interface{}, ...interface{}) (client.WorkflowRun, error)) *MockClient_SignalWithStartWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// SignalWorkflow provides a mock function with given fields: ctx, workflowID, runID, signalName, arg +func (_m *MockClient) SignalWorkflow(ctx context.Context, workflowID string, runID string, signalName string, arg interface{}) error { + ret := _m.Called(ctx, workflowID, runID, signalName, arg) + + if len(ret) == 0 { + panic("no return value specified for SignalWorkflow") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, interface{}) error); ok { + r0 = rf(ctx, workflowID, runID, signalName, arg) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_SignalWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SignalWorkflow' +type MockClient_SignalWorkflow_Call struct { + *mock.Call +} + +// SignalWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - workflowID string +// - runID string +// - signalName string +// - arg interface{} +func (_e *MockClient_Expecter) SignalWorkflow(ctx interface{}, workflowID interface{}, runID interface{}, signalName interface{}, arg interface{}) *MockClient_SignalWorkflow_Call { + return &MockClient_SignalWorkflow_Call{Call: _e.mock.On("SignalWorkflow", ctx, workflowID, runID, signalName, arg)} +} + +func (_c *MockClient_SignalWorkflow_Call) Run(run func(ctx context.Context, workflowID string, runID string, signalName string, arg interface{})) *MockClient_SignalWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(interface{})) + }) + return _c +} + +func (_c *MockClient_SignalWorkflow_Call) Return(_a0 error) *MockClient_SignalWorkflow_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_SignalWorkflow_Call) RunAndReturn(run func(context.Context, string, string, string, interface{}) error) *MockClient_SignalWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// Start provides a mock function with given fields: +func (_m *MockClient) Start() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Start") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' +type MockClient_Start_Call struct { + *mock.Call +} + +// Start is a helper method to define mock.On call +func (_e *MockClient_Expecter) Start() *MockClient_Start_Call { + return &MockClient_Start_Call{Call: _e.mock.On("Start")} +} + +func (_c *MockClient_Start_Call) Run(run func()) *MockClient_Start_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockClient_Start_Call) Return(_a0 error) *MockClient_Start_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_Start_Call) RunAndReturn(run func() error) *MockClient_Start_Call { + _c.Call.Return(run) + return _c +} + +// Stop provides a mock function with given fields: +func (_m *MockClient) Stop() { + _m.Called() +} + +// MockClient_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop' +type MockClient_Stop_Call struct { + *mock.Call +} + +// Stop is a helper method to define mock.On call +func (_e *MockClient_Expecter) Stop() *MockClient_Stop_Call { + return &MockClient_Stop_Call{Call: _e.mock.On("Stop")} +} + +func (_c *MockClient_Stop_Call) Run(run func()) *MockClient_Stop_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockClient_Stop_Call) Return() *MockClient_Stop_Call { + _c.Call.Return() + return _c +} + +func (_c *MockClient_Stop_Call) RunAndReturn(run func()) *MockClient_Stop_Call { + _c.Call.Return(run) + return _c +} + +// TerminateWorkflow provides a mock function with given fields: ctx, workflowID, runID, reason, details +func (_m *MockClient) TerminateWorkflow(ctx context.Context, workflowID string, runID string, reason string, details ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, ctx, workflowID, runID, reason) + _ca = append(_ca, details...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for TerminateWorkflow") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, ...interface{}) error); ok { + r0 = rf(ctx, workflowID, runID, reason, details...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_TerminateWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TerminateWorkflow' +type MockClient_TerminateWorkflow_Call struct { + *mock.Call +} + +// TerminateWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - workflowID string +// - runID string +// - reason string +// - details ...interface{} +func (_e *MockClient_Expecter) TerminateWorkflow(ctx interface{}, workflowID interface{}, runID interface{}, reason interface{}, details ...interface{}) *MockClient_TerminateWorkflow_Call { + return &MockClient_TerminateWorkflow_Call{Call: _e.mock.On("TerminateWorkflow", + append([]interface{}{ctx, workflowID, runID, reason}, details...)...)} +} + +func (_c *MockClient_TerminateWorkflow_Call) Run(run func(ctx context.Context, workflowID string, runID string, reason string, details ...interface{})) *MockClient_TerminateWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-4) + for i, a := range args[4:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), variadicArgs...) + }) + return _c +} + +func (_c *MockClient_TerminateWorkflow_Call) Return(_a0 error) *MockClient_TerminateWorkflow_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_TerminateWorkflow_Call) RunAndReturn(run func(context.Context, string, string, string, ...interface{}) error) *MockClient_TerminateWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// UpdateWithStartWorkflow provides a mock function with given fields: ctx, options +func (_m *MockClient) UpdateWithStartWorkflow(ctx context.Context, options client.UpdateWithStartWorkflowOptions) (client.WorkflowUpdateHandle, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for UpdateWithStartWorkflow") + } + + var r0 client.WorkflowUpdateHandle + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, client.UpdateWithStartWorkflowOptions) (client.WorkflowUpdateHandle, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, client.UpdateWithStartWorkflowOptions) client.WorkflowUpdateHandle); ok { + r0 = rf(ctx, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.WorkflowUpdateHandle) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, client.UpdateWithStartWorkflowOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_UpdateWithStartWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateWithStartWorkflow' +type MockClient_UpdateWithStartWorkflow_Call struct { + *mock.Call +} + +// UpdateWithStartWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - options client.UpdateWithStartWorkflowOptions +func (_e *MockClient_Expecter) UpdateWithStartWorkflow(ctx interface{}, options interface{}) *MockClient_UpdateWithStartWorkflow_Call { + return &MockClient_UpdateWithStartWorkflow_Call{Call: _e.mock.On("UpdateWithStartWorkflow", ctx, options)} +} + +func (_c *MockClient_UpdateWithStartWorkflow_Call) Run(run func(ctx context.Context, options client.UpdateWithStartWorkflowOptions)) *MockClient_UpdateWithStartWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(client.UpdateWithStartWorkflowOptions)) + }) + return _c +} + +func (_c *MockClient_UpdateWithStartWorkflow_Call) Return(_a0 client.WorkflowUpdateHandle, _a1 error) *MockClient_UpdateWithStartWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_UpdateWithStartWorkflow_Call) RunAndReturn(run func(context.Context, client.UpdateWithStartWorkflowOptions) (client.WorkflowUpdateHandle, error)) *MockClient_UpdateWithStartWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// UpdateWorkerBuildIdCompatibility provides a mock function with given fields: ctx, options +func (_m *MockClient) UpdateWorkerBuildIdCompatibility(ctx context.Context, options *client.UpdateWorkerBuildIdCompatibilityOptions) error { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for UpdateWorkerBuildIdCompatibility") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *client.UpdateWorkerBuildIdCompatibilityOptions) error); ok { + r0 = rf(ctx, options) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_UpdateWorkerBuildIdCompatibility_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateWorkerBuildIdCompatibility' +type MockClient_UpdateWorkerBuildIdCompatibility_Call struct { + *mock.Call +} + +// UpdateWorkerBuildIdCompatibility is a helper method to define mock.On call +// - ctx context.Context +// - options *client.UpdateWorkerBuildIdCompatibilityOptions +func (_e *MockClient_Expecter) UpdateWorkerBuildIdCompatibility(ctx interface{}, options interface{}) *MockClient_UpdateWorkerBuildIdCompatibility_Call { + return &MockClient_UpdateWorkerBuildIdCompatibility_Call{Call: _e.mock.On("UpdateWorkerBuildIdCompatibility", ctx, options)} +} + +func (_c *MockClient_UpdateWorkerBuildIdCompatibility_Call) Run(run func(ctx context.Context, options *client.UpdateWorkerBuildIdCompatibilityOptions)) *MockClient_UpdateWorkerBuildIdCompatibility_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*client.UpdateWorkerBuildIdCompatibilityOptions)) + }) + return _c +} + +func (_c *MockClient_UpdateWorkerBuildIdCompatibility_Call) Return(_a0 error) *MockClient_UpdateWorkerBuildIdCompatibility_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_UpdateWorkerBuildIdCompatibility_Call) RunAndReturn(run func(context.Context, *client.UpdateWorkerBuildIdCompatibilityOptions) error) *MockClient_UpdateWorkerBuildIdCompatibility_Call { + _c.Call.Return(run) + return _c +} + +// UpdateWorkerVersioningRules provides a mock function with given fields: ctx, options +func (_m *MockClient) UpdateWorkerVersioningRules(ctx context.Context, options client.UpdateWorkerVersioningRulesOptions) (*client.WorkerVersioningRules, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for UpdateWorkerVersioningRules") + } + + var r0 *client.WorkerVersioningRules + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, client.UpdateWorkerVersioningRulesOptions) (*client.WorkerVersioningRules, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, client.UpdateWorkerVersioningRulesOptions) *client.WorkerVersioningRules); ok { + r0 = rf(ctx, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*client.WorkerVersioningRules) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, client.UpdateWorkerVersioningRulesOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_UpdateWorkerVersioningRules_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateWorkerVersioningRules' +type MockClient_UpdateWorkerVersioningRules_Call struct { + *mock.Call +} + +// UpdateWorkerVersioningRules is a helper method to define mock.On call +// - ctx context.Context +// - options client.UpdateWorkerVersioningRulesOptions +func (_e *MockClient_Expecter) UpdateWorkerVersioningRules(ctx interface{}, options interface{}) *MockClient_UpdateWorkerVersioningRules_Call { + return &MockClient_UpdateWorkerVersioningRules_Call{Call: _e.mock.On("UpdateWorkerVersioningRules", ctx, options)} +} + +func (_c *MockClient_UpdateWorkerVersioningRules_Call) Run(run func(ctx context.Context, options client.UpdateWorkerVersioningRulesOptions)) *MockClient_UpdateWorkerVersioningRules_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(client.UpdateWorkerVersioningRulesOptions)) + }) + return _c +} + +func (_c *MockClient_UpdateWorkerVersioningRules_Call) Return(_a0 *client.WorkerVersioningRules, _a1 error) *MockClient_UpdateWorkerVersioningRules_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_UpdateWorkerVersioningRules_Call) RunAndReturn(run func(context.Context, client.UpdateWorkerVersioningRulesOptions) (*client.WorkerVersioningRules, error)) *MockClient_UpdateWorkerVersioningRules_Call { + _c.Call.Return(run) + return _c +} + +// UpdateWorkflow provides a mock function with given fields: ctx, options +func (_m *MockClient) UpdateWorkflow(ctx context.Context, options client.UpdateWorkflowOptions) (client.WorkflowUpdateHandle, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for UpdateWorkflow") + } + + var r0 client.WorkflowUpdateHandle + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, client.UpdateWorkflowOptions) (client.WorkflowUpdateHandle, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, client.UpdateWorkflowOptions) client.WorkflowUpdateHandle); ok { + r0 = rf(ctx, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.WorkflowUpdateHandle) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, client.UpdateWorkflowOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockClient_UpdateWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateWorkflow' +type MockClient_UpdateWorkflow_Call struct { + *mock.Call +} + +// UpdateWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - options client.UpdateWorkflowOptions +func (_e *MockClient_Expecter) UpdateWorkflow(ctx interface{}, options interface{}) *MockClient_UpdateWorkflow_Call { + return &MockClient_UpdateWorkflow_Call{Call: _e.mock.On("UpdateWorkflow", ctx, options)} +} + +func (_c *MockClient_UpdateWorkflow_Call) Run(run func(ctx context.Context, options client.UpdateWorkflowOptions)) *MockClient_UpdateWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(client.UpdateWorkflowOptions)) + }) + return _c +} + +func (_c *MockClient_UpdateWorkflow_Call) Return(_a0 client.WorkflowUpdateHandle, _a1 error) *MockClient_UpdateWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockClient_UpdateWorkflow_Call) RunAndReturn(run func(context.Context, client.UpdateWorkflowOptions) (client.WorkflowUpdateHandle, error)) *MockClient_UpdateWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// WorkflowService provides a mock function with given fields: +func (_m *MockClient) WorkflowService() workflowservice.WorkflowServiceClient { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for WorkflowService") + } + + var r0 workflowservice.WorkflowServiceClient + if rf, ok := ret.Get(0).(func() workflowservice.WorkflowServiceClient); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(workflowservice.WorkflowServiceClient) + } + } + + return r0 +} + +// MockClient_WorkflowService_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WorkflowService' +type MockClient_WorkflowService_Call struct { + *mock.Call +} + +// WorkflowService is a helper method to define mock.On call +func (_e *MockClient_Expecter) WorkflowService() *MockClient_WorkflowService_Call { + return &MockClient_WorkflowService_Call{Call: _e.mock.On("WorkflowService")} +} + +func (_c *MockClient_WorkflowService_Call) Run(run func()) *MockClient_WorkflowService_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockClient_WorkflowService_Call) Return(_a0 workflowservice.WorkflowServiceClient) *MockClient_WorkflowService_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_WorkflowService_Call) RunAndReturn(run func() workflowservice.WorkflowServiceClient) *MockClient_WorkflowService_Call { + _c.Call.Return(run) + return _c +} + +// NewMockClient creates a new instance of MockClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockClient { + mock := &MockClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/component/callback/mock_GitCallbackComponent.go b/_mocks/opencsg.com/csghub-server/component/callback/mock_GitCallbackComponent.go new file mode 100644 index 00000000..3ce91d7f --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/callback/mock_GitCallbackComponent.go @@ -0,0 +1,305 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package callback + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + types "opencsg.com/csghub-server/common/types" +) + +// MockGitCallbackComponent is an autogenerated mock type for the GitCallbackComponent type +type MockGitCallbackComponent struct { + mock.Mock +} + +type MockGitCallbackComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockGitCallbackComponent) EXPECT() *MockGitCallbackComponent_Expecter { + return &MockGitCallbackComponent_Expecter{mock: &_m.Mock} +} + +// SensitiveCheck provides a mock function with given fields: ctx, req +func (_m *MockGitCallbackComponent) SensitiveCheck(ctx context.Context, req *types.GiteaCallbackPushReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for SensitiveCheck") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *types.GiteaCallbackPushReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockGitCallbackComponent_SensitiveCheck_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SensitiveCheck' +type MockGitCallbackComponent_SensitiveCheck_Call struct { + *mock.Call +} + +// SensitiveCheck is a helper method to define mock.On call +// - ctx context.Context +// - req *types.GiteaCallbackPushReq +func (_e *MockGitCallbackComponent_Expecter) SensitiveCheck(ctx interface{}, req interface{}) *MockGitCallbackComponent_SensitiveCheck_Call { + return &MockGitCallbackComponent_SensitiveCheck_Call{Call: _e.mock.On("SensitiveCheck", ctx, req)} +} + +func (_c *MockGitCallbackComponent_SensitiveCheck_Call) Run(run func(ctx context.Context, req *types.GiteaCallbackPushReq)) *MockGitCallbackComponent_SensitiveCheck_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.GiteaCallbackPushReq)) + }) + return _c +} + +func (_c *MockGitCallbackComponent_SensitiveCheck_Call) Return(_a0 error) *MockGitCallbackComponent_SensitiveCheck_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockGitCallbackComponent_SensitiveCheck_Call) RunAndReturn(run func(context.Context, *types.GiteaCallbackPushReq) error) *MockGitCallbackComponent_SensitiveCheck_Call { + _c.Call.Return(run) + return _c +} + +// SetRepoUpdateTime provides a mock function with given fields: ctx, req +func (_m *MockGitCallbackComponent) SetRepoUpdateTime(ctx context.Context, req *types.GiteaCallbackPushReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for SetRepoUpdateTime") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *types.GiteaCallbackPushReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockGitCallbackComponent_SetRepoUpdateTime_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetRepoUpdateTime' +type MockGitCallbackComponent_SetRepoUpdateTime_Call struct { + *mock.Call +} + +// SetRepoUpdateTime is a helper method to define mock.On call +// - ctx context.Context +// - req *types.GiteaCallbackPushReq +func (_e *MockGitCallbackComponent_Expecter) SetRepoUpdateTime(ctx interface{}, req interface{}) *MockGitCallbackComponent_SetRepoUpdateTime_Call { + return &MockGitCallbackComponent_SetRepoUpdateTime_Call{Call: _e.mock.On("SetRepoUpdateTime", ctx, req)} +} + +func (_c *MockGitCallbackComponent_SetRepoUpdateTime_Call) Run(run func(ctx context.Context, req *types.GiteaCallbackPushReq)) *MockGitCallbackComponent_SetRepoUpdateTime_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.GiteaCallbackPushReq)) + }) + return _c +} + +func (_c *MockGitCallbackComponent_SetRepoUpdateTime_Call) Return(_a0 error) *MockGitCallbackComponent_SetRepoUpdateTime_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockGitCallbackComponent_SetRepoUpdateTime_Call) RunAndReturn(run func(context.Context, *types.GiteaCallbackPushReq) error) *MockGitCallbackComponent_SetRepoUpdateTime_Call { + _c.Call.Return(run) + return _c +} + +// SetRepoVisibility provides a mock function with given fields: yes +func (_m *MockGitCallbackComponent) SetRepoVisibility(yes bool) { + _m.Called(yes) +} + +// MockGitCallbackComponent_SetRepoVisibility_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetRepoVisibility' +type MockGitCallbackComponent_SetRepoVisibility_Call struct { + *mock.Call +} + +// SetRepoVisibility is a helper method to define mock.On call +// - yes bool +func (_e *MockGitCallbackComponent_Expecter) SetRepoVisibility(yes interface{}) *MockGitCallbackComponent_SetRepoVisibility_Call { + return &MockGitCallbackComponent_SetRepoVisibility_Call{Call: _e.mock.On("SetRepoVisibility", yes)} +} + +func (_c *MockGitCallbackComponent_SetRepoVisibility_Call) Run(run func(yes bool)) *MockGitCallbackComponent_SetRepoVisibility_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(bool)) + }) + return _c +} + +func (_c *MockGitCallbackComponent_SetRepoVisibility_Call) Return() *MockGitCallbackComponent_SetRepoVisibility_Call { + _c.Call.Return() + return _c +} + +func (_c *MockGitCallbackComponent_SetRepoVisibility_Call) RunAndReturn(run func(bool)) *MockGitCallbackComponent_SetRepoVisibility_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRepoInfos provides a mock function with given fields: ctx, req +func (_m *MockGitCallbackComponent) UpdateRepoInfos(ctx context.Context, req *types.GiteaCallbackPushReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for UpdateRepoInfos") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *types.GiteaCallbackPushReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockGitCallbackComponent_UpdateRepoInfos_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRepoInfos' +type MockGitCallbackComponent_UpdateRepoInfos_Call struct { + *mock.Call +} + +// UpdateRepoInfos is a helper method to define mock.On call +// - ctx context.Context +// - req *types.GiteaCallbackPushReq +func (_e *MockGitCallbackComponent_Expecter) UpdateRepoInfos(ctx interface{}, req interface{}) *MockGitCallbackComponent_UpdateRepoInfos_Call { + return &MockGitCallbackComponent_UpdateRepoInfos_Call{Call: _e.mock.On("UpdateRepoInfos", ctx, req)} +} + +func (_c *MockGitCallbackComponent_UpdateRepoInfos_Call) Run(run func(ctx context.Context, req *types.GiteaCallbackPushReq)) *MockGitCallbackComponent_UpdateRepoInfos_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.GiteaCallbackPushReq)) + }) + return _c +} + +func (_c *MockGitCallbackComponent_UpdateRepoInfos_Call) Return(_a0 error) *MockGitCallbackComponent_UpdateRepoInfos_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockGitCallbackComponent_UpdateRepoInfos_Call) RunAndReturn(run func(context.Context, *types.GiteaCallbackPushReq) error) *MockGitCallbackComponent_UpdateRepoInfos_Call { + _c.Call.Return(run) + return _c +} + +// WatchRepoRelation provides a mock function with given fields: ctx, req +func (_m *MockGitCallbackComponent) WatchRepoRelation(ctx context.Context, req *types.GiteaCallbackPushReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for WatchRepoRelation") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *types.GiteaCallbackPushReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockGitCallbackComponent_WatchRepoRelation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WatchRepoRelation' +type MockGitCallbackComponent_WatchRepoRelation_Call struct { + *mock.Call +} + +// WatchRepoRelation is a helper method to define mock.On call +// - ctx context.Context +// - req *types.GiteaCallbackPushReq +func (_e *MockGitCallbackComponent_Expecter) WatchRepoRelation(ctx interface{}, req interface{}) *MockGitCallbackComponent_WatchRepoRelation_Call { + return &MockGitCallbackComponent_WatchRepoRelation_Call{Call: _e.mock.On("WatchRepoRelation", ctx, req)} +} + +func (_c *MockGitCallbackComponent_WatchRepoRelation_Call) Run(run func(ctx context.Context, req *types.GiteaCallbackPushReq)) *MockGitCallbackComponent_WatchRepoRelation_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.GiteaCallbackPushReq)) + }) + return _c +} + +func (_c *MockGitCallbackComponent_WatchRepoRelation_Call) Return(_a0 error) *MockGitCallbackComponent_WatchRepoRelation_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockGitCallbackComponent_WatchRepoRelation_Call) RunAndReturn(run func(context.Context, *types.GiteaCallbackPushReq) error) *MockGitCallbackComponent_WatchRepoRelation_Call { + _c.Call.Return(run) + return _c +} + +// WatchSpaceChange provides a mock function with given fields: ctx, req +func (_m *MockGitCallbackComponent) WatchSpaceChange(ctx context.Context, req *types.GiteaCallbackPushReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for WatchSpaceChange") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *types.GiteaCallbackPushReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockGitCallbackComponent_WatchSpaceChange_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WatchSpaceChange' +type MockGitCallbackComponent_WatchSpaceChange_Call struct { + *mock.Call +} + +// WatchSpaceChange is a helper method to define mock.On call +// - ctx context.Context +// - req *types.GiteaCallbackPushReq +func (_e *MockGitCallbackComponent_Expecter) WatchSpaceChange(ctx interface{}, req interface{}) *MockGitCallbackComponent_WatchSpaceChange_Call { + return &MockGitCallbackComponent_WatchSpaceChange_Call{Call: _e.mock.On("WatchSpaceChange", ctx, req)} +} + +func (_c *MockGitCallbackComponent_WatchSpaceChange_Call) Run(run func(ctx context.Context, req *types.GiteaCallbackPushReq)) *MockGitCallbackComponent_WatchSpaceChange_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.GiteaCallbackPushReq)) + }) + return _c +} + +func (_c *MockGitCallbackComponent_WatchSpaceChange_Call) Return(_a0 error) *MockGitCallbackComponent_WatchSpaceChange_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockGitCallbackComponent_WatchSpaceChange_Call) RunAndReturn(run func(context.Context, *types.GiteaCallbackPushReq) error) *MockGitCallbackComponent_WatchSpaceChange_Call { + _c.Call.Return(run) + return _c +} + +// NewMockGitCallbackComponent creates a new instance of MockGitCallbackComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockGitCallbackComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockGitCallbackComponent { + mock := &MockGitCallbackComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/component/mock_MultiSyncComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_MultiSyncComponent.go new file mode 100644 index 00000000..64593c06 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/mock_MultiSyncComponent.go @@ -0,0 +1,146 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + multisync "opencsg.com/csghub-server/builder/multisync" + + types "opencsg.com/csghub-server/common/types" +) + +// MockMultiSyncComponent is an autogenerated mock type for the MultiSyncComponent type +type MockMultiSyncComponent struct { + mock.Mock +} + +type MockMultiSyncComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockMultiSyncComponent) EXPECT() *MockMultiSyncComponent_Expecter { + return &MockMultiSyncComponent_Expecter{mock: &_m.Mock} +} + +// More provides a mock function with given fields: ctx, cur, limit +func (_m *MockMultiSyncComponent) More(ctx context.Context, cur int64, limit int64) ([]types.SyncVersion, error) { + ret := _m.Called(ctx, cur, limit) + + if len(ret) == 0 { + panic("no return value specified for More") + } + + var r0 []types.SyncVersion + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int64, int64) ([]types.SyncVersion, error)); ok { + return rf(ctx, cur, limit) + } + if rf, ok := ret.Get(0).(func(context.Context, int64, int64) []types.SyncVersion); ok { + r0 = rf(ctx, cur, limit) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.SyncVersion) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int64, int64) error); ok { + r1 = rf(ctx, cur, limit) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMultiSyncComponent_More_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'More' +type MockMultiSyncComponent_More_Call struct { + *mock.Call +} + +// More is a helper method to define mock.On call +// - ctx context.Context +// - cur int64 +// - limit int64 +func (_e *MockMultiSyncComponent_Expecter) More(ctx interface{}, cur interface{}, limit interface{}) *MockMultiSyncComponent_More_Call { + return &MockMultiSyncComponent_More_Call{Call: _e.mock.On("More", ctx, cur, limit)} +} + +func (_c *MockMultiSyncComponent_More_Call) Run(run func(ctx context.Context, cur int64, limit int64)) *MockMultiSyncComponent_More_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64), args[2].(int64)) + }) + return _c +} + +func (_c *MockMultiSyncComponent_More_Call) Return(_a0 []types.SyncVersion, _a1 error) *MockMultiSyncComponent_More_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMultiSyncComponent_More_Call) RunAndReturn(run func(context.Context, int64, int64) ([]types.SyncVersion, error)) *MockMultiSyncComponent_More_Call { + _c.Call.Return(run) + return _c +} + +// SyncAsClient provides a mock function with given fields: ctx, sc +func (_m *MockMultiSyncComponent) SyncAsClient(ctx context.Context, sc multisync.Client) error { + ret := _m.Called(ctx, sc) + + if len(ret) == 0 { + panic("no return value specified for SyncAsClient") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, multisync.Client) error); ok { + r0 = rf(ctx, sc) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMultiSyncComponent_SyncAsClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SyncAsClient' +type MockMultiSyncComponent_SyncAsClient_Call struct { + *mock.Call +} + +// SyncAsClient is a helper method to define mock.On call +// - ctx context.Context +// - sc multisync.Client +func (_e *MockMultiSyncComponent_Expecter) SyncAsClient(ctx interface{}, sc interface{}) *MockMultiSyncComponent_SyncAsClient_Call { + return &MockMultiSyncComponent_SyncAsClient_Call{Call: _e.mock.On("SyncAsClient", ctx, sc)} +} + +func (_c *MockMultiSyncComponent_SyncAsClient_Call) Run(run func(ctx context.Context, sc multisync.Client)) *MockMultiSyncComponent_SyncAsClient_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(multisync.Client)) + }) + return _c +} + +func (_c *MockMultiSyncComponent_SyncAsClient_Call) Return(_a0 error) *MockMultiSyncComponent_SyncAsClient_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMultiSyncComponent_SyncAsClient_Call) RunAndReturn(run func(context.Context, multisync.Client) error) *MockMultiSyncComponent_SyncAsClient_Call { + _c.Call.Return(run) + return _c +} + +// NewMockMultiSyncComponent creates a new instance of MockMultiSyncComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockMultiSyncComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockMultiSyncComponent { + mock := &MockMultiSyncComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/component/mock_RecomComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_RecomComponent.go new file mode 100644 index 00000000..91dad12d --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/mock_RecomComponent.go @@ -0,0 +1,166 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + database "opencsg.com/csghub-server/builder/store/database" +) + +// MockRecomComponent is an autogenerated mock type for the RecomComponent type +type MockRecomComponent struct { + mock.Mock +} + +type MockRecomComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockRecomComponent) EXPECT() *MockRecomComponent_Expecter { + return &MockRecomComponent_Expecter{mock: &_m.Mock} +} + +// CalcTotalScore provides a mock function with given fields: ctx, repo, weights +func (_m *MockRecomComponent) CalcTotalScore(ctx context.Context, repo *database.Repository, weights map[string]string) float64 { + ret := _m.Called(ctx, repo, weights) + + if len(ret) == 0 { + panic("no return value specified for CalcTotalScore") + } + + var r0 float64 + if rf, ok := ret.Get(0).(func(context.Context, *database.Repository, map[string]string) float64); ok { + r0 = rf(ctx, repo, weights) + } else { + r0 = ret.Get(0).(float64) + } + + return r0 +} + +// MockRecomComponent_CalcTotalScore_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CalcTotalScore' +type MockRecomComponent_CalcTotalScore_Call struct { + *mock.Call +} + +// CalcTotalScore is a helper method to define mock.On call +// - ctx context.Context +// - repo *database.Repository +// - weights map[string]string +func (_e *MockRecomComponent_Expecter) CalcTotalScore(ctx interface{}, repo interface{}, weights interface{}) *MockRecomComponent_CalcTotalScore_Call { + return &MockRecomComponent_CalcTotalScore_Call{Call: _e.mock.On("CalcTotalScore", ctx, repo, weights)} +} + +func (_c *MockRecomComponent_CalcTotalScore_Call) Run(run func(ctx context.Context, repo *database.Repository, weights map[string]string)) *MockRecomComponent_CalcTotalScore_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*database.Repository), args[2].(map[string]string)) + }) + return _c +} + +func (_c *MockRecomComponent_CalcTotalScore_Call) Return(_a0 float64) *MockRecomComponent_CalcTotalScore_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockRecomComponent_CalcTotalScore_Call) RunAndReturn(run func(context.Context, *database.Repository, map[string]string) float64) *MockRecomComponent_CalcTotalScore_Call { + _c.Call.Return(run) + return _c +} + +// CalculateRecomScore provides a mock function with given fields: ctx +func (_m *MockRecomComponent) CalculateRecomScore(ctx context.Context) { + _m.Called(ctx) +} + +// MockRecomComponent_CalculateRecomScore_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CalculateRecomScore' +type MockRecomComponent_CalculateRecomScore_Call struct { + *mock.Call +} + +// CalculateRecomScore is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockRecomComponent_Expecter) CalculateRecomScore(ctx interface{}) *MockRecomComponent_CalculateRecomScore_Call { + return &MockRecomComponent_CalculateRecomScore_Call{Call: _e.mock.On("CalculateRecomScore", ctx)} +} + +func (_c *MockRecomComponent_CalculateRecomScore_Call) Run(run func(ctx context.Context)) *MockRecomComponent_CalculateRecomScore_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockRecomComponent_CalculateRecomScore_Call) Return() *MockRecomComponent_CalculateRecomScore_Call { + _c.Call.Return() + return _c +} + +func (_c *MockRecomComponent_CalculateRecomScore_Call) RunAndReturn(run func(context.Context)) *MockRecomComponent_CalculateRecomScore_Call { + _c.Call.Return(run) + return _c +} + +// SetOpWeight provides a mock function with given fields: ctx, repoID, weight +func (_m *MockRecomComponent) SetOpWeight(ctx context.Context, repoID int64, weight int64) error { + ret := _m.Called(ctx, repoID, weight) + + if len(ret) == 0 { + panic("no return value specified for SetOpWeight") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int64, int64) error); ok { + r0 = rf(ctx, repoID, weight) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockRecomComponent_SetOpWeight_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetOpWeight' +type MockRecomComponent_SetOpWeight_Call struct { + *mock.Call +} + +// SetOpWeight is a helper method to define mock.On call +// - ctx context.Context +// - repoID int64 +// - weight int64 +func (_e *MockRecomComponent_Expecter) SetOpWeight(ctx interface{}, repoID interface{}, weight interface{}) *MockRecomComponent_SetOpWeight_Call { + return &MockRecomComponent_SetOpWeight_Call{Call: _e.mock.On("SetOpWeight", ctx, repoID, weight)} +} + +func (_c *MockRecomComponent_SetOpWeight_Call) Run(run func(ctx context.Context, repoID int64, weight int64)) *MockRecomComponent_SetOpWeight_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64), args[2].(int64)) + }) + return _c +} + +func (_c *MockRecomComponent_SetOpWeight_Call) Return(_a0 error) *MockRecomComponent_SetOpWeight_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockRecomComponent_SetOpWeight_Call) RunAndReturn(run func(context.Context, int64, int64) error) *MockRecomComponent_SetOpWeight_Call { + _c.Call.Return(run) + return _c +} + +// NewMockRecomComponent creates a new instance of MockRecomComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockRecomComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockRecomComponent { + mock := &MockRecomComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/user/component/mock_OrganizationComponent.go b/_mocks/opencsg.com/csghub-server/user/component/mock_OrganizationComponent.go new file mode 100644 index 00000000..3640b555 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/user/component/mock_OrganizationComponent.go @@ -0,0 +1,381 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + database "opencsg.com/csghub-server/builder/store/database" + + types "opencsg.com/csghub-server/common/types" +) + +// MockOrganizationComponent is an autogenerated mock type for the OrganizationComponent type +type MockOrganizationComponent struct { + mock.Mock +} + +type MockOrganizationComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockOrganizationComponent) EXPECT() *MockOrganizationComponent_Expecter { + return &MockOrganizationComponent_Expecter{mock: &_m.Mock} +} + +// Create provides a mock function with given fields: ctx, req +func (_m *MockOrganizationComponent) Create(ctx context.Context, req *types.CreateOrgReq) (*types.Organization, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Create") + } + + var r0 *types.Organization + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.CreateOrgReq) (*types.Organization, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.CreateOrgReq) *types.Organization); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Organization) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.CreateOrgReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockOrganizationComponent_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' +type MockOrganizationComponent_Create_Call struct { + *mock.Call +} + +// Create is a helper method to define mock.On call +// - ctx context.Context +// - req *types.CreateOrgReq +func (_e *MockOrganizationComponent_Expecter) Create(ctx interface{}, req interface{}) *MockOrganizationComponent_Create_Call { + return &MockOrganizationComponent_Create_Call{Call: _e.mock.On("Create", ctx, req)} +} + +func (_c *MockOrganizationComponent_Create_Call) Run(run func(ctx context.Context, req *types.CreateOrgReq)) *MockOrganizationComponent_Create_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.CreateOrgReq)) + }) + return _c +} + +func (_c *MockOrganizationComponent_Create_Call) Return(_a0 *types.Organization, _a1 error) *MockOrganizationComponent_Create_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockOrganizationComponent_Create_Call) RunAndReturn(run func(context.Context, *types.CreateOrgReq) (*types.Organization, error)) *MockOrganizationComponent_Create_Call { + _c.Call.Return(run) + return _c +} + +// Delete provides a mock function with given fields: ctx, req +func (_m *MockOrganizationComponent) Delete(ctx context.Context, req *types.DeleteOrgReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *types.DeleteOrgReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockOrganizationComponent_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type MockOrganizationComponent_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - ctx context.Context +// - req *types.DeleteOrgReq +func (_e *MockOrganizationComponent_Expecter) Delete(ctx interface{}, req interface{}) *MockOrganizationComponent_Delete_Call { + return &MockOrganizationComponent_Delete_Call{Call: _e.mock.On("Delete", ctx, req)} +} + +func (_c *MockOrganizationComponent_Delete_Call) Run(run func(ctx context.Context, req *types.DeleteOrgReq)) *MockOrganizationComponent_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.DeleteOrgReq)) + }) + return _c +} + +func (_c *MockOrganizationComponent_Delete_Call) Return(_a0 error) *MockOrganizationComponent_Delete_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockOrganizationComponent_Delete_Call) RunAndReturn(run func(context.Context, *types.DeleteOrgReq) error) *MockOrganizationComponent_Delete_Call { + _c.Call.Return(run) + return _c +} + +// FixOrgData provides a mock function with given fields: ctx, org +func (_m *MockOrganizationComponent) FixOrgData(ctx context.Context, org *database.Organization) (*database.Organization, error) { + ret := _m.Called(ctx, org) + + if len(ret) == 0 { + panic("no return value specified for FixOrgData") + } + + var r0 *database.Organization + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *database.Organization) (*database.Organization, error)); ok { + return rf(ctx, org) + } + if rf, ok := ret.Get(0).(func(context.Context, *database.Organization) *database.Organization); ok { + r0 = rf(ctx, org) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.Organization) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *database.Organization) error); ok { + r1 = rf(ctx, org) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockOrganizationComponent_FixOrgData_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'FixOrgData' +type MockOrganizationComponent_FixOrgData_Call struct { + *mock.Call +} + +// FixOrgData is a helper method to define mock.On call +// - ctx context.Context +// - org *database.Organization +func (_e *MockOrganizationComponent_Expecter) FixOrgData(ctx interface{}, org interface{}) *MockOrganizationComponent_FixOrgData_Call { + return &MockOrganizationComponent_FixOrgData_Call{Call: _e.mock.On("FixOrgData", ctx, org)} +} + +func (_c *MockOrganizationComponent_FixOrgData_Call) Run(run func(ctx context.Context, org *database.Organization)) *MockOrganizationComponent_FixOrgData_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*database.Organization)) + }) + return _c +} + +func (_c *MockOrganizationComponent_FixOrgData_Call) Return(_a0 *database.Organization, _a1 error) *MockOrganizationComponent_FixOrgData_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockOrganizationComponent_FixOrgData_Call) RunAndReturn(run func(context.Context, *database.Organization) (*database.Organization, error)) *MockOrganizationComponent_FixOrgData_Call { + _c.Call.Return(run) + return _c +} + +// Get provides a mock function with given fields: ctx, orgName +func (_m *MockOrganizationComponent) Get(ctx context.Context, orgName string) (*types.Organization, error) { + ret := _m.Called(ctx, orgName) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 *types.Organization + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*types.Organization, error)); ok { + return rf(ctx, orgName) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *types.Organization); ok { + r0 = rf(ctx, orgName) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Organization) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, orgName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockOrganizationComponent_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type MockOrganizationComponent_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - ctx context.Context +// - orgName string +func (_e *MockOrganizationComponent_Expecter) Get(ctx interface{}, orgName interface{}) *MockOrganizationComponent_Get_Call { + return &MockOrganizationComponent_Get_Call{Call: _e.mock.On("Get", ctx, orgName)} +} + +func (_c *MockOrganizationComponent_Get_Call) Run(run func(ctx context.Context, orgName string)) *MockOrganizationComponent_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockOrganizationComponent_Get_Call) Return(_a0 *types.Organization, _a1 error) *MockOrganizationComponent_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockOrganizationComponent_Get_Call) RunAndReturn(run func(context.Context, string) (*types.Organization, error)) *MockOrganizationComponent_Get_Call { + _c.Call.Return(run) + return _c +} + +// Index provides a mock function with given fields: ctx, username +func (_m *MockOrganizationComponent) Index(ctx context.Context, username string) ([]types.Organization, error) { + ret := _m.Called(ctx, username) + + if len(ret) == 0 { + panic("no return value specified for Index") + } + + var r0 []types.Organization + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]types.Organization, error)); ok { + return rf(ctx, username) + } + if rf, ok := ret.Get(0).(func(context.Context, string) []types.Organization); ok { + r0 = rf(ctx, username) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Organization) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, username) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockOrganizationComponent_Index_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Index' +type MockOrganizationComponent_Index_Call struct { + *mock.Call +} + +// Index is a helper method to define mock.On call +// - ctx context.Context +// - username string +func (_e *MockOrganizationComponent_Expecter) Index(ctx interface{}, username interface{}) *MockOrganizationComponent_Index_Call { + return &MockOrganizationComponent_Index_Call{Call: _e.mock.On("Index", ctx, username)} +} + +func (_c *MockOrganizationComponent_Index_Call) Run(run func(ctx context.Context, username string)) *MockOrganizationComponent_Index_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockOrganizationComponent_Index_Call) Return(_a0 []types.Organization, _a1 error) *MockOrganizationComponent_Index_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockOrganizationComponent_Index_Call) RunAndReturn(run func(context.Context, string) ([]types.Organization, error)) *MockOrganizationComponent_Index_Call { + _c.Call.Return(run) + return _c +} + +// Update provides a mock function with given fields: ctx, req +func (_m *MockOrganizationComponent) Update(ctx context.Context, req *types.EditOrgReq) (*database.Organization, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Update") + } + + var r0 *database.Organization + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.EditOrgReq) (*database.Organization, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.EditOrgReq) *database.Organization); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.Organization) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.EditOrgReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockOrganizationComponent_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' +type MockOrganizationComponent_Update_Call struct { + *mock.Call +} + +// Update is a helper method to define mock.On call +// - ctx context.Context +// - req *types.EditOrgReq +func (_e *MockOrganizationComponent_Expecter) Update(ctx interface{}, req interface{}) *MockOrganizationComponent_Update_Call { + return &MockOrganizationComponent_Update_Call{Call: _e.mock.On("Update", ctx, req)} +} + +func (_c *MockOrganizationComponent_Update_Call) Run(run func(ctx context.Context, req *types.EditOrgReq)) *MockOrganizationComponent_Update_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.EditOrgReq)) + }) + return _c +} + +func (_c *MockOrganizationComponent_Update_Call) Return(_a0 *database.Organization, _a1 error) *MockOrganizationComponent_Update_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockOrganizationComponent_Update_Call) RunAndReturn(run func(context.Context, *types.EditOrgReq) (*database.Organization, error)) *MockOrganizationComponent_Update_Call { + _c.Call.Return(run) + return _c +} + +// NewMockOrganizationComponent creates a new instance of MockOrganizationComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockOrganizationComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockOrganizationComponent { + mock := &MockOrganizationComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/api/handler/callback/git_callback.go b/api/handler/callback/git_callback.go index 1fda23f8..c21e7555 100644 --- a/api/handler/callback/git_callback.go +++ b/api/handler/callback/git_callback.go @@ -8,6 +8,7 @@ import ( "go.temporal.io/sdk/client" "opencsg.com/csghub-server/api/httpbase" "opencsg.com/csghub-server/api/workflow" + "opencsg.com/csghub-server/builder/temporal" "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/types" component "opencsg.com/csghub-server/component/callback" @@ -48,14 +49,13 @@ func (h *GitCallbackHandler) handlePush(c *gin.Context) { return } //start workflow to handle push request - workflowClient := workflow.GetWorkflowClient() + workflowClient := temporal.GetClient() workflowOptions := client.StartWorkflowOptions{ TaskQueue: workflow.HandlePushQueueName, } - we, err := workflowClient.ExecuteWorkflow(c, workflowOptions, workflow.HandlePushWorkflow, - &req, - h.config, + we, err := workflowClient.ExecuteWorkflow( + c, workflowOptions, workflow.HandlePushWorkflow, &req, ) if err != nil { slog.Error("failed to handle git push callback", slog.Any("error", err)) diff --git a/api/handler/internal.go b/api/handler/internal.go index 12ada5cc..2c66efe8 100644 --- a/api/handler/internal.go +++ b/api/handler/internal.go @@ -9,6 +9,7 @@ import ( "go.temporal.io/sdk/client" "opencsg.com/csghub-server/api/httpbase" "opencsg.com/csghub-server/api/workflow" + "opencsg.com/csghub-server/builder/temporal" "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/types" "opencsg.com/csghub-server/component" @@ -22,14 +23,14 @@ func NewInternalHandler(config *config.Config) (*InternalHandler, error) { return &InternalHandler{ internal: uc, config: config, - workflowClient: workflow.GetWorkflowClient(), + temporalClient: temporal.GetClient(), }, nil } type InternalHandler struct { internal component.InternalComponent config *config.Config - workflowClient client.Client + temporalClient temporal.Client } // TODO: add prmission check @@ -138,14 +139,12 @@ func (h *InternalHandler) PostReceive(ctx *gin.Context) { } callback.Ref = originalRef //start workflow to handle push request - workflowClient := h.workflowClient workflowOptions := client.StartWorkflowOptions{ TaskQueue: workflow.HandlePushQueueName, } - we, err := workflowClient.ExecuteWorkflow(ctx, workflowOptions, workflow.HandlePushWorkflow, - callback, - h.config, + we, err := h.temporalClient.ExecuteWorkflow( + ctx, workflowOptions, workflow.HandlePushWorkflow, callback, ) if err != nil { slog.Error("failed to handle git push callback", slog.Any("error", err)) diff --git a/api/handler/internal_test.go b/api/handler/internal_test.go index 82446dbe..5ff19e2f 100644 --- a/api/handler/internal_test.go +++ b/api/handler/internal_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/mock" "go.temporal.io/sdk/client" temporal_mock "go.temporal.io/sdk/mocks" - mock_temporal "opencsg.com/csghub-server/_mocks/go.temporal.io/sdk/client" + workflow_mock "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/temporal" mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" "opencsg.com/csghub-server/api/workflow" "opencsg.com/csghub-server/builder/store/database" @@ -19,19 +19,19 @@ type InternalTester struct { *GinTester handler *InternalHandler mocks struct { - internal *mockcomponent.MockInternalComponent - workflowClient *mock_temporal.MockClient + internal *mockcomponent.MockInternalComponent + workflow *workflow_mock.MockClient } } func NewInternalTester(t *testing.T) *InternalTester { tester := &InternalTester{GinTester: NewGinTester()} tester.mocks.internal = mockcomponent.NewMockInternalComponent(t) - tester.mocks.workflowClient = mock_temporal.NewMockClient(t) + tester.mocks.workflow = workflow_mock.NewMockClient(t) tester.handler = &InternalHandler{ internal: tester.mocks.internal, - workflowClient: tester.mocks.workflowClient, + temporalClient: tester.mocks.workflow, config: &config.Config{}, } tester.WithParam("internalId", "testInternalId") @@ -149,11 +149,11 @@ func TestInternalHandler_PostReceive(t *testing.T) { runMock := &temporal_mock.WorkflowRun{} runMock.On("GetID").Return("id") - tester.mocks.workflowClient.EXPECT().ExecuteWorkflow( + tester.mocks.workflow.EXPECT().ExecuteWorkflow( tester.ctx, client.StartWorkflowOptions{ TaskQueue: workflow.HandlePushQueueName, }, mock.Anything, - &types.GiteaCallbackPushReq{Ref: "ref/heads/main"}, &config.Config{}, + &types.GiteaCallbackPushReq{Ref: "ref/heads/main"}, ).Return( runMock, nil, ) diff --git a/api/workflow/activity/activities_ce.go b/api/workflow/activity/activities_ce.go new file mode 100644 index 00000000..8cc27961 --- /dev/null +++ b/api/workflow/activity/activities_ce.go @@ -0,0 +1,46 @@ +//go:build !ee && !saas + +package activity + +import ( + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/component/callback" +) + +type stores struct { + syncClientSetting database.SyncClientSettingStore +} + +type Activities struct { + config *config.Config + callback callback.GitCallbackComponent + recom component.RecomComponent + gitServer gitserver.GitServer + multisync component.MultiSyncComponent + stores stores +} + +func NewActivities( + cfg *config.Config, + callback callback.GitCallbackComponent, + recom component.RecomComponent, + gitServer gitserver.GitServer, + multisync component.MultiSyncComponent, + syncClientSetting database.SyncClientSettingStore, +) *Activities { + stores := stores{ + syncClientSetting: syncClientSetting, + } + + return &Activities{ + config: cfg, + callback: callback, + recom: recom, + gitServer: gitServer, + multisync: multisync, + stores: stores, + } +} diff --git a/api/workflow/activity/calc_recom_score.go b/api/workflow/activity/calc_recom_score.go index a2704f60..e11df67f 100644 --- a/api/workflow/activity/calc_recom_score.go +++ b/api/workflow/activity/calc_recom_score.go @@ -2,18 +2,9 @@ package activity import ( "context" - "log/slog" - - "opencsg.com/csghub-server/common/config" - "opencsg.com/csghub-server/component" ) -func CalcRecomScore(ctx context.Context, config *config.Config) error { - c, err := component.NewRecomComponent(config) - if err != nil { - slog.Error("failed to create recom component", "err", err) - return err - } - c.CalculateRecomScore(context.Background()) +func (a *Activities) CalcRecomScore(ctx context.Context) error { + a.recom.CalculateRecomScore(context.Background()) return nil } diff --git a/api/workflow/activity/handle_push.go b/api/workflow/activity/handle_push.go index 112e7a16..2a1171d3 100644 --- a/api/workflow/activity/handle_push.go +++ b/api/workflow/activity/handle_push.go @@ -2,65 +2,42 @@ package activity import ( "context" - "fmt" "go.temporal.io/sdk/activity" - "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/types" - "opencsg.com/csghub-server/component/callback" ) -func WatchSpaceChange(ctx context.Context, req *types.GiteaCallbackPushReq, config *config.Config) error { +func (a *Activities) WatchSpaceChange(ctx context.Context, req *types.GiteaCallbackPushReq) error { logger := activity.GetLogger(ctx) logger.Info("watch space change start", "req", req) - callbackComponent, err := callback.NewGitCallback(config) - if err != nil { - return fmt.Errorf("failed to create callback component, error: %w", err) - } - callbackComponent.SetRepoVisibility(true) - return callbackComponent.WatchSpaceChange(ctx, req) + a.callback.SetRepoVisibility(true) + return a.callback.WatchSpaceChange(ctx, req) } -func WatchRepoRelation(ctx context.Context, req *types.GiteaCallbackPushReq, config *config.Config) error { +func (a *Activities) WatchRepoRelation(ctx context.Context, req *types.GiteaCallbackPushReq) error { logger := activity.GetLogger(ctx) logger.Info("watch repo relation start", "req", req) - callbackComponent, err := callback.NewGitCallback(config) - if err != nil { - return fmt.Errorf("failed to create callback component, error: %w", err) - } - callbackComponent.SetRepoVisibility(true) - return callbackComponent.WatchRepoRelation(ctx, req) + a.callback.SetRepoVisibility(true) + return a.callback.WatchRepoRelation(ctx, req) } -func SetRepoUpdateTime(ctx context.Context, req *types.GiteaCallbackPushReq, config *config.Config) error { +func (a *Activities) SetRepoUpdateTime(ctx context.Context, req *types.GiteaCallbackPushReq) error { logger := activity.GetLogger(ctx) logger.Info("set repo update time start", "req", req) - callbackComponent, err := callback.NewGitCallback(config) - if err != nil { - return fmt.Errorf("failed to create callback component, error: %w", err) - } - callbackComponent.SetRepoVisibility(true) - return callbackComponent.SetRepoUpdateTime(ctx, req) + a.callback.SetRepoVisibility(true) + return a.callback.SetRepoUpdateTime(ctx, req) } -func UpdateRepoInfos(ctx context.Context, req *types.GiteaCallbackPushReq, config *config.Config) error { +func (a *Activities) UpdateRepoInfos(ctx context.Context, req *types.GiteaCallbackPushReq) error { logger := activity.GetLogger(ctx) logger.Info("update repo infos start", "req", req) - callbackComponent, err := callback.NewGitCallback(config) - if err != nil { - return fmt.Errorf("failed to create callback component, error: %w", err) - } - callbackComponent.SetRepoVisibility(true) - return callbackComponent.UpdateRepoInfos(ctx, req) + a.callback.SetRepoVisibility(true) + return a.callback.UpdateRepoInfos(ctx, req) } -func SensitiveCheck(ctx context.Context, req *types.GiteaCallbackPushReq, config *config.Config) error { +func (a *Activities) SensitiveCheck(ctx context.Context, req *types.GiteaCallbackPushReq) error { logger := activity.GetLogger(ctx) logger.Info("sensitive check start", "req", req) - callbackComponent, err := callback.NewGitCallback(config) - if err != nil { - return fmt.Errorf("failed to create callback component, error: %w", err) - } - callbackComponent.SetRepoVisibility(true) - return callbackComponent.SensitiveCheck(ctx, req) + a.callback.SetRepoVisibility(true) + return a.callback.SensitiveCheck(ctx, req) } diff --git a/api/workflow/activity/sync_as_client.go b/api/workflow/activity/sync_as_client.go index 99e4ca3e..92652ae2 100644 --- a/api/workflow/activity/sync_as_client.go +++ b/api/workflow/activity/sync_as_client.go @@ -5,24 +5,15 @@ import ( "log/slog" "opencsg.com/csghub-server/builder/multisync" - "opencsg.com/csghub-server/builder/store/database" - "opencsg.com/csghub-server/common/config" - "opencsg.com/csghub-server/component" ) -func SyncAsClient(ctx context.Context, config *config.Config) error { - c, err := component.NewMultiSyncComponent(config) - if err != nil { - slog.Error("failed to create multi sync component", "err", err) - return err - } - syncClientSettingStore := database.NewSyncClientSettingStore() - setting, err := syncClientSettingStore.First(ctx) +func (a *Activities) SyncAsClient(ctx context.Context) error { + setting, err := a.stores.syncClientSetting.First(ctx) if err != nil { slog.Error("failed to find sync client setting", "error", err) return err } - apiDomain := config.MultiSync.SaasAPIDomain + apiDomain := a.config.MultiSync.SaasAPIDomain sc := multisync.FromOpenCSG(apiDomain, setting.Token) - return c.SyncAsClient(ctx, sc) + return a.multisync.SyncAsClient(ctx, sc) } diff --git a/api/workflow/cron_calc_recom_score.go b/api/workflow/cron_calc_recom_score.go index d6c44fab..293a54e6 100644 --- a/api/workflow/cron_calc_recom_score.go +++ b/api/workflow/cron_calc_recom_score.go @@ -5,11 +5,9 @@ import ( "go.temporal.io/sdk/temporal" "go.temporal.io/sdk/workflow" - "opencsg.com/csghub-server/api/workflow/activity" - "opencsg.com/csghub-server/common/config" ) -func CalcRecomScoreWorkflow(ctx workflow.Context, config *config.Config) error { +func CalcRecomScoreWorkflow(ctx workflow.Context) error { logger := workflow.GetLogger(ctx) logger.Info("calc recom score workflow started") @@ -23,7 +21,7 @@ func CalcRecomScoreWorkflow(ctx workflow.Context, config *config.Config) error { } ctx = workflow.WithActivityOptions(ctx, options) - err := workflow.ExecuteActivity(ctx, activity.CalcRecomScore, config).Get(ctx, nil) + err := workflow.ExecuteActivity(ctx, activities.CalcRecomScore).Get(ctx, nil) if err != nil { logger.Error("failed to calc recom score", "error", err) return err diff --git a/api/workflow/cron_sync_as_client.go b/api/workflow/cron_sync_as_client.go index 54f5ab6b..1c200a49 100644 --- a/api/workflow/cron_sync_as_client.go +++ b/api/workflow/cron_sync_as_client.go @@ -5,11 +5,9 @@ import ( "go.temporal.io/sdk/temporal" "go.temporal.io/sdk/workflow" - "opencsg.com/csghub-server/api/workflow/activity" - "opencsg.com/csghub-server/common/config" ) -func SyncAsClientWorkflow(ctx workflow.Context, config *config.Config) error { +func SyncAsClientWorkflow(ctx workflow.Context) error { logger := workflow.GetLogger(ctx) logger.Info("sync as client workflow started") @@ -23,7 +21,7 @@ func SyncAsClientWorkflow(ctx workflow.Context, config *config.Config) error { } ctx = workflow.WithActivityOptions(ctx, options) - err := workflow.ExecuteActivity(ctx, activity.SyncAsClient, config).Get(ctx, nil) + err := workflow.ExecuteActivity(ctx, activities.SyncAsClient).Get(ctx, nil) if err != nil { logger.Error("failed to sync as client", "error", err) return err diff --git a/api/workflow/cron_worker.go b/api/workflow/cron_worker.go index b0409c12..7ba1c348 100644 --- a/api/workflow/cron_worker.go +++ b/api/workflow/cron_worker.go @@ -1,88 +1,6 @@ package workflow -import ( - "context" - "fmt" - - enumspb "go.temporal.io/api/enums/v1" - "go.temporal.io/sdk/client" - "go.temporal.io/sdk/worker" - "opencsg.com/csghub-server/api/workflow/activity" - "opencsg.com/csghub-server/common/config" -) - const ( AlreadyScheduledMessage = "schedule with this ID is already registered" CronJobQueueName = "workflow_cron_queue" ) - -func RegisterCronJobs(config *config.Config) error { - var err error - if wfClient == nil { - wfClient, err = client.Dial(client.Options{ - HostPort: config.WorkFLow.Endpoint, - }) - if err != nil { - return fmt.Errorf("unable to create workflow client, error:%w", err) - } - } - - if !config.Saas { - _, err = wfClient.ScheduleClient().Create(context.Background(), client.ScheduleOptions{ - ID: "sync-as-client-schedule", - Spec: client.ScheduleSpec{ - CronExpressions: []string{config.CronJob.SyncAsClientCronExpression}, - }, - Overlap: enumspb.SCHEDULE_OVERLAP_POLICY_SKIP, - Action: &client.ScheduleWorkflowAction{ - ID: "sync-as-client-workflow", - TaskQueue: CronJobQueueName, - Workflow: SyncAsClientWorkflow, - Args: []interface{}{config}, - }, - }) - if err != nil && err.Error() != AlreadyScheduledMessage { - return fmt.Errorf("unable to create schedule, error:%w", err) - } - } - - _, err = wfClient.ScheduleClient().Create(context.Background(), client.ScheduleOptions{ - ID: "calc-recom-score-schedule", - Spec: client.ScheduleSpec{ - CronExpressions: []string{config.CronJob.CalcRecomScoreCronExpression}, - }, - Overlap: enumspb.SCHEDULE_OVERLAP_POLICY_SKIP, - Action: &client.ScheduleWorkflowAction{ - ID: "calc-recom-score-workflow", - TaskQueue: CronJobQueueName, - Workflow: CalcRecomScoreWorkflow, - Args: []interface{}{config}, - }, - }) - if err != nil && err.Error() != AlreadyScheduledMessage { - return fmt.Errorf("unable to create schedule, error:%w", err) - } - - return nil -} - -func StartCronWorker(config *config.Config) error { - var err error - if wfClient == nil { - wfClient, err = client.Dial(client.Options{ - HostPort: config.WorkFLow.Endpoint, - }) - if err != nil { - return fmt.Errorf("unable to create workflow client, error:%w", err) - } - } - wfWorker = worker.New(wfClient, CronJobQueueName, worker.Options{}) - if !config.Saas { - wfWorker.RegisterWorkflow(SyncAsClientWorkflow) - wfWorker.RegisterActivity(activity.SyncAsClient) - } - wfWorker.RegisterWorkflow(CalcRecomScoreWorkflow) - wfWorker.RegisterActivity(activity.CalcRecomScore) - - return wfWorker.Start() -} diff --git a/api/workflow/cron_worker_ce.go b/api/workflow/cron_worker_ce.go new file mode 100644 index 00000000..faee075f --- /dev/null +++ b/api/workflow/cron_worker_ce.go @@ -0,0 +1,63 @@ +//go:build !ee && !saas + +package workflow + +import ( + "context" + "fmt" + + enumspb "go.temporal.io/api/enums/v1" + "go.temporal.io/sdk/client" + "go.temporal.io/sdk/worker" + "opencsg.com/csghub-server/builder/temporal" + "opencsg.com/csghub-server/common/config" +) + +func RegisterCronJobs(config *config.Config, temporalClient temporal.Client) error { + var err error + scheduler := temporalClient.ScheduleClient() + + _, err = scheduler.Create(context.Background(), client.ScheduleOptions{ + ID: "sync-as-client-schedule", + Spec: client.ScheduleSpec{ + CronExpressions: []string{config.CronJob.SyncAsClientCronExpression}, + }, + Overlap: enumspb.SCHEDULE_OVERLAP_POLICY_SKIP, + Action: &client.ScheduleWorkflowAction{ + ID: "sync-as-client-workflow", + TaskQueue: CronJobQueueName, + Workflow: SyncAsClientWorkflow, + Args: []interface{}{config}, + }, + }) + if err != nil && err.Error() != AlreadyScheduledMessage { + return fmt.Errorf("unable to create schedule, error:%w", err) + } + + _, err = scheduler.Create(context.Background(), client.ScheduleOptions{ + ID: "calc-recom-score-schedule", + Spec: client.ScheduleSpec{ + CronExpressions: []string{config.CronJob.CalcRecomScoreCronExpression}, + }, + Overlap: enumspb.SCHEDULE_OVERLAP_POLICY_SKIP, + Action: &client.ScheduleWorkflowAction{ + ID: "calc-recom-score-workflow", + TaskQueue: CronJobQueueName, + Workflow: CalcRecomScoreWorkflow, + Args: []interface{}{config}, + }, + }) + if err != nil && err.Error() != AlreadyScheduledMessage { + return fmt.Errorf("unable to create schedule, error:%w", err) + } + + return nil +} + +func RegisterCronWorker(config *config.Config, temporalClient temporal.Client) { + + wfWorker := temporalClient.NewWorker(CronJobQueueName, worker.Options{}) + wfWorker.RegisterWorkflow(SyncAsClientWorkflow) + wfWorker.RegisterWorkflow(CalcRecomScoreWorkflow) + +} diff --git a/api/workflow/handle_push.go b/api/workflow/handle_push.go index 05e9c4d1..f3746025 100644 --- a/api/workflow/handle_push.go +++ b/api/workflow/handle_push.go @@ -6,12 +6,10 @@ import ( "go.temporal.io/sdk/temporal" "go.temporal.io/sdk/workflow" - "opencsg.com/csghub-server/api/workflow/activity" - "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/types" ) -func HandlePushWorkflow(ctx workflow.Context, req *types.GiteaCallbackPushReq, config *config.Config) error { +func HandlePushWorkflow(ctx workflow.Context, req *types.GiteaCallbackPushReq) error { logger := workflow.GetLogger(ctx) logger.Info("handle push workflow started") @@ -26,35 +24,35 @@ func HandlePushWorkflow(ctx workflow.Context, req *types.GiteaCallbackPushReq, c ctx = workflow.WithActivityOptions(ctx, options) // Watch space change - err := workflow.ExecuteActivity(ctx, activity.WatchSpaceChange, req, config).Get(ctx, nil) + err := workflow.ExecuteActivity(ctx, activities.WatchSpaceChange, req).Get(ctx, nil) if err != nil { logger.Error("failed to watch space change", "error", err, "req", req) return err } // Watch repo relation - err = workflow.ExecuteActivity(ctx, activity.WatchRepoRelation, req, config).Get(ctx, nil) + err = workflow.ExecuteActivity(ctx, activities.WatchRepoRelation, req).Get(ctx, nil) if err != nil { logger.Error("failed to watch repo relation", "error", err, "req", req) return err } // Set repo update time - err = workflow.ExecuteActivity(ctx, activity.SetRepoUpdateTime, req, config).Get(ctx, nil) + err = workflow.ExecuteActivity(ctx, activities.SetRepoUpdateTime, req).Get(ctx, nil) if err != nil { logger.Error("failed to set repo update time", "error", err, "req", req) return err } // Update repo infos - err = workflow.ExecuteActivity(ctx, activity.UpdateRepoInfos, req, config).Get(ctx, nil) + err = workflow.ExecuteActivity(ctx, activities.UpdateRepoInfos, req).Get(ctx, nil) if err != nil { logger.Error("failed to update repo infos", "error", err, "req", req) return err } // Sensitive check - err = workflow.ExecuteActivity(ctx, activity.SensitiveCheck, req, config).Get(ctx, nil) + err = workflow.ExecuteActivity(ctx, activities.SensitiveCheck, req).Get(ctx, nil) if err != nil { logger.Error("failed to sensitive check", "error", err, "req", req) return err diff --git a/api/workflow/worker.go b/api/workflow/worker.go deleted file mode 100644 index bad5bd62..00000000 --- a/api/workflow/worker.go +++ /dev/null @@ -1,49 +0,0 @@ -package workflow - -import ( - "fmt" - - "go.temporal.io/sdk/client" - "go.temporal.io/sdk/worker" - "opencsg.com/csghub-server/api/workflow/activity" - "opencsg.com/csghub-server/common/config" -) - -const HandlePushQueueName = "workflow_handle_push_queue" - -var ( - wfWorker worker.Worker - wfClient client.Client -) - -func StartWorker(config *config.Config) error { - var err error - wfClient, err = client.Dial(client.Options{ - HostPort: config.WorkFLow.Endpoint, - }) - if err != nil { - return fmt.Errorf("unable to create workflow client, error:%w", err) - } - wfWorker = worker.New(wfClient, HandlePushQueueName, worker.Options{}) - wfWorker.RegisterWorkflow(HandlePushWorkflow) - wfWorker.RegisterActivity(activity.WatchSpaceChange) - wfWorker.RegisterActivity(activity.WatchRepoRelation) - wfWorker.RegisterActivity(activity.SetRepoUpdateTime) - wfWorker.RegisterActivity(activity.UpdateRepoInfos) - wfWorker.RegisterActivity(activity.SensitiveCheck) - - return wfWorker.Start() -} - -func StopWorker() { - if wfWorker != nil { - wfWorker.Stop() - } - if wfClient != nil { - wfClient.Close() - } -} - -func GetWorkflowClient() client.Client { - return wfClient -} diff --git a/api/workflow/worker_ce.go b/api/workflow/worker_ce.go new file mode 100644 index 00000000..0af82e7a --- /dev/null +++ b/api/workflow/worker_ce.go @@ -0,0 +1,80 @@ +//go:build !ee && !saas + +package workflow + +import ( + "fmt" + + "go.temporal.io/sdk/client" + "go.temporal.io/sdk/worker" + "opencsg.com/csghub-server/api/workflow/activity" + "opencsg.com/csghub-server/builder/git" + "opencsg.com/csghub-server/builder/git/gitserver" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/builder/temporal" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/component/callback" +) + +const HandlePushQueueName = "workflow_handle_push_queue" + +var activities activity.Activities + +func StartWorkflow(cfg *config.Config) error { + gitcallback, err := callback.NewGitCallback(cfg) + if err != nil { + return err + } + recom, err := component.NewRecomComponent(cfg) + if err != nil { + return err + } + gitserver, err := git.NewGitServer(cfg) + if err != nil { + return err + } + multisync, err := component.NewMultiSyncComponent(cfg) + if err != nil { + return err + } + temporalClient, err := client.Dial(client.Options{ + HostPort: cfg.WorkFLow.Endpoint, + }) + if err != nil { + return fmt.Errorf("unable to create workflow client, error: %w", err) + } + client, err := temporal.NewClient(temporalClient) + if err != nil { + return err + } + return StartWorkflowDI( + cfg, gitcallback, recom, + gitserver, multisync, database.NewSyncClientSettingStore(), client, + ) +} + +func StartWorkflowDI( + cfg *config.Config, + callback callback.GitCallbackComponent, + recom component.RecomComponent, + gitServer gitserver.GitServer, + multisync component.MultiSyncComponent, + syncClientSetting database.SyncClientSettingStore, + temporalClient temporal.Client, +) error { + worker := temporalClient.NewWorker(HandlePushQueueName, worker.Options{}) + act := activity.NewActivities(cfg, callback, recom, gitServer, multisync, syncClientSetting) + worker.RegisterActivity(act) + + worker.RegisterWorkflow(HandlePushWorkflow) + + RegisterCronWorker(cfg, temporalClient) + + err := temporalClient.Start() + if err != nil { + return fmt.Errorf("failed to start worker: %w", err) + } + return nil + +} diff --git a/api/workflow/workflow_ce_test.go b/api/workflow/workflow_ce_test.go new file mode 100644 index 00000000..1e499d60 --- /dev/null +++ b/api/workflow/workflow_ce_test.go @@ -0,0 +1,54 @@ +//go:build !ee && !saas + +package workflow_test + +import ( + "testing" + + "github.com/stretchr/testify/mock" + "go.temporal.io/sdk/testsuite" + mock_git "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/git/gitserver" + mock_temporal "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/temporal" + mock_component "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + mock_callback "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component/callback" + "opencsg.com/csghub-server/api/workflow" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/tests" +) + +func newWorkflowTester(t *testing.T) (*workflowTester, error) { + suite := testsuite.WorkflowTestSuite{} + tester := &workflowTester{env: suite.NewTestWorkflowEnvironment()} + + // Mock the dependencies + tester.mocks.stores = tests.NewMockStores(t) + + mcb := mock_callback.NewMockGitCallbackComponent(t) + tester.mocks.callback = mcb + + mr := mock_component.NewMockRecomComponent(t) + tester.mocks.recom = mr + + mm := mock_component.NewMockMultiSyncComponent(t) + tester.mocks.multisync = mm + + mg := mock_git.NewMockGitServer(t) + tester.mocks.gitServer = mg + + mtc := mock_temporal.NewMockClient(t) + mtc.EXPECT().NewWorker(workflow.HandlePushQueueName, mock.Anything).Return(tester.env) + mtc.EXPECT().NewWorker(workflow.CronJobQueueName, mock.Anything).Return(tester.env) + mtc.EXPECT().Start().Return(nil) + tester.mocks.temporal = mtc + + cfg := &config.Config{} + + err := workflow.StartWorkflowDI( + cfg, mcb, mr, mg, mm, tester.mocks.stores.SyncClientSettingMock(), mtc, + ) + + if err != nil { + return nil, err + } + return tester, nil +} diff --git a/api/workflow/workflow_test.go b/api/workflow/workflow_test.go new file mode 100644 index 00000000..6f29e91e --- /dev/null +++ b/api/workflow/workflow_test.go @@ -0,0 +1,74 @@ +package workflow_test + +import ( + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.temporal.io/sdk/testsuite" + mock_git "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/git/gitserver" + mock_temporal "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/temporal" + mock_component "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + mock_callback "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component/callback" + "opencsg.com/csghub-server/api/workflow" + "opencsg.com/csghub-server/builder/multisync" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" + "opencsg.com/csghub-server/common/types" +) + +type workflowTester struct { + env *testsuite.TestWorkflowEnvironment + mocks struct { + callback *mock_callback.MockGitCallbackComponent + recom *mock_component.MockRecomComponent + multisync *mock_component.MockMultiSyncComponent + gitServer *mock_git.MockGitServer + temporal *mock_temporal.MockClient + stores *tests.MockStores + } +} + +func TestWorkflow_CalcRecomScoreWorkflow(t *testing.T) { + tester, err := newWorkflowTester(t) + require.NoError(t, err) + + tester.mocks.recom.EXPECT().CalculateRecomScore(mock.Anything).Return() + tester.env.ExecuteWorkflow(workflow.CalcRecomScoreWorkflow) + require.True(t, tester.env.IsWorkflowCompleted()) + require.NoError(t, tester.env.GetWorkflowError()) +} + +func TestWorkflow_SyncAsClient(t *testing.T) { + tester, err := newWorkflowTester(t) + require.NoError(t, err) + + tester.mocks.stores.SyncClientSettingMock().EXPECT().First(mock.Anything).Return( + &database.SyncClientSetting{Token: "tk"}, nil, + ) + tester.mocks.multisync.EXPECT().SyncAsClient( + mock.Anything, multisync.FromOpenCSG("", "tk"), + ).Return(nil) + + tester.env.ExecuteWorkflow(workflow.SyncAsClientWorkflow) + require.True(t, tester.env.IsWorkflowCompleted()) + require.NoError(t, tester.env.GetWorkflowError()) + +} + +func TestWorkflow_HandlePushWorkflow(t *testing.T) { + tester, err := newWorkflowTester(t) + require.NoError(t, err) + + tester.mocks.callback.EXPECT().SetRepoVisibility(true).Return() + tester.mocks.callback.EXPECT().WatchSpaceChange(mock.Anything, &types.GiteaCallbackPushReq{}).Return(nil) + tester.mocks.callback.EXPECT().WatchRepoRelation(mock.Anything, &types.GiteaCallbackPushReq{}).Return(nil) + tester.mocks.callback.EXPECT().SetRepoUpdateTime(mock.Anything, &types.GiteaCallbackPushReq{}).Return(nil) + tester.mocks.callback.EXPECT().UpdateRepoInfos(mock.Anything, &types.GiteaCallbackPushReq{}).Return(nil) + tester.mocks.callback.EXPECT().SensitiveCheck(mock.Anything, &types.GiteaCallbackPushReq{}).Return(nil) + + tester.env.ExecuteWorkflow(workflow.HandlePushWorkflow, &types.GiteaCallbackPushReq{}) + require.True(t, tester.env.IsWorkflowCompleted()) + require.NoError(t, tester.env.GetWorkflowError()) + +} diff --git a/builder/temporal/temporal.go b/builder/temporal/temporal.go new file mode 100644 index 00000000..43439db9 --- /dev/null +++ b/builder/temporal/temporal.go @@ -0,0 +1,61 @@ +package temporal + +import ( + "go.temporal.io/sdk/client" + "go.temporal.io/sdk/worker" +) + +type Client interface { + client.Client + NewWorker(queue string, options worker.Options) worker.Registry + Start() error + Stop() +} + +type clientImpl struct { + client.Client + workers []worker.Worker +} + +var _client Client = &clientImpl{} + +func NewClient(temporalClient client.Client) (*clientImpl, error) { + c := _client.(*clientImpl) + c.Client = temporalClient + + return c, nil +} + +func (c *clientImpl) NewWorker(queue string, options worker.Options) worker.Registry { + w := worker.New(c.Client, queue, options) + c.workers = append(c.workers, w) + return w +} + +func (c *clientImpl) Start() error { + for _, worker := range c.workers { + err := worker.Start() + if err != nil { + return err + } + } + return nil +} + +func (c *clientImpl) Stop() { + for _, worker := range c.workers { + worker.Stop() + } + + if c.Client != nil { + c.Close() + } +} + +func GetClient() Client { + return _client +} + +func Stop() { + _client.Close() +} diff --git a/builder/temporal/temporal_test.go b/builder/temporal/temporal_test.go new file mode 100644 index 00000000..0d3defcf --- /dev/null +++ b/builder/temporal/temporal_test.go @@ -0,0 +1,71 @@ +package temporal_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "go.temporal.io/sdk/client" + "go.temporal.io/sdk/worker" + "go.temporal.io/sdk/workflow" + "go.temporal.io/server/temporaltest" + "opencsg.com/csghub-server/builder/temporal" +) + +type Tester struct { + client temporal.Client + counter int + total int +} + +func (t *Tester) Count(ctx workflow.Context) error { + t.counter += 1 + return nil +} + +func (t *Tester) Add(ctx workflow.Context) error { + t.total += 1 + return nil +} + +func TestTemporalClient(t *testing.T) { + ts := temporaltest.NewServer(temporaltest.WithT(t)) + defer ts.Stop() + c := ts.GetDefaultClient() + + tester := &Tester{client: temporal.GetClient()} + _, err := temporal.NewClient(c) + require.NoError(t, err) + + worker1 := tester.client.NewWorker("q1", worker.Options{}) + worker1.RegisterWorkflow(tester.Count) + worker2 := tester.client.NewWorker("q2", worker.Options{}) + worker2.RegisterWorkflow(tester.Add) + + err = tester.client.Start() + require.NoError(t, err) + + r, err := tester.client.ExecuteWorkflow(context.TODO(), client.StartWorkflowOptions{ + TaskQueue: "q1", + }, tester.Count) + require.NoError(t, err) + err = r.Get(context.Background(), nil) + require.NoError(t, err) + + r, err = tester.client.ExecuteWorkflow(context.TODO(), client.StartWorkflowOptions{ + TaskQueue: "q2", + }, tester.Add) + require.NoError(t, err) + err = r.Get(context.Background(), nil) + require.NoError(t, err) + + require.Equal(t, 1, tester.counter) + require.Equal(t, 1, tester.total) + + temporal.Stop() + _, err = tester.client.ExecuteWorkflow(context.TODO(), client.StartWorkflowOptions{ + TaskQueue: "q1", + }, tester.Count) + require.Error(t, err) + +} diff --git a/cmd/csghub-server/cmd/mirror/repo_sync.go b/cmd/csghub-server/cmd/mirror/repo_sync.go index 25699681..cabfc1a8 100644 --- a/cmd/csghub-server/cmd/mirror/repo_sync.go +++ b/cmd/csghub-server/cmd/mirror/repo_sync.go @@ -1,11 +1,10 @@ package mirror import ( - "fmt" - "github.com/spf13/cobra" "opencsg.com/csghub-server/api/workflow" "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/builder/temporal" "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/mirror" ) @@ -30,13 +29,14 @@ var repoSyncCmd = &cobra.Command{ if err != nil { return err } - err = workflow.StartWorker(cfg) + + err = workflow.StartWorkflow(cfg) if err != nil { - return fmt.Errorf("failed to start worker: %w", err) + return err } repoSYncer.Run() - workflow.StopWorker() + temporal.Stop() return nil }, diff --git a/cmd/csghub-server/cmd/start/server.go b/cmd/csghub-server/cmd/start/server.go index c9e66bb6..c1bac060 100644 --- a/cmd/csghub-server/cmd/start/server.go +++ b/cmd/csghub-server/cmd/start/server.go @@ -13,6 +13,7 @@ import ( "opencsg.com/csghub-server/builder/deploy/common" "opencsg.com/csghub-server/builder/event" "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/builder/temporal" "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/types" "opencsg.com/csghub-server/docs" @@ -83,25 +84,16 @@ var serverCmd = &cobra.Command{ if err != nil { return fmt.Errorf("failed to init deploy: %w", err) } - err = workflow.StartWorker(cfg) + + err = workflow.StartWorkflow(cfg) if err != nil { - return fmt.Errorf("failed to start worker: %w", err) + return err } r, err := router.NewRouter(cfg, enableSwagger) if err != nil { return fmt.Errorf("failed to init router: %w", err) } - err = workflow.RegisterCronJobs(cfg) - if err != nil { - return fmt.Errorf("failed to register cron jobs: %w", err) - } - - err = workflow.StartCronWorker(cfg) - if err != nil { - return fmt.Errorf("failed to start cron worker: %w", err) - } - server := httpbase.NewGracefulServer( httpbase.GraceServerOpt{ Port: cfg.APIServer.Port, @@ -120,7 +112,7 @@ var serverCmd = &cobra.Command{ } server.Run() - workflow.StopWorker() + temporal.Stop() return nil }, diff --git a/cmd/csghub-server/cmd/trigger/git_callback.go b/cmd/csghub-server/cmd/trigger/git_callback.go index 24fb4509..f108511a 100644 --- a/cmd/csghub-server/cmd/trigger/git_callback.go +++ b/cmd/csghub-server/cmd/trigger/git_callback.go @@ -11,6 +11,7 @@ import ( "opencsg.com/csghub-server/api/workflow" "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/builder/temporal" "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/types" ) @@ -39,11 +40,12 @@ var gitCallbackCmd = &cobra.Command{ if err != nil { return fmt.Errorf("failed to load config: %w", err) } - err = workflow.StartWorker(config) + + err = workflow.StartWorkflow(config) if err != nil { - slog.Error("failed to start worker", slog.Any("error", err)) - return fmt.Errorf("failed to start worker: %w", err) + return err } + if len(repoPaths) > 0 { for _, rp := range repoPaths { parts := strings.Split(rp, "/") @@ -85,14 +87,13 @@ var gitCallbackCmd = &cobra.Command{ req.Commits = append(req.Commits, types.GiteaCallbackPushReq_Commit{}) req.Commits[0].Added = append(req.Commits[0].Added, filePaths...) //start workflow to handle push request - workflowClient := workflow.GetWorkflowClient() + workflowClient := temporal.GetClient() workflowOptions := client.StartWorkflowOptions{ TaskQueue: workflow.HandlePushQueueName, } - we, err := workflowClient.ExecuteWorkflow(context.Background(), workflowOptions, workflow.HandlePushWorkflow, - req, - config, + we, err := workflowClient.ExecuteWorkflow( + context.Background(), workflowOptions, workflow.HandlePushWorkflow, req, ) if err != nil { slog.Error("failed to handle git push callback", slog.String("repo", repo.Path), slog.Any("error", err)) diff --git a/component/callback/git_callback.go b/component/callback/git_callback.go index 0810f1e4..e022ff33 100644 --- a/component/callback/git_callback.go +++ b/component/callback/git_callback.go @@ -27,6 +27,7 @@ type GitCallbackComponent interface { WatchRepoRelation(ctx context.Context, req *types.GiteaCallbackPushReq) error SetRepoUpdateTime(ctx context.Context, req *types.GiteaCallbackPushReq) error UpdateRepoInfos(ctx context.Context, req *types.GiteaCallbackPushReq) error + SensitiveCheck(ctx context.Context, req *types.GiteaCallbackPushReq) error } type gitCallbackComponentImpl struct { diff --git a/go.mod b/go.mod index 93e4fadf..87eab4f6 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module opencsg.com/csghub-server -go 1.23 +go 1.23.2 + +toolchain go1.23.4 require ( github.com/DATA-DOG/go-txdb v0.2.0 @@ -26,7 +28,7 @@ require ( github.com/sethvargo/go-envconfig v1.1.0 github.com/spf13/cast v1.5.1 github.com/spf13/cobra v1.8.0 - github.com/stretchr/testify v1.9.0 + github.com/stretchr/testify v1.10.0 github.com/swaggo/files v1.0.1 github.com/swaggo/gin-swagger v1.6.0 github.com/swaggo/swag v1.16.2 @@ -39,9 +41,10 @@ require ( github.com/uptrace/bun/driver/sqliteshim v1.1.16 github.com/uptrace/bun/extra/bundebug v1.1.16 gitlab.com/gitlab-org/gitaly/v16 v16.11.8 - go.temporal.io/api v1.40.0 - go.temporal.io/sdk v1.30.0 - google.golang.org/grpc v1.66.0 + go.temporal.io/api v1.43.0 + go.temporal.io/sdk v1.31.0 + go.temporal.io/server v1.26.2 + google.golang.org/grpc v1.67.1 gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 k8s.io/api v0.31.3 @@ -51,20 +54,31 @@ require ( ) require ( - cloud.google.com/go/compute/metadata v0.3.0 // indirect - cloud.google.com/go/monitoring v1.18.0 // indirect - cloud.google.com/go/trace v1.10.5 // indirect + cloud.google.com/go v0.114.0 // indirect + cloud.google.com/go/auth v0.5.0 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect + cloud.google.com/go/compute/metadata v0.5.0 // indirect + cloud.google.com/go/iam v1.1.8 // indirect + cloud.google.com/go/monitoring v1.19.0 // indirect + cloud.google.com/go/storage v1.41.0 // indirect + cloud.google.com/go/trace v1.10.7 // indirect contrib.go.opencensus.io/exporter/stackdriver v0.13.14 // indirect dario.cat/mergo v1.0.0 // indirect + filippo.io/edwards25519 v1.1.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/DataDog/datadog-go v4.4.0+incompatible // indirect github.com/DataDog/sketches-go v1.0.0 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect - github.com/aws/aws-sdk-go v1.50.36 // indirect + github.com/apache/thrift v0.16.0 // indirect + github.com/aws/aws-sdk-go v1.53.15 // indirect github.com/beevik/ntp v1.3.1 // indirect + github.com/benbjohnson/clock v1.1.0 // indirect github.com/beorn7/perks v1.0.1 // indirect + github.com/blang/semver/v4 v4.0.0 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect - github.com/cenkalti/backoff/v4 v4.2.1 // indirect + github.com/cactus/go-statsd-client/statsd v0.0.0-20200423205355-cb0885a1018c // indirect + github.com/cactus/go-statsd-client/v5 v5.1.0 // indirect + github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect @@ -72,27 +86,41 @@ require ( github.com/containerd/log v0.1.0 // indirect github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect + github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/docker v27.1.1+incompatible // indirect github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect + github.com/emirpasic/gods v1.18.1 // indirect github.com/evanphx/json-patch v5.9.0+incompatible // indirect github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.2.6 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/gocql/gocql v1.6.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/mock v1.6.0 // indirect + github.com/golang/snappy v0.0.4 // indirect github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect - github.com/googleapis/gax-go/v2 v2.12.2 // indirect + github.com/googleapis/gax-go/v2 v2.12.4 // indirect + github.com/gorilla/mux v1.8.1 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.0 // indirect github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 // indirect + github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/hashicorp/yamux v0.1.2-0.20220728231024-8f49b6f63f18 // indirect + github.com/iancoleman/strcase v0.3.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx/v5 v5.6.0 // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jmoiron/sqlx v1.3.4 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20210210170715-a8dfcb80d3a7 // indirect github.com/lightstep/lightstep-tracer-go v0.25.0 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect @@ -107,26 +135,36 @@ require ( github.com/nats-io/nkeys v0.4.7 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect - github.com/nexus-rpc/sdk-go v0.0.11 // indirect + github.com/nexus-rpc/sdk-go v0.1.0 // indirect github.com/oklog/ulid/v2 v2.0.2 // indirect + github.com/olivere/elastic/v7 v7.0.32 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect github.com/pborman/uuid v1.2.1 // indirect github.com/philhofer/fwd v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect - github.com/prometheus/client_golang v1.19.0 // indirect - github.com/prometheus/client_model v0.6.0 // indirect - github.com/prometheus/common v0.48.0 // indirect - github.com/prometheus/procfs v0.12.0 // indirect + github.com/prometheus/client_golang v1.20.4 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.60.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect github.com/prometheus/prometheus v0.50.1 // indirect + github.com/rcrowley/go-metrics v0.0.0-20141108142129-dee209f2455f // indirect github.com/robfig/cron v1.2.0 // indirect + github.com/robfig/cron/v3 v3.0.1 // indirect github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a // indirect github.com/shirou/gopsutil/v3 v3.23.12 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect + github.com/sony/gobreaker v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect + github.com/temporalio/ringpop-go v0.0.0-20241119001152-e505ebd8f887 // indirect + github.com/temporalio/sqlparser v0.0.0-20231115171017-f4060bcfa6cb // indirect + github.com/temporalio/tchannel-go v1.22.1-0.20240528171429-1db37fdea938 // indirect github.com/tinylib/msgp v1.1.2 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect + github.com/twmb/murmur3 v1.1.8 // indirect + github.com/uber-common/bark v1.0.0 // indirect + github.com/uber-go/tally/v4 v4.1.17-0.20240412215630-22fe011f5ff0 // indirect github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect github.com/uber/jaeger-lib v2.4.1+incompatible // indirect github.com/x448/float16 v0.8.4 // indirect @@ -134,20 +172,33 @@ require ( gitlab.com/gitlab-org/go/reopen v1.0.0 // indirect gitlab.com/gitlab-org/labkit v1.21.2 // indirect go.opencensus.io v0.24.0 // indirect - go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect - go.opentelemetry.io/otel v1.24.0 // indirect - go.opentelemetry.io/otel/metric v1.24.0 // indirect - go.opentelemetry.io/otel/trace v1.24.0 // indirect - golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.56.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.52.0 // indirect + go.opentelemetry.io/otel v1.31.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.31.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.31.0 // indirect + go.opentelemetry.io/otel/exporters/prometheus v0.53.0 // indirect + go.opentelemetry.io/otel/metric v1.31.0 // indirect + go.opentelemetry.io/otel/sdk v1.31.0 // indirect + go.opentelemetry.io/otel/sdk/metric v1.31.0 // indirect + go.opentelemetry.io/otel/trace v1.31.0 // indirect + go.opentelemetry.io/proto/otlp v1.3.1 // indirect + go.temporal.io/version v0.3.0 // indirect + go.uber.org/dig v1.17.1 // indirect + go.uber.org/fx v1.22.0 // indirect + go.uber.org/mock v0.4.0 // indirect + golang.org/x/exp v0.0.0-20240531132922-fd00a4e0eefc // indirect golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect - google.golang.org/api v0.169.0 // indirect - google.golang.org/genproto v0.0.0-20240311173647-c811ad7063a7 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240827150818-7e3bb234dfed // indirect + google.golang.org/api v0.182.0 // indirect + google.golang.org/genproto v0.0.0-20240528184218-531527333157 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9 // indirect gopkg.in/DataDog/dd-trace-go.v1 v1.32.0 // indirect gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect - modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect + gopkg.in/go-jose/go-jose.v2 v2.6.3 // indirect + gopkg.in/validator.v2 v2.0.1 // indirect + modernc.org/gc/v3 v3.0.0-20240304020402-f0dba7c97c2b // indirect ) require ( @@ -172,7 +223,7 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/evanphx/json-patch/v5 v5.8.0 // indirect - github.com/fatih/color v1.15.0 // indirect + github.com/fatih/color v1.17.0 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-fed/httpsig v1.1.0 // indirect @@ -205,7 +256,7 @@ require ( github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/compress v1.17.4 // indirect + github.com/klauspost/compress v1.17.9 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect @@ -234,11 +285,11 @@ require ( github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect - go.uber.org/zap v1.26.0 // indirect + go.uber.org/zap v1.27.0 // indirect golang.org/x/arch v0.8.0 // indirect golang.org/x/crypto v0.31.0 - golang.org/x/net v0.28.0 // indirect - golang.org/x/oauth2 v0.22.0 // indirect + golang.org/x/net v0.31.0 // indirect + golang.org/x/oauth2 v0.23.0 // indirect golang.org/x/sync v0.10.0 // indirect golang.org/x/sys v0.28.0 // indirect golang.org/x/term v0.27.0 // indirect @@ -246,7 +297,7 @@ require ( golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect - google.golang.org/protobuf v1.34.2 + google.golang.org/protobuf v1.35.1 gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect k8s.io/klog/v2 v2.130.1 // indirect @@ -255,10 +306,10 @@ require ( knative.dev/networking v0.0.0-20240116081125-ce0738abf051 // indirect knative.dev/pkg v0.0.0-20240116073220-b488e7be5902 // indirect mellium.im/sasl v0.3.1 // indirect - modernc.org/libc v1.41.0 // indirect + modernc.org/libc v1.55.3 // indirect modernc.org/mathutil v1.6.0 // indirect - modernc.org/memory v1.7.2 // indirect - modernc.org/sqlite v1.29.1 // indirect + modernc.org/memory v1.8.0 // indirect + modernc.org/sqlite v1.34.1 // indirect modernc.org/strutil v1.2.0 // indirect modernc.org/token v1.1.0 // indirect sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect diff --git a/go.sum b/go.sum index 52183d31..53f12bcd 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,22 @@ cloud.google.com/go v0.16.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= -cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= -cloud.google.com/go/monitoring v1.18.0 h1:NfkDLQDG2UR3WYZVQE8kwSbUIEyIqJUPl+aOQdFH1T4= -cloud.google.com/go/monitoring v1.18.0/go.mod h1:c92vVBCeq/OB4Ioyo+NbN2U7tlg5ZH41PZcdvfc+Lcg= -cloud.google.com/go/trace v1.10.5 h1:0pr4lIKJ5XZFYD9GtxXEWr0KkVeigc3wlGpZco0X1oA= -cloud.google.com/go/trace v1.10.5/go.mod h1:9hjCV1nGBCtXbAE4YK7OqJ8pmPYSxPA0I67JwRd5s3M= +cloud.google.com/go v0.114.0 h1:OIPFAdfrFDFO2ve2U7r/H5SwSbBzEdrBdE7xkgwc+kY= +cloud.google.com/go v0.114.0/go.mod h1:ZV9La5YYxctro1HTPug5lXH/GefROyW8PPD4T8n9J8E= +cloud.google.com/go/auth v0.5.0 h1:GtSZfKJkPrZi/s3AkiHnUYVI4dTP/kg8+I3unm0omag= +cloud.google.com/go/auth v0.5.0/go.mod h1:Kqvlz1cf1sNA0D+sYJnkPQOP+JMHkuHeIgVmCRtZOLc= +cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= +cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= +cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY= +cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY= +cloud.google.com/go/iam v1.1.8 h1:r7umDwhj+BQyz0ScZMp4QrGXjSTI3ZINnpgU2nlB/K0= +cloud.google.com/go/iam v1.1.8/go.mod h1:GvE6lyMmfxXauzNq8NbgJbeVQNspG+tcdL/W8QO1+zE= +cloud.google.com/go/monitoring v1.19.0 h1:NCXf8hfQi+Kmr56QJezXRZ6GPb80ZI7El1XztyUuLQI= +cloud.google.com/go/monitoring v1.19.0/go.mod h1:25IeMR5cQ5BoZ8j1eogHE5VPJLlReQ7zFp5OiLgiGZw= +cloud.google.com/go/storage v1.41.0 h1:RusiwatSu6lHeEXe3kglxakAmAbfV+rhtPqA6i8RBx0= +cloud.google.com/go/storage v1.41.0/go.mod h1:J1WCa/Z2FcgdEDuPUY8DxT5I+d9mFKsCepp5vR6Sq80= +cloud.google.com/go/trace v1.10.7 h1:gK8z2BIJQ3KIYGddw9RJLne5Fx0FEXkrEQzPaeEYVvk= +cloud.google.com/go/trace v1.10.7/go.mod h1:qk3eiKmZX0ar2dzIJN/3QhY2PIFh1eqcIdaN5uEjQPM= contrib.go.opencensus.io/exporter/ocagent v0.7.1-0.20200907061046-05415f1de66d h1:LblfooH1lKOpp1hIhukktmSAxFkqMPFk9KR6iZ0MJNI= contrib.go.opencensus.io/exporter/ocagent v0.7.1-0.20200907061046-05415f1de66d/go.mod h1:IshRmMJBhDfFj5Y67nVhMYTTIze91RUeT73ipWKs/GY= contrib.go.opencensus.io/exporter/prometheus v0.4.2 h1:sqfsYl5GIY/L570iT+l93ehxaWJs2/OwXtiWwew3oAg= @@ -75,17 +85,29 @@ github.com/aliyun/credentials-go v1.3.1/go.mod h1:8jKYhQuDawt8x2+fusqa1Y6mPxemTs github.com/aliyun/credentials-go v1.3.2 h1:L4WppI9rctC8PdlMgyTkF8bBsy9pyKQEzBD1bHMRl+g= github.com/aliyun/credentials-go v1.3.2/go.mod h1:tlpz4uys4Rn7Ik4/piGRrTbXy2uLKvePgQJJduE+Y5c= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= +github.com/apache/thrift v0.16.0 h1:qEy6UW60iVOlUy+b9ZR0d5WzUWYGOo4HfopoyBaNmoY= +github.com/apache/thrift v0.16.0/go.mod h1:PHK3hniurgQaNMZYaCLEqXKsYK8upmhPbmdP2FXSqgU= github.com/argoproj/argo-workflows/v3 v3.5.13 h1:d+t+nTBgfHsTTuw+KL3CmBrjvo9/VlRcMNm+FRf8FBA= github.com/argoproj/argo-workflows/v3 v3.5.13/go.mod h1:DecB01a8UXDCjtIh0udY8XfIMIRrWrlbob7hk/uMmg0= -github.com/aws/aws-sdk-go v1.50.36 h1:PjWXHwZPuTLMR1NIb8nEjLucZBMzmf84TLoLbD8BZqk= -github.com/aws/aws-sdk-go v1.50.36/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= +github.com/aws/aws-sdk-go v1.53.15 h1:FtZmkg7xM8RfP2oY6p7xdKBYrRgkITk9yve2QV7N938= +github.com/aws/aws-sdk-go v1.53.15/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= github.com/beevik/ntp v1.3.1 h1:Y/srlT8L1yQr58kyPWFPZIxRL8ttx2SRIpVYJqZIlAM= github.com/beevik/ntp v1.3.1/go.mod h1:fT6PylBq86Tsq23ZMEe47b7QQrZfYBFPnpzt0a9kJxw= +github.com/benbjohnson/clock v0.0.0-20160125162948-a620c1cc9866/go.mod h1:UMqtWQTnOe4byzwe7Zhwh8f8s+36uszN51sJrSIZlTE= +github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY= +github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= +github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= +github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= github.com/blendle/zapdriver v1.3.1 h1:C3dydBOWYRiOk+B8X9IVZ5IOe+7cl+tGOexN4QqHfpE= github.com/blendle/zapdriver v1.3.1/go.mod h1:mdXfREi6u5MArG4j9fewC+FGnXaBR+T4Ox4J2u4eHCc= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= +github.com/bmizerany/perks v0.0.0-20141205001514-d9a9656a3a4b h1:AP/Y7sqYicnjGDfD5VcY4CIfh1hRXBUavxrvELjTiOE= +github.com/bmizerany/perks v0.0.0-20141205001514-d9a9656a3a4b/go.mod h1:ac9efd0D1fsDb3EJvhqgXRbFx7bs2wqZ10HQPeU8U/Q= github.com/bradfitz/gomemcache v0.0.0-20170208213004-1952afaa557d/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= @@ -97,10 +119,15 @@ github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/cactus/go-statsd-client/statsd v0.0.0-20200423205355-cb0885a1018c h1:HIGF0r/56+7fuIZw2V4isE22MK6xpxWx7BbV8dJ290w= +github.com/cactus/go-statsd-client/statsd v0.0.0-20200423205355-cb0885a1018c/go.mod h1:l/bIBLeOl9eX+wxJAzxS4TveKRtAqlyDpHjhkfO0MEI= +github.com/cactus/go-statsd-client/v4 v4.0.0/go.mod h1:m73kwJp6TN0Ja9P6ycdZhWM1MlfxY/95WZ//IptPQ+Y= +github.com/cactus/go-statsd-client/v5 v5.1.0 h1:sbbdfIl9PgisjEoXzvXI1lwUKWElngsjJKaZeC021P4= +github.com/cactus/go-statsd-client/v5 v5.1.0/go.mod h1:COEvJ1E+/E2L4q6QE5CkjWPi4eeDw9maJBMIuMPBZbY= github.com/casdoor/casdoor-go-sdk v0.41.0 h1:mqqoc1Jub34/OkAQqjeASRAaiy7x/5ZWtGObI08cfEk= github.com/casdoor/casdoor-go-sdk v0.41.0/go.mod h1:cMnkCQJgMYpgAlgEx8reSt1AVaDIQLcJ1zk5pzBaz+4= -github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= -github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/census-instrumentation/opencensus-proto v0.4.1 h1:iKLQ0xPNFxR/2hzXZMrBo8f1j86j5WHzznCCQxV/b8g= github.com/census-instrumentation/opencensus-proto v0.4.1/go.mod h1:4T9NM4+4Vw91VeyqjLS6ao50K5bOcLKN6Q42XnYaRYw= @@ -135,6 +162,7 @@ github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.21 h1:1/QdRyBaHHJP61QkWMXlOIBfsgdDeeKfK8SYVUWJKf0= github.com/creack/pty v1.1.21/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/crossdock/crossdock-go v0.0.0-20160816171116-049aabb0122b/go.mod h1:v9FBN7gdVTpiD/+LZ7Po0UKvROyT87uLVxTHVky/dlQ= github.com/d5/tengo/v2 v2.17.0 h1:BWUN9NoJzw48jZKiYDXDIF3QrIVZRm1uV1gTzeZ2lqM= github.com/d5/tengo/v2 v2.17.0/go.mod h1:XRGjEs5I9jYIKTxly6HCF8oiiilk5E/RYXOZ5b0DZC8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -143,6 +171,9 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davidmz/go-pageant v1.0.2 h1:bPblRCh5jGU+Uptpz6LgMZGD5hJoOt7otgT454WvHn0= github.com/davidmz/go-pageant v1.0.2/go.mod h1:P2EDDnMqIwG5Rrp05dTRITj9z2zpGcD9efWSkTNKLIE= +github.com/dgryski/go-farm v0.0.0-20140601200337-fc41e106ee0e/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da h1:aIftn67I1fkbMa512G+w+Pxci9hJPB8oMnkcP3iZF38= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= @@ -157,6 +188,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g= github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= +github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= @@ -167,11 +200,13 @@ github.com/evanphx/json-patch/v5 v5.8.0 h1:lRj6N9Nci7MvzrXuX6HFzU8XjmhPiXPlsKEy1 github.com/evanphx/json-patch/v5 v5.8.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ= github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a h1:yDWHCSQ40h88yih2JAcL6Ls/kVkSE8GFACTGVnMPruw= github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a/go.mod h1:7Ga40egUymuWXxAe151lTNnCv97MddSOVsjpPPkityA= -github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= -github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= +github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4= +github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= +github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= +github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY= github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.4.3-0.20170329110642-4da3e2cfbabc/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= @@ -196,6 +231,8 @@ github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= +github.com/go-faker/faker/v4 v4.4.1 h1:LY1jDgjVkBZWIhATCt+gkl0x9i/7wC61gZx73GTFb+Q= +github.com/go-faker/faker/v4 v4.4.1/go.mod h1:HRLrjis+tYsbFtIHufEPTAIzcZiRu0rS9EYl2Ccwme4= github.com/go-fed/httpsig v1.1.0 h1:9M+hb0jkEICD8/cAiNqEB66R87tTINszBRTjwjQzWcI= github.com/go-fed/httpsig v1.1.0/go.mod h1:RCMrTZvN1bJYtofsG4rd5NaO5obxQ5xBkdiS7xsT7bM= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= @@ -235,6 +272,7 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-stack/stack v1.6.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= @@ -245,6 +283,8 @@ github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1v github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/gocql/gocql v1.6.0 h1:IdFdOTbnpbd0pDhl4REKQDM+Q0SzKXQ1Yh+YZZ8T/qU= +github.com/gocql/gocql v1.6.0/go.mod h1:3gM2c4D3AnkISwBxGnMMsS8Oy4y2lhbPRsH4xnJrHG8= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A= @@ -262,6 +302,7 @@ github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/lint v0.0.0-20170918230701-e5d664eb928e/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -281,6 +322,9 @@ github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/gnostic-models v0.6.8 h1:yo/ABAfM5IMRsS1VnXjTBvUb61tFIHozhlYvRgGre9I= github.com/google/gnostic-models v0.6.8/go.mod h1:5n7qKqH0f5wFt+aWF8CW6pZLLNOfYuF5OpfBSENuI8U= github.com/google/go-cmp v0.1.1-0.20171103154506-982329095285/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -301,6 +345,8 @@ github.com/google/go-containerregistry v0.16.1/go.mod h1:u0qB2l7mvtWVR5kNcbFIhFY github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/martian/v3 v3.3.3 h1:DIhPTQrbPkgs2yJYdXU/eNACCG5DVQjySNRNlflZ9Fc= +github.com/google/martian/v3 v3.3.3/go.mod h1:iEPrYcgCF7jA9OtScMFQyAlZZ4YXTKEtJ1E6RWzmBA0= github.com/google/pprof v0.0.0-20210125172800-10e9aeb4a998/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af h1:kmjWCqn2qkEml422C2Rrd27c3VGxi6a/6HNq8QmHRKM= @@ -317,12 +363,14 @@ github.com/google/wire v0.6.0/go.mod h1:F4QhpQ9EDIdJ1Mbop/NZBRB+5yrR6qg3BnctaoUk github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= -github.com/googleapis/gax-go/v2 v2.12.2 h1:mhN09QQW1jEWeMF74zGR81R30z4VJzjZsfkUhuHF+DA= -github.com/googleapis/gax-go/v2 v2.12.2/go.mod h1:61M8vcyyXR2kqKFxKrfA22jaA8JGF7Dc8App1U3H6jc= +github.com/googleapis/gax-go/v2 v2.12.4 h1:9gWcmF85Wvq4ryPFvGFaOgPIs1AQX0d0bcbGw4Z96qg= +github.com/googleapis/gax-go/v2 v2.12.4/go.mod h1:KYEYLorsnIGDi/rPC8b5TdlB9kbKoFubselGIoBMCwI= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= +github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= +github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= @@ -336,6 +384,8 @@ github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4 github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 h1:asbCHRVmodnJTuQ3qamDwqVOIjwqUPTYmYuemVOx+Ys= github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0/go.mod h1:ggCgvZ2r7uOoQjOyu2Y1NhHmEPPzzuhWgcza5M1Ji1I= +github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= +github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= github.com/hashicorp/go-version v1.5.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY= github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= @@ -345,6 +395,8 @@ github.com/hashicorp/hcl v0.0.0-20170914154624-68e816d1c783/go.mod h1:oZtUIOe8dh github.com/hashicorp/yamux v0.1.2-0.20220728231024-8f49b6f63f18 h1:IVujPV6DRIu1fYF4zUHrfhkngJzmYjelXa+iSUiFZSI= github.com/hashicorp/yamux v0.1.2-0.20220728231024-8f49b6f63f18/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= +github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= @@ -355,12 +407,14 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= -github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= +github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jellydator/ttlcache/v2 v2.11.1 h1:AZGME43Eh2Vv3giG6GeqeLeFXxwxn1/qHItqWZl6U64= github.com/jellydator/ttlcache/v2 v2.11.1/go.mod h1:RtE5Snf0/57e+2cLWFYWCCsLas2Hy3c5Z4n14XmSvTI= +github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= +github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= @@ -368,6 +422,8 @@ github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9Y github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= +github.com/jmoiron/sqlx v1.3.4 h1:wv+0IJZfL5z0uZoUjlpKgHkgaFSYD+r9CfrXjEXsO7w= +github.com/jmoiron/sqlx v1.3.4/go.mod h1:2BljVx/86SuTyjE+aPYlHCTNvZrnJXghYGpNiXLBMCQ= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.5/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= @@ -379,8 +435,8 @@ github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfV github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= -github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= @@ -400,6 +456,7 @@ github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20210210170715-a8dfcb80d3a7 h1:YjW+hUb8Fh2S58z4av4t/0cBMK/Q0aP48RocCFsC8yI= @@ -423,6 +480,7 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4= @@ -469,14 +527,16 @@ github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/nexus-rpc/sdk-go v0.0.11 h1:qH3Us3spfp50t5ca775V1va2eE6z1zMQDZY4mvbw0CI= -github.com/nexus-rpc/sdk-go v0.0.11/go.mod h1:TpfkM2Cw0Rlk9drGkoiSMpFqflKTiQLWUNyKJjF8mKQ= +github.com/nexus-rpc/sdk-go v0.1.0 h1:PUL/0vEY1//WnqyEHT5ao4LBRQ6MeNUihmnNGn0xMWY= +github.com/nexus-rpc/sdk-go v0.1.0/go.mod h1:TpfkM2Cw0Rlk9drGkoiSMpFqflKTiQLWUNyKJjF8mKQ= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/oklog/ulid/v2 v2.0.2 h1:r4fFzBm+bv0wNKNh5eXTwU7i85y5x+uwkxCUTNVQqLc= github.com/oklog/ulid/v2 v2.0.2/go.mod h1:mtBL0Qe/0HAx6/a4Z30qxVIAL1eQDweXq5lxOEiwQ68= +github.com/olivere/elastic/v7 v7.0.32 h1:R7CXvbu8Eq+WlsLgxmKVKPox0oOwAE/2T9Si5BnvK6E= +github.com/olivere/elastic/v7 v7.0.32/go.mod h1:c7PVmLe3Fxq77PIfY/bZmxY/TAamBhCzZ8xDOE09a9k= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= @@ -517,31 +577,38 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= -github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU= -github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k= +github.com/prashantv/protectmem v0.0.0-20171002184600-e20412882b3a h1:AA9vgIBDjMHPC2McaGPojgV2dcI78ZC0TLNhYCXEKH8= +github.com/prashantv/protectmem v0.0.0-20171002184600-e20412882b3a/go.mod h1:lzZQ3Noex5pfAy7mkAeCjcBDteYU85uWWnJ/y6gKU8k= +github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI= +github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.6.0 h1:k1v3CzpSRUTrKMppY35TLwPvxHqBu0bYgxZzqGIgaos= -github.com/prometheus/client_model v0.6.0/go.mod h1:NTQHnmxFpouOD0DpvP4XujX3CdOAGQPoaGhyTchlyt8= -github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE= -github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc= -github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= -github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/common v0.60.0 h1:+V9PAREWNvJMAuJ1x1BaWl9dewMW4YrHZQbx0sJNllA= +github.com/prometheus/common v0.60.0/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/prometheus/prometheus v0.50.1 h1:N2L+DYrxqPh4WZStU+o1p/gQlBaqFbcLBTjlp3vpdXw= github.com/prometheus/prometheus v0.50.1/go.mod h1:FvE8dtQ1Ww63IlyKBn1V4s+zMwF9kHkVNkQBR1pM4CU= github.com/prometheus/statsd_exporter v0.22.7 h1:7Pji/i2GuhK6Lu7DHrtTkFmNBCudCPT1pX2CziuyQR0= github.com/prometheus/statsd_exporter v0.22.7/go.mod h1:N/TevpjkIh9ccs6nuzY3jQn9dFqnUakOjnEuMPJJJnI= +github.com/rcrowley/go-metrics v0.0.0-20141108142129-dee209f2455f h1:dfcuI1ZZzn8OXb0mYeJFo/0FzL/9eXT/sEzogrOzGc8= +github.com/rcrowley/go-metrics v0.0.0-20141108142129-dee209f2455f/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/redis/go-redis/v9 v9.5.1 h1:H1X4D3yHPaYrkL5X06Wh6xNVM/pX0Ft4RV0vMGvLBh8= github.com/redis/go-redis/v9 v9.5.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/robfig/cron v1.2.0 h1:ZjScXvvxeQ63Dbyxy76Fj3AT3Ut0aKsyd2/tl3DTMuQ= github.com/robfig/cron v1.2.0/go.mod h1:JGuDeoQd7Z6yL4zQhZ3OPEVHB7fL6Ka6skscFHfmt2k= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/samuel/go-thrift v0.0.0-20190219015601-e8b6b52668fe/go.mod h1:Vrkh1pnjV9Bl8c3P9zH0/D4NlOHWP5d4/hF4YTULaec= github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a h1:iLcLb5Fwwz7g/DLK89F+uQBDeAhHhwdzB5fSlVdhGcM= github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a/go.mod h1:wozgYq9WEBQBaIJe4YZ0qTSFAMxmcwBhQH0fO0R34Z0= github.com/sethvargo/go-envconfig v1.1.0 h1:cWZiJxeTm7AlCvzGXrEXaSTCNgip5oJepekh/BOQuog= @@ -553,12 +620,15 @@ github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFt github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= +github.com/sirupsen/logrus v1.0.2-0.20170726183946-abee6f9b0679/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/sony/gobreaker v1.0.0 h1:feX5fGGXSl3dYd4aHZItw+FpHLvvoaqkawKjVNiFMNQ= +github.com/sony/gobreaker v1.0.0/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= github.com/spf13/afero v0.0.0-20170901052352-ee1bd8ee15a1/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/cast v1.1.0/go.mod h1:r2rcYCSwa1IExKTDiTfzaxqT2FNHs8hODu4LnUfgKEg= github.com/spf13/cast v1.5.1 h1:R+kOtfhWQE6TVQzY+4D7wJLBgkdVasCEFxSUBYBYIlA= @@ -587,14 +657,22 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/swaggo/files v1.0.1 h1:J1bVJ4XHZNq0I46UU90611i9/YzdrF7x92oX1ig5IdE= github.com/swaggo/files v1.0.1/go.mod h1:0qXmMNH6sXNf+73t65aKeB+ApmgxdnkQzVTAj2uaMUg= github.com/swaggo/gin-swagger v1.6.0 h1:y8sxvQ3E20/RCyrXeFfg60r6H0Z+SwpTjMYsMm+zy8M= github.com/swaggo/gin-swagger v1.6.0/go.mod h1:BG00cCEy294xtVpyIAHG6+e2Qzj/xKlRdOqDkvq0uzo= github.com/swaggo/swag v1.16.2 h1:28Pp+8DkQoV+HLzLx8RGJZXNGKbFqnuvSbAAtoxiY04= github.com/swaggo/swag v1.16.2/go.mod h1:6YzXnDcpr0767iOejs318CwYkCQqyGer6BizOg03f+E= +github.com/temporalio/ringpop-go v0.0.0-20241119001152-e505ebd8f887 h1:08Y1jDl4UKVu+TiQHIVKcW6TKQaHl15vBKkcZ094/SA= +github.com/temporalio/ringpop-go v0.0.0-20241119001152-e505ebd8f887/go.mod h1:RE+CHmY+kOZQk47AQaVzwrGmxpflnLgTd6EOK0853j4= +github.com/temporalio/sqlparser v0.0.0-20231115171017-f4060bcfa6cb h1:YzHH/U/dN7vMP+glybzcXRTczTrgfdRisNTzAj7La04= +github.com/temporalio/sqlparser v0.0.0-20231115171017-f4060bcfa6cb/go.mod h1:143qKdh3G45IgV9p+gbAwp3ikRDI8mxsijFiXDfuxsw= +github.com/temporalio/tchannel-go v1.22.1-0.20220818200552-1be8d8cffa5b/go.mod h1:c+V9Z/ZgkzAdyGvHrvC5AsXgN+M9Qwey04cBdKYzV7U= +github.com/temporalio/tchannel-go v1.22.1-0.20240528171429-1db37fdea938 h1:sEJGhmDo+0FaPWM6f0v8Tjia0H5pR6/Baj6+kS78B+M= +github.com/temporalio/tchannel-go v1.22.1-0.20240528171429-1db37fdea938/go.mod h1:ezRQRwu9KQXy8Wuuv1aaFFxoCNz5CeNbVOOkh3xctbY= github.com/testcontainers/testcontainers-go v0.34.0 h1:5fbgF0vIN5u+nD3IWabQwRybuB4GY8G2HHgCkbMzMHo= github.com/testcontainers/testcontainers-go v0.34.0/go.mod h1:6P/kMkQe8yqPHfPWNulFGdFHTD8HB2vLq/231xY2iPQ= github.com/testcontainers/testcontainers-go/modules/mysql v0.32.0 h1:6vjJOVJSWDTyNvQmB8EFTmv20ScquRWZa+pM1hZNodc= @@ -616,6 +694,14 @@ github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYm github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/twmb/murmur3 v1.1.8 h1:8Yt9taO/WN3l08xErzjeschgZU2QSrwm1kclYq+0aRg= +github.com/twmb/murmur3 v1.1.8/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ= +github.com/uber-common/bark v1.0.0 h1:l5mfssVFEaYr60U8c4LNLRkp6xo6pTcOAcB1aJ4f9+g= +github.com/uber-common/bark v1.0.0/go.mod h1:g0ZuPcD7XiExKHynr93Q742G/sbrdVQkghrqLGOoFuY= +github.com/uber-go/tally v3.3.15+incompatible/go.mod h1:YDTIBxdXyOU/sCWilKB4bgyufu1cEi0jdVnRdxvjnmU= +github.com/uber-go/tally/v4 v4.1.17-0.20240412215630-22fe011f5ff0 h1:z5IgRoL16N7tdzn5oikX2G4oVXopW+CWo3XRxx61OQo= +github.com/uber-go/tally/v4 v4.1.17-0.20240412215630-22fe011f5ff0/go.mod h1:ZdpiHRGSa3z4NIAc1VlEH4SiknR885fOIF08xmS0gaU= +github.com/uber/jaeger-client-go v2.22.1+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o= github.com/uber/jaeger-client-go v2.30.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= github.com/uber/jaeger-lib v2.4.1+incompatible h1:td4jdvLcExb4cBISKIpHuGoVXh+dVKhn2Um6rjCsSsg= @@ -658,44 +744,63 @@ gitlab.com/gitlab-org/labkit v1.21.2/go.mod h1:Q++SWyCH/abH2pytnX2SU/3mrCX6aK/xK go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw= -go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo= -go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.22.0 h1:9M3+rhx7kZCIQQhQRYaZCdNu1V73tm4TvXs2ntl98C4= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.22.0/go.mod h1:noq80iT8rrHP1SfybmPiRGc9dc5M8RPmGvtwo7Oo7tc= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.56.0 h1:yMkBS9yViCc7U7yeLzJPM2XizlfdVvBRSmsQDWu6qc0= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.56.0/go.mod h1:n8MR6/liuGB5EmTETUBeU5ZgqMOlqKRxUaqPQBOANZ8= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.52.0 h1:9l89oX4ba9kHbBol3Xin3leYJ+252h0zszDtBwyKe2A= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.52.0/go.mod h1:XLZfZboOJWHNKUv7eH0inh0E9VV6eWDFB/9yJyTLPp0= +go.opentelemetry.io/otel v1.31.0 h1:NsJcKPIW0D0H3NgzPDHmo0WW6SptzPdqg/L1zsIm2hY= +go.opentelemetry.io/otel v1.31.0/go.mod h1:O0C14Yl9FgkjqcCZAsE053C13OaddMYr/hz6clDkEJE= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.31.0 h1:FZ6ei8GFW7kyPYdxJaV2rgI6M+4tvZzhYsQ2wgyVC08= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.31.0/go.mod h1:MdEu/mC6j3D+tTEfvI15b5Ci2Fn7NneJ71YMoiS3tpI= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0 h1:K0XaT3DwHAcV4nKLzcQvwAgSyisUghWoY20I7huthMk= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0/go.mod h1:B5Ki776z/MBnVha1Nzwp5arlzBbE3+1jk+pGmaP5HME= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.31.0 h1:FFeLy03iVTXP6ffeN2iXrxfGsZGCjVx0/4KlizjyBwU= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.31.0/go.mod h1:TMu73/k1CP8nBUpDLc71Wj/Kf7ZS9FK5b53VapRsP9o= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.22.0 h1:FyjCyI9jVEfqhUh2MoSkmolPjfh5fp2hnV0b0irxH4Q= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.22.0/go.mod h1:hYwym2nDEeZfG/motx0p7L7J1N1vyzIThemQsb4g2qY= -go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGXlc88kI= -go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco= -go.opentelemetry.io/otel/sdk v1.22.0 h1:6coWHw9xw7EfClIC/+O31R8IY3/+EiRFHevmHafB2Gw= -go.opentelemetry.io/otel/sdk v1.22.0/go.mod h1:iu7luyVGYovrRpe2fmj3CVKouQNdTOkxtLzPvPz1DOc= -go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI= -go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU= -go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= -go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= -go.temporal.io/api v1.40.0 h1:rH3HvUUCFr0oecQTBW5tI6DdDQsX2Xb6OFVgt/bvLto= -go.temporal.io/api v1.40.0/go.mod h1:1WwYUMo6lao8yl0371xWUm13paHExN5ATYT/B7QtFis= -go.temporal.io/sdk v1.30.0 h1:7jzSFZYk+tQ2kIYEP+dvrM7AW9EsCEP52JHCjVGuwbI= -go.temporal.io/sdk v1.30.0/go.mod h1:Pv45F/fVDgWKx+jhix5t/dGgqROVaI+VjPLd3CHWqq0= +go.opentelemetry.io/otel/exporters/prometheus v0.53.0 h1:QXobPHrwiGLM4ufrY3EOmDPJpo2P90UuFau4CDPJA/I= +go.opentelemetry.io/otel/exporters/prometheus v0.53.0/go.mod h1:WOAXGr3D00CfzmFxtTV1eR0GpoHuPEu+HJT8UWW2SIU= +go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozRPcF2fE= +go.opentelemetry.io/otel/metric v1.31.0/go.mod h1:C3dEloVbLuYoX41KpmAhOqNriGbA+qqH6PQ5E5mUfnY= +go.opentelemetry.io/otel/sdk v1.31.0 h1:xLY3abVHYZ5HSfOg3l2E5LUj2Cwva5Y7yGxnSW9H5Gk= +go.opentelemetry.io/otel/sdk v1.31.0/go.mod h1:TfRbMdhvxIIr/B2N2LQW2S5v9m3gOQ/08KsbbO5BPT0= +go.opentelemetry.io/otel/sdk/metric v1.31.0 h1:i9hxxLJF/9kkvfHppyLL55aW7iIJz4JjxTeYusH7zMc= +go.opentelemetry.io/otel/sdk/metric v1.31.0/go.mod h1:CRInTMVvNhUKgSAMbKyTMxqOBC0zgyxzW55lZzX43Y8= +go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys= +go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A= +go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= +go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8= +go.temporal.io/api v1.43.0 h1:lBhq+u5qFJqGMXwWsmg/i8qn1UA/3LCwVc88l2xUMHg= +go.temporal.io/api v1.43.0/go.mod h1:1WwYUMo6lao8yl0371xWUm13paHExN5ATYT/B7QtFis= +go.temporal.io/sdk v1.31.0 h1:CLYiP0R5Sdj0gq8LyYKDDz4ccGOdJPR8wNGJU0JGwj8= +go.temporal.io/sdk v1.31.0/go.mod h1:8U8H7rF9u4Hyb4Ry9yiEls5716DHPNvVITPNkgWUwE8= +go.temporal.io/server v1.26.2 h1:vDW11lxslYPlGDbQklWi/tqbkVZ2ExtRO1jNjvZmUUI= +go.temporal.io/server v1.26.2/go.mod h1:tgY+4z/PuIdqs6ouV1bT90RWSWfEioWkzmrNrLYLUrk= +go.temporal.io/version v0.3.0 h1:dMrei9l9NyHt8nG6EB8vAwDLLTwx2SvRyucCSumAiig= +go.temporal.io/version v0.3.0/go.mod h1:UA9S8/1LaKYae6TyD9NaPMJTZb911JcbqghI2CBSP78= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/dig v1.17.1 h1:Tga8Lz8PcYNsWsyHMZ1Vm0OQOUaJNDyvPImgbAu9YSc= +go.uber.org/dig v1.17.1/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE= +go.uber.org/fx v1.22.0 h1:pApUK7yL0OUHMd8vkunWSlLxZVFFk70jR2nKde8X2NM= +go.uber.org/fx v1.22.0/go.mod h1:HT2M7d7RHo+ebKGh9NRcrsrHHfpZ60nW3QRubMRfv48= go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= -go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= -go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= @@ -721,8 +826,8 @@ golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= -golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= -golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= +golang.org/x/exp v0.0.0-20240531132922-fd00a4e0eefc h1:O9NuF4s+E/PvMIy+9IUZB9znFwUIXEWSstNjek6VpVg= +golang.org/x/exp v0.0.0-20240531132922-fd00a4e0eefc/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= @@ -762,6 +867,7 @@ golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= +golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= @@ -771,13 +877,13 @@ golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= -golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= +golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= golang.org/x/oauth2 v0.0.0-20170912212905-13449ad91cb2/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.22.0 h1:BzDx2FehcG7jJwgWLELCdmLuxk2i+x9UDpSiss2u0ZA= -golang.org/x/oauth2 v0.22.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= +golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20170517211232-f52d1811a629/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -814,11 +920,13 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210217105451-b926d437f341/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -898,8 +1006,8 @@ gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= google.golang.org/api v0.0.0-20170921000349-586095a6e407/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= -google.golang.org/api v0.169.0 h1:QwWPy71FgMWqJN/l6jVlFHUa29a7dcUy02I8o799nPY= -google.golang.org/api v0.169.0/go.mod h1:gpNOiMA2tZ4mf5R9Iwf4rK/Dcz0fbdIgWYWVoxmsyLg= +google.golang.org/api v0.182.0 h1:if5fPvudRQ78GeRx3RayIoiuV7modtErPIZC/T2bIvE= +google.golang.org/api v0.182.0/go.mod h1:cGhjy4caqA5yXRzEhkHI8Y9mfyC2VLTlER2l08xaqtM= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= @@ -911,12 +1019,12 @@ google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98 google.golang.org/genproto v0.0.0-20200423170343-7949de9c1215/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20240311173647-c811ad7063a7 h1:ImUcDPHjTrAqNhlOkSocDLfG9rrNHH7w7uoKWPaWZ8s= -google.golang.org/genproto v0.0.0-20240311173647-c811ad7063a7/go.mod h1:/3XmxOjePkvmKrHuBy4zNFw7IzxJXtAgdpXi8Ll990U= -google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed h1:3RgNmBoI9MZhsj3QxC+AP/qQhNwpCLOvYDYYsFrhFt0= -google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed/go.mod h1:OCdP9MfskevB/rbYvHTsXTtKC+3bHWajPdoKgjcYkfo= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240827150818-7e3bb234dfed h1:J6izYgfBXAI3xTKLgxzTmUltdYaLsuBxFCgDHWJ/eXg= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240827150818-7e3bb234dfed/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/genproto v0.0.0-20240528184218-531527333157 h1:u7WMYrIrVvs0TF5yaKwKNbcJyySYf+HAIFXxWltJOXE= +google.golang.org/genproto v0.0.0-20240528184218-531527333157/go.mod h1:ubQlAQnzejB8uZzszhrTCU2Fyp6Vi7ZE5nn0c3W8+qQ= +google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 h1:T6rh4haD3GVYsgEfWExoCZA2o2FmbNyKpTuAxbEFPTg= +google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9/go.mod h1:wp2WsuBYj6j8wUdo3ToZsdxxixbvQNAHqVJrTgi5E5M= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9 h1:QCqS/PdaHTSWGvupk2F/ehwHtGc0/GYkT+3GAcR1CCc= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI= google.golang.org/grpc v1.2.1-0.20170921194603-d4b75ebd4f9f/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= @@ -928,8 +1036,8 @@ google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3Iji google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= -google.golang.org/grpc v1.66.0 h1:DibZuoBznOxbDQxRINckZcUvnCEvrW9pcWIE2yF9r1c= -google.golang.org/grpc v1.66.0/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y= +google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E= +google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -941,8 +1049,8 @@ google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpAD google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= -google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= +google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/DataDog/dd-trace-go.v1 v1.32.0 h1:DkD0plWEVUB8v/Ru6kRBW30Hy/fRNBC8hPdcExuBZMc= gopkg.in/DataDog/dd-trace-go.v1 v1.32.0/go.mod h1:wRKMf/tRASHwH/UOfPQ3IQmVFhTz2/1a1/mpXoIjF54= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -954,6 +1062,8 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV gopkg.in/evanphx/json-patch.v4 v4.12.0 h1:n6jtcsulIzXPJaxegRbvFNNrZDjbij7ny3gmSPG+6V4= gopkg.in/evanphx/json-patch.v4 v4.12.0/go.mod h1:p8EYWUEYMpynmqDbY58zCKCFZw8pRWMG4EsWvDvM72M= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/go-jose/go-jose.v2 v2.6.3 h1:nt80fvSDlhKWQgSWyHyy5CfmlQr+asih51R8PTWNKKs= +gopkg.in/go-jose/go-jose.v2 v2.6.3/go.mod h1:zzZDPkNNw/c9IE7Z9jr11mBZQhKQTMzoEEIoEdZlFBI= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/ini.v1 v1.56.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= @@ -962,6 +1072,8 @@ gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/validator.v2 v2.0.1 h1:xF0KWyGWXm/LM2G1TrEjqOu4pa6coO9AlWSf3msVfDY= +gopkg.in/validator.v2 v2.0.1/go.mod h1:lIUZBlB3Im4s/eYp39Ry/wkR02yOPhZ9IwIRBjuPuG8= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= @@ -999,16 +1111,28 @@ knative.dev/serving v0.40.1 h1:ZAAK8KwZQYUgCgVi3ay+NqKAISnJQ1OXPYvdtXWUcBc= knative.dev/serving v0.40.1/go.mod h1:Ory3XczDB8b1lH757CSdeDeouY3LHzSamX8IjmStuoU= mellium.im/sasl v0.3.1 h1:wE0LW6g7U83vhvxjC1IY8DnXM+EU095yeo8XClvCdfo= mellium.im/sasl v0.3.1/go.mod h1:xm59PUYpZHhgQ9ZqoJ5QaCqzWMi8IeS49dhp6plPCzw= -modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI= -modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4= -modernc.org/libc v1.41.0 h1:g9YAc6BkKlgORsUWj+JwqoB1wU3o4DE3bM3yvA3k+Gk= -modernc.org/libc v1.41.0/go.mod h1:w0eszPsiXoOnoMJgrXjglgLuDy/bt5RR4y3QzUUeodY= +modernc.org/cc/v4 v4.21.4 h1:3Be/Rdo1fpr8GrQ7IVw9OHtplU4gWbb+wNgeoBMmGLQ= +modernc.org/cc/v4 v4.21.4/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ= +modernc.org/ccgo/v4 v4.19.2 h1:lwQZgvboKD0jBwdaeVCTouxhxAyN6iawF3STraAal8Y= +modernc.org/ccgo/v4 v4.19.2/go.mod h1:ysS3mxiMV38XGRTTcgo0DQTeTmAO4oCmJl1nX9VFI3s= +modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE= +modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ= +modernc.org/gc/v2 v2.4.1 h1:9cNzOqPyMJBvrUipmynX0ZohMhcxPtMccYgGOJdOiBw= +modernc.org/gc/v2 v2.4.1/go.mod h1:wzN5dK1AzVGoH6XOzc3YZ+ey/jPgYHLuVckd62P0GYU= +modernc.org/gc/v3 v3.0.0-20240304020402-f0dba7c97c2b h1:BnN1t+pb1cy61zbvSUV7SeI0PwosMhlAEi/vBY4qxp8= +modernc.org/gc/v3 v3.0.0-20240304020402-f0dba7c97c2b/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4= +modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U= +modernc.org/libc v1.55.3/go.mod h1:qFXepLhz+JjFThQ4kzwzOjA/y/artDeg+pcYnY+Q83w= modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= -modernc.org/memory v1.7.2 h1:Klh90S215mmH8c9gO98QxQFsY+W451E8AnzjoE2ee1E= -modernc.org/memory v1.7.2/go.mod h1:NO4NVCQy0N7ln+T9ngWqOQfi7ley4vpwvARR+Hjw95E= -modernc.org/sqlite v1.29.1 h1:19GY2qvWB4VPw0HppFlZCPAbmxFU41r+qjKZQdQ1ryA= -modernc.org/sqlite v1.29.1/go.mod h1:hG41jCYxOAOoO6BRK66AdRlmOcDzXf7qnwlwjUIOqa0= +modernc.org/memory v1.8.0 h1:IqGTL6eFMaDZZhEWwcREgeMXYwmW83LYW8cROZYkg+E= +modernc.org/memory v1.8.0/go.mod h1:XPZ936zp5OMKGWPqbD3JShgd/ZoQ7899TUuQqxY+peU= +modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= +modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/sortutil v1.2.0 h1:jQiD3PfS2REGJNzNCMMaLSp/wdMNieTbKX920Cqdgqc= +modernc.org/sortutil v1.2.0/go.mod h1:TKU2s7kJMf1AE84OoiGppNHJwvB753OYfNl2WRb++Ss= +modernc.org/sqlite v1.34.1 h1:u3Yi6M0N8t9yKRDwhXcyp1eS5/ErhPTBggxWFuR6Hfk= +modernc.org/sqlite v1.34.1/go.mod h1:pXV2xHxhzXZsgT/RtTFAPY6JJDEvOTcTdwADQCCWD4k= modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= diff --git a/mirror/reposyncer/local_woker.go b/mirror/reposyncer/local_woker.go index 530630b6..b6a98c11 100644 --- a/mirror/reposyncer/local_woker.go +++ b/mirror/reposyncer/local_woker.go @@ -15,6 +15,7 @@ import ( "opencsg.com/csghub-server/builder/git/gitserver" "opencsg.com/csghub-server/builder/git/gitserver/gitaly" "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/builder/temporal" "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/types" "opencsg.com/csghub-server/mirror/queue" @@ -227,14 +228,13 @@ func (w *LocalMirrorWoker) SyncRepo(ctx context.Context, task queue.MirrorTask) callback.Ref = branch //start workflow to handle push request - workflowClient := workflow.GetWorkflowClient() + workflowClient := temporal.GetClient() workflowOptions := client.StartWorkflowOptions{ TaskQueue: workflow.HandlePushQueueName, } - we, err := workflowClient.ExecuteWorkflow(ctx, workflowOptions, workflow.HandlePushWorkflow, - callback, - w.config, + we, err := workflowClient.ExecuteWorkflow( + ctx, workflowOptions, workflow.HandlePushWorkflow, callback, ) if err != nil { return fmt.Errorf("failed to handle git push callback: %w", err) From 87c11be7c0fbdd8f52efca5f090c1e7f894710be Mon Sep 17 00:00:00 2001 From: Lei Da Date: Wed, 8 Jan 2025 09:44:11 +0800 Subject: [PATCH 34/34] ce,ee allow create public dataset --- api/handler/dataset.go | 4 +++ api/handler/dataset_ce.go | 7 ++++ api/handler/dataset_ce_test.go | 58 ++++++++++++++++++++++++++++++++++ api/handler/dataset_test.go | 26 --------------- 4 files changed, 69 insertions(+), 26 deletions(-) create mode 100644 api/handler/dataset_ce.go create mode 100644 api/handler/dataset_ce_test.go diff --git a/api/handler/dataset.go b/api/handler/dataset.go index 043fe0c2..faceb021 100644 --- a/api/handler/dataset.go +++ b/api/handler/dataset.go @@ -76,6 +76,10 @@ func (h *DatasetHandler) Create(ctx *gin.Context) { httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) return } + if !req.Private && !h.allowCreatePublic() { + httpbase.BadRequest(ctx, "creating public dataset is not allowed") + return + } req.Username = currentUser dataset, err := h.dataset.Create(ctx, req) diff --git a/api/handler/dataset_ce.go b/api/handler/dataset_ce.go new file mode 100644 index 00000000..ed72fc94 --- /dev/null +++ b/api/handler/dataset_ce.go @@ -0,0 +1,7 @@ +//go:build !saas + +package handler + +func (h *DatasetHandler) allowCreatePublic() bool { + return true +} diff --git a/api/handler/dataset_ce_test.go b/api/handler/dataset_ce_test.go new file mode 100644 index 00000000..1b55a034 --- /dev/null +++ b/api/handler/dataset_ce_test.go @@ -0,0 +1,58 @@ +//go:build !saas + +package handler + +import ( + "testing" + + "github.com/gin-gonic/gin" + "opencsg.com/csghub-server/common/types" +) + +func TestDatasetHandler_Create(t *testing.T) { + t.Run("no public", func(t *testing.T) { + + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.Create + }) + tester.RequireUser(t) + + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, &types.CreateDatasetReq{ + CreateRepoReq: types.CreateRepoReq{Private: true}, + }).Return(true, nil) + tester.mocks.dataset.EXPECT().Create(tester.ctx, &types.CreateDatasetReq{ + CreateRepoReq: types.CreateRepoReq{Private: true, Username: "u"}, + }).Return(&types.Dataset{Name: "d"}, nil) + tester.WithBody(t, &types.CreateDatasetReq{ + CreateRepoReq: types.CreateRepoReq{Private: true}, + }).Execute() + + tester.ResponseEqSimple(t, 200, gin.H{ + "data": &types.Dataset{Name: "d"}, + }) + + }) + + t.Run("public", func(t *testing.T) { + + tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { + return h.Create + }) + tester.RequireUser(t) + + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, &types.CreateDatasetReq{ + CreateRepoReq: types.CreateRepoReq{Private: false}, + }).Return(true, nil) + tester.mocks.dataset.EXPECT().Create(tester.ctx, &types.CreateDatasetReq{ + CreateRepoReq: types.CreateRepoReq{Private: false, Username: "u"}, + }).Return(&types.Dataset{Name: "d"}, nil) + tester.WithBody(t, &types.CreateDatasetReq{ + CreateRepoReq: types.CreateRepoReq{Private: false}, + }).Execute() + + tester.ResponseEqSimple(t, 200, gin.H{ + "data": &types.Dataset{Name: "d"}, + }) + }) + +} diff --git a/api/handler/dataset_test.go b/api/handler/dataset_test.go index de181d1a..6e9d9c76 100644 --- a/api/handler/dataset_test.go +++ b/api/handler/dataset_test.go @@ -41,32 +41,6 @@ func (t *DatasetTester) WithHandleFunc(fn func(h *DatasetHandler) gin.HandlerFun return t } -func TestDatasetHandler_Create(t *testing.T) { - - t.Run("public", func(t *testing.T) { - - tester := NewDatasetTester(t).WithHandleFunc(func(h *DatasetHandler) gin.HandlerFunc { - return h.Create - }) - tester.RequireUser(t) - - tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, &types.CreateDatasetReq{ - CreateRepoReq: types.CreateRepoReq{Private: true}, - }).Return(true, nil) - tester.mocks.dataset.EXPECT().Create(tester.ctx, &types.CreateDatasetReq{ - CreateRepoReq: types.CreateRepoReq{Private: true, Username: "u"}, - }).Return(&types.Dataset{Name: "d"}, nil) - tester.WithBody(t, &types.CreateDatasetReq{ - CreateRepoReq: types.CreateRepoReq{Private: true}, - }).Execute() - - tester.ResponseEqSimple(t, 200, gin.H{ - "data": &types.Dataset{Name: "d"}, - }) - }) - -} - func TestDatasetHandler_Index(t *testing.T) { cases := []struct { sort string