diff --git a/internal/controlplane/handlers_githubwebhooks.go b/internal/controlplane/handlers_githubwebhooks.go index f7239a14c7..f74504f003 100644 --- a/internal/controlplane/handlers_githubwebhooks.go +++ b/internal/controlplane/handlers_githubwebhooks.go @@ -412,16 +412,8 @@ func (s *Server) parseGithubEventForProcessing( return fmt.Errorf("error getting repo information from payload: %w", err) } - ph, err := s.store.GetParentProjects(ctx, dbRepo.ProjectID) - if err != nil { - return fmt.Errorf("error getting project hierarchy: %w", err) - } - // get the provider for the repository - prov, err := s.store.GetProviderByName(ctx, db.GetProviderByNameParams{ - Name: dbRepo.Provider, - Projects: ph, - }) + prov, err := s.providerStore.GetByName(ctx, dbRepo.ProjectID, dbRepo.Provider) if err != nil { return fmt.Errorf("error getting provider: %w", err) } @@ -430,7 +422,7 @@ func (s *Server) parseGithubEventForProcessing( providers.WithProviderMetrics(s.provMt), providers.WithRestClientCache(s.restClientCache), } - provBuilder, err := providers.GetProviderBuilder(ctx, prov, s.store, s.cryptoEngine, &s.cfg.Provider, pbOpts...) + provBuilder, err := providers.GetProviderBuilder(ctx, *prov, s.store, s.cryptoEngine, &s.cfg.Provider, pbOpts...) if err != nil { return fmt.Errorf("error building client: %w", err) } diff --git a/internal/controlplane/handlers_githubwebhooks_test.go b/internal/controlplane/handlers_githubwebhooks_test.go index 1c5711d5c5..decdafbeb6 100644 --- a/internal/controlplane/handlers_githubwebhooks_test.go +++ b/internal/controlplane/handlers_githubwebhooks_test.go @@ -276,16 +276,17 @@ func (s *UnitTestSuite) TestHandleWebHookRepository() { }, nil) mockStore.EXPECT(). - GetProviderByName(gomock.Any(), db.GetProviderByNameParams{ - Name: providerName, + FindProviders(gomock.Any(), db.FindProvidersParams{ + Name: sql.NullString{String: providerName, Valid: true}, Projects: []uuid.UUID{ projectID, }, + Trait: db.NullProviderType{}, }). - Return(db.Provider{ + Return([]db.Provider{{ ProjectID: projectID, Name: providerName, - }, nil) + }}, nil) hook := srv.HandleGitHubWebHook() port, err := rand.GetRandomPort() diff --git a/internal/controlplane/handlers_oauth.go b/internal/controlplane/handlers_oauth.go index 47996162fb..581e1edfbe 100644 --- a/internal/controlplane/handlers_oauth.go +++ b/internal/controlplane/handlers_oauth.go @@ -392,11 +392,7 @@ func (s *Server) StoreProviderToken(ctx context.Context, return nil, status.Errorf(codes.InvalidArgument, "provider name is required") } - provider, err := s.store.GetProviderByName(ctx, db.GetProviderByNameParams{ - // We don't check parent projects here because subprojects should not update the credentials of their parent - Projects: []uuid.UUID{projectID}, - Name: providerName, - }) + provider, err := s.providerStore.GetByNameInSpecificProject(ctx, projectID, providerName) if err != nil { return nil, providerError(err) } @@ -500,11 +496,7 @@ func (s *Server) VerifyProviderTokenFrom(ctx context.Context, return &pb.VerifyProviderTokenFromResponse{Status: "KO"}, nil } - provider, err := s.store.GetProviderByName(ctx, db.GetProviderByNameParams{ - // We don't check parent projects here because subprojects should not check the credentials of their parent - Projects: []uuid.UUID{projectID}, - Name: providerName, - }) + provider, err := s.providerStore.GetByNameInSpecificProject(ctx, projectID, providerName) if err != nil { return nil, providerError(err) } @@ -560,7 +552,7 @@ func (s *Server) VerifyProviderCredential(ctx context.Context, ) if err == nil { - provider, err := s.store.GetProviderByID(ctx, installation.ProviderID.UUID) + provider, err := s.providerStore.GetByID(ctx, installation.ProviderID.UUID) if err != nil { return nil, status.Errorf(codes.Internal, "error getting provider: %v", err) } diff --git a/internal/controlplane/handlers_oauth_test.go b/internal/controlplane/handlers_oauth_test.go index 5d41a7ea67..368a9ed8df 100644 --- a/internal/controlplane/handlers_oauth_test.go +++ b/internal/controlplane/handlers_oauth_test.go @@ -774,6 +774,7 @@ func TestVerifyProviderCredential(t *testing.T) { store: store, evt: evt, providerAuthFactory: providerAuthFactory, + providerStore: providers.NewProviderStore(store), cfg: &serverconfig.Config{ Auth: serverconfig.AuthConfig{}, }, diff --git a/internal/controlplane/handlers_providers.go b/internal/controlplane/handlers_providers.go index 0fb8a7840a..6a28618a7e 100644 --- a/internal/controlplane/handlers_providers.go +++ b/internal/controlplane/handlers_providers.go @@ -19,7 +19,6 @@ import ( "database/sql" "errors" - "github.com/google/uuid" "github.com/rs/zerolog" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -39,12 +38,7 @@ func (s *Server) GetProvider(ctx context.Context, req *minderv1.GetProviderReque entityCtx := engine.EntityFromContext(ctx) projectID := entityCtx.Project.ID - prov, err := s.store.GetProviderByName(ctx, db.GetProviderByNameParams{ - Name: req.Name, - // Note that this does not take the hierarchy into account in purpose. - // We want to get this call to be explicit for the given project. - Projects: []uuid.UUID{projectID}, - }) + prov, err := s.providerStore.GetByNameInSpecificProject(ctx, projectID, req.GetName()) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, util.UserVisibleError(codes.NotFound, "provider not found") @@ -66,11 +60,11 @@ func (s *Server) GetProvider(ctx context.Context, req *minderv1.GetProviderReque Name: prov.Name, Project: projectID.String(), Version: prov.Version, - Implements: protobufProviderImplementsFromDB(ctx, prov), - AuthFlows: protobufProviderAuthFlowFromDB(ctx, prov), + Implements: protobufProviderImplementsFromDB(ctx, *prov), + AuthFlows: protobufProviderAuthFlowFromDB(ctx, *prov), Config: cfg, - CredentialsState: providers.GetCredentialStateForProvider(ctx, prov, s.store, s.cryptoEngine, &s.cfg.Provider), - Class: providers.GetProviderClassString(prov), + CredentialsState: providers.GetCredentialStateForProvider(ctx, *prov, s.store, s.cryptoEngine, &s.cfg.Provider), + Class: providers.GetProviderClassString(*prov), }, }, nil } diff --git a/internal/controlplane/handlers_repositories.go b/internal/controlplane/handlers_repositories.go index 03cbbf7e48..8f64e96b20 100644 --- a/internal/controlplane/handlers_repositories.go +++ b/internal/controlplane/handlers_repositories.go @@ -76,12 +76,12 @@ func (s *Server) RegisterRepository( return nil, status.Errorf(codes.Internal, "cannot get provider: %v", err) } - client, err := s.getClientForProvider(ctx, provider) + client, err := s.getClientForProvider(ctx, *provider) if err != nil { return nil, err } - newRepo, err := s.repos.CreateRepository(ctx, client, &provider, projectID, githubRepo.GetOwner(), githubRepo.GetName()) + newRepo, err := s.repos.CreateRepository(ctx, client, provider, projectID, githubRepo.GetOwner(), githubRepo.GetName()) if err != nil { if errors.Is(err, ghrepo.ErrPrivateRepoForbidden) { return nil, util.UserVisibleError(codes.InvalidArgument, err.Error()) @@ -425,12 +425,12 @@ func (s *Server) deleteRepository( return status.Errorf(codes.Internal, "unexpected error fetching repo: %v", err) } - provider, err := s.findProviderByName(ctx, repo.Provider, projectID) + provider, err := s.providerStore.GetByName(ctx, projectID, repo.Provider) if err != nil { return status.Errorf(codes.Internal, "cannot get provider: %v", err) } - client, err := s.getClientForProvider(ctx, provider) + client, err := s.getClientForProvider(ctx, *provider) if err != nil { return status.Errorf(codes.Internal, "cannot get client for provider: %v", err) } @@ -445,13 +445,13 @@ func (s *Server) deleteRepository( // TODO: move out of controlplane // inferProviderByOwner returns the provider to use for a given repo owner func (s *Server) inferProviderByOwner(ctx context.Context, owner string, projectID uuid.UUID, providerName string, -) (db.Provider, error) { +) (*db.Provider, error) { if providerName != "" { - return s.findProviderByName(ctx, providerName, projectID) + return s.providerStore.GetByName(ctx, projectID, providerName) } opts, err := s.providerStore.GetByNameAndTrait(ctx, projectID, providerName, db.ProviderTypeGithub) if err != nil { - return db.Provider{}, fmt.Errorf("error getting providers: %v", err) + return nil, fmt.Errorf("error getting providers: %v", err) } slices.SortFunc(opts, func(a, b db.Provider) int { @@ -467,11 +467,11 @@ func (s *Server) inferProviderByOwner(ctx context.Context, owner string, project for _, prov := range opts { if github.CanHandleOwner(ctx, prov, owner) { - return prov, nil + return &prov, nil } } - return db.Provider{}, fmt.Errorf("no providers can handle repo owned by %s", owner) + return nil, fmt.Errorf("no providers can handle repo owned by %s", owner) } func (s *Server) getClientForProvider( @@ -496,22 +496,6 @@ func (s *Server) getClientForProvider( return client, nil } -func (s *Server) findProviderByName(ctx context.Context, providerName string, projectID uuid.UUID) (db.Provider, error) { - ph, err := s.store.GetParentProjects(ctx, projectID) - if err != nil { - return db.Provider{}, fmt.Errorf("error getting project hierarchy: %v", err) - } - - provider, err := s.store.GetProviderByName(ctx, db.GetProviderByNameParams{ - Name: providerName, - Projects: ph, - }) - if err != nil { - return db.Provider{}, fmt.Errorf("error getting provider: %v", err) - } - return provider, nil -} - func getProjectID(ctx context.Context) uuid.UUID { entityCtx := engine.EntityFromContext(ctx) return entityCtx.Project.ID diff --git a/internal/controlplane/handlers_repositories_test.go b/internal/controlplane/handlers_repositories_test.go index 8642d19ec7..572cabada0 100644 --- a/internal/controlplane/handlers_repositories_test.go +++ b/internal/controlplane/handlers_repositories_test.go @@ -56,7 +56,7 @@ func TestServer_RegisterRepository(t *testing.T) { RepoOwner: repoOwner, RepoName: repoName, ProviderFails: true, - ExpectedError: "error getting provider", + ExpectedError: "cannot retrieve provider", }, { Name: "Repo creation fails when repo name is missing", @@ -150,7 +150,7 @@ func TestServer_DeleteRepository(t *testing.T) { RepoName: repoOwnerAndName, RepoServiceSetup: newRepoService(withSuccessfulGetRepoByName), ProviderFails: true, - ExpectedError: "error getting provider", + ExpectedError: "cannot retrieve provider", }, { Name: "delete by name fails when name is malformed", @@ -354,12 +354,12 @@ func createServer(ctrl *gomock.Controller, repoServiceSetup repoMockBuilder, pro if providerFails { store.EXPECT(). - GetProviderByName(gomock.Any(), gomock.Any()). - Return(db.Provider{}, errDefault) + FindProviders(gomock.Any(), gomock.Any()). + Return([]db.Provider{}, errDefault) } else { store.EXPECT(). - GetProviderByName(gomock.Any(), gomock.Any()). - Return(provider, nil).AnyTimes() + FindProviders(gomock.Any(), gomock.Any()). + Return([]db.Provider{provider}, nil).AnyTimes() store.EXPECT(). GetAccessTokenByProjectID(gomock.Any(), gomock.Any()). Return(db.ProviderAccessToken{ diff --git a/internal/providers/store.go b/internal/providers/store.go index ba3a6e7a12..13b9d99069 100644 --- a/internal/providers/store.go +++ b/internal/providers/store.go @@ -30,11 +30,22 @@ import ( // ProviderStore provides methods for retrieving Providers from the database type ProviderStore interface { - // GetByName returns the provider instance in the database as identified - // by its project ID and name. + // GetByID returns the provider identified by its UUID primary key. + // It is assumed that the caller carries out some kind of validation to + // ensure that whoever made the request is authorized to access this + // provider. + GetByID(ctx context.Context, providerID uuid.UUID) (*db.Provider, error) + // GetByName returns the provider instance in the database as + // identified by its project ID and name. All parent projects of the + // specified project are included in the search. GetByName(ctx context.Context, projectID uuid.UUID, name string) (*db.Provider, error) + // GetByNameInSpecificProject returns the provider instance in the database as + // identified by its project ID and name. Unlike `GetByName` it will only + // search in the specified project, and ignore the project hierarchy. + GetByNameInSpecificProject(ctx context.Context, projectID uuid.UUID, name string) (*db.Provider, error) // GetByNameAndTrait returns the providers in the project which match the - // specified trait. + // specified trait. All parent projects of the specified project are + // included in the search. // Note that if error is nil, there will always be at least one element // in the list of providers which is returned. GetByNameAndTrait( @@ -54,8 +65,17 @@ func NewProviderStore(store db.Store) ProviderStore { return &providerStore{store: store} } +func (p *providerStore) GetByID(ctx context.Context, providerID uuid.UUID) (*db.Provider, error) { + provider, err := p.store.GetProviderByID(ctx, providerID) + if err != nil { + return nil, fmt.Errorf("failed to find provider by ID: %w", err) + } + return &provider, nil +} + func (p *providerStore) GetByName(ctx context.Context, projectID uuid.UUID, name string) (*db.Provider, error) { nameFilter := getNameFilterParam(name) + providers, err := p.findProvider(ctx, nameFilter, db.NullProviderType{}, projectID) if err != nil { return nil, err @@ -77,6 +97,17 @@ func (p *providerStore) GetByName(ctx context.Context, projectID uuid.UUID, name return &providers[0], nil } +func (p *providerStore) GetByNameInSpecificProject(ctx context.Context, projectID uuid.UUID, name string) (*db.Provider, error) { + provider, err := p.store.GetProviderByName(ctx, db.GetProviderByNameParams{ + Name: name, + Projects: []uuid.UUID{projectID}, + }) + if err != nil { + return nil, fmt.Errorf("cannot retrieve provider: %w", err) + } + return &provider, nil +} + func (p *providerStore) GetByNameAndTrait( ctx context.Context, projectID uuid.UUID, @@ -111,16 +142,16 @@ func (p *providerStore) findProvider( ctx context.Context, name sql.NullString, trait db.NullProviderType, - projectId uuid.UUID, + projectID uuid.UUID, ) ([]db.Provider, error) { // Allows us to take into account the hierarchy to find the provider - parents, err := p.store.GetParentProjects(ctx, projectId) + projectHierarchy, err := p.store.GetParentProjects(ctx, projectID) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "cannot retrieve parent projects: %s", err) } provs, err := p.store.FindProviders(ctx, db.FindProvidersParams{ - Projects: parents, + Projects: projectHierarchy, Name: name, Trait: trait, })