Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move more provider access methods to store #2911

Merged
merged 2 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions internal/controlplane/handlers_githubwebhooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,16 +412,8 @@ func (s *Server) parseGithubEventForProcessing(
return fmt.Errorf("error getting repo information from payload: %w", err)
}

ph, err := s.store.GetParentProjects(ctx, dbRepo.ProjectID)
if err != nil {
return fmt.Errorf("error getting project hierarchy: %w", err)
}

// get the provider for the repository
prov, err := s.store.GetProviderByName(ctx, db.GetProviderByNameParams{
Name: dbRepo.Provider,
Projects: ph,
})
prov, err := s.providerStore.GetByName(ctx, dbRepo.ProjectID, dbRepo.Provider)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Genuine question: The old code was calling GetParentProjects and then getting the provider by name in parent. The new code goes to GetByName, which calls findProjects which calls getProjectHierarchy. Are those equivalent?

Copy link
Contributor Author

@dmjb dmjb Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getProjectHierarchy also calls GetParentProjects. In fact, I may get rid of the separate getProjectHierarchy function since it's only called in one places (originally in this branch it was called in a second place too)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That being said - one important difference is that the GetByName will check to see if there are any duplicate results and return an error if there is ambiguity about being what is returned, whereas the code I have replaced will not (see the implementation of GetByName).

@eleftherias - since I think you wrote this code: is there any possibility that the original query could match more than one provider?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is possible, but only in a strange corner case where someone renamed their GitHub organization to have the same name as a GitHub organization that they previously installed Minder on. We generally don't handle renames yet, so I think we can ignore that case for now.

if err != nil {
return fmt.Errorf("error getting provider: %w", err)
}
Expand All @@ -430,7 +422,7 @@ func (s *Server) parseGithubEventForProcessing(
providers.WithProviderMetrics(s.provMt),
providers.WithRestClientCache(s.restClientCache),
}
provBuilder, err := providers.GetProviderBuilder(ctx, prov, s.store, s.cryptoEngine, &s.cfg.Provider, pbOpts...)
provBuilder, err := providers.GetProviderBuilder(ctx, *prov, s.store, s.cryptoEngine, &s.cfg.Provider, pbOpts...)
if err != nil {
return fmt.Errorf("error building client: %w", err)
}
Expand Down
9 changes: 5 additions & 4 deletions internal/controlplane/handlers_githubwebhooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,16 +276,17 @@ func (s *UnitTestSuite) TestHandleWebHookRepository() {
}, nil)

mockStore.EXPECT().
GetProviderByName(gomock.Any(), db.GetProviderByNameParams{
Name: providerName,
FindProviders(gomock.Any(), db.FindProvidersParams{
Name: sql.NullString{String: providerName, Valid: true},
Projects: []uuid.UUID{
projectID,
},
Trait: db.NullProviderType{},
}).
Return(db.Provider{
Return([]db.Provider{{
ProjectID: projectID,
Name: providerName,
}, nil)
}}, nil)

hook := srv.HandleGitHubWebHook()
port, err := rand.GetRandomPort()
Expand Down
14 changes: 3 additions & 11 deletions internal/controlplane/handlers_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,11 +392,7 @@ func (s *Server) StoreProviderToken(ctx context.Context,
return nil, status.Errorf(codes.InvalidArgument, "provider name is required")
}

provider, err := s.store.GetProviderByName(ctx, db.GetProviderByNameParams{
// We don't check parent projects here because subprojects should not update the credentials of their parent
Projects: []uuid.UUID{projectID},
Name: providerName,
})
provider, err := s.providerStore.GetByNameInSpecificProject(ctx, projectID, providerName)
if err != nil {
return nil, providerError(err)
}
Expand Down Expand Up @@ -500,11 +496,7 @@ func (s *Server) VerifyProviderTokenFrom(ctx context.Context,
return &pb.VerifyProviderTokenFromResponse{Status: "KO"}, nil
}

provider, err := s.store.GetProviderByName(ctx, db.GetProviderByNameParams{
// We don't check parent projects here because subprojects should not check the credentials of their parent
Projects: []uuid.UUID{projectID},
Name: providerName,
})
provider, err := s.providerStore.GetByNameInSpecificProject(ctx, projectID, providerName)
if err != nil {
return nil, providerError(err)
}
Expand Down Expand Up @@ -560,7 +552,7 @@ func (s *Server) VerifyProviderCredential(ctx context.Context,
)

if err == nil {
provider, err := s.store.GetProviderByID(ctx, installation.ProviderID.UUID)
provider, err := s.providerStore.GetByID(ctx, installation.ProviderID.UUID)
if err != nil {
return nil, status.Errorf(codes.Internal, "error getting provider: %v", err)
}
Expand Down
1 change: 1 addition & 0 deletions internal/controlplane/handlers_oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,7 @@ func TestVerifyProviderCredential(t *testing.T) {
store: store,
evt: evt,
providerAuthFactory: providerAuthFactory,
providerStore: providers.NewProviderStore(store),
cfg: &serverconfig.Config{
Auth: serverconfig.AuthConfig{},
},
Expand Down
16 changes: 5 additions & 11 deletions internal/controlplane/handlers_providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"database/sql"
"errors"

"github.com/google/uuid"
"github.com/rs/zerolog"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand All @@ -39,12 +38,7 @@ func (s *Server) GetProvider(ctx context.Context, req *minderv1.GetProviderReque
entityCtx := engine.EntityFromContext(ctx)
projectID := entityCtx.Project.ID

prov, err := s.store.GetProviderByName(ctx, db.GetProviderByNameParams{
Name: req.Name,
// Note that this does not take the hierarchy into account in purpose.
// We want to get this call to be explicit for the given project.
Projects: []uuid.UUID{projectID},
})
prov, err := s.providerStore.GetByNameInSpecificProject(ctx, projectID, req.GetName())
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, util.UserVisibleError(codes.NotFound, "provider not found")
Expand All @@ -66,11 +60,11 @@ func (s *Server) GetProvider(ctx context.Context, req *minderv1.GetProviderReque
Name: prov.Name,
Project: projectID.String(),
Version: prov.Version,
Implements: protobufProviderImplementsFromDB(ctx, prov),
AuthFlows: protobufProviderAuthFlowFromDB(ctx, prov),
Implements: protobufProviderImplementsFromDB(ctx, *prov),
AuthFlows: protobufProviderAuthFlowFromDB(ctx, *prov),
Config: cfg,
CredentialsState: providers.GetCredentialStateForProvider(ctx, prov, s.store, s.cryptoEngine, &s.cfg.Provider),
Class: providers.GetProviderClassString(prov),
CredentialsState: providers.GetCredentialStateForProvider(ctx, *prov, s.store, s.cryptoEngine, &s.cfg.Provider),
Class: providers.GetProviderClassString(*prov),
},
}, nil
}
Expand Down
34 changes: 9 additions & 25 deletions internal/controlplane/handlers_repositories.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ func (s *Server) RegisterRepository(
return nil, status.Errorf(codes.Internal, "cannot get provider: %v", err)
}

client, err := s.getClientForProvider(ctx, provider)
client, err := s.getClientForProvider(ctx, *provider)
if err != nil {
return nil, err
}

newRepo, err := s.repos.CreateRepository(ctx, client, &provider, projectID, githubRepo.GetOwner(), githubRepo.GetName())
newRepo, err := s.repos.CreateRepository(ctx, client, provider, projectID, githubRepo.GetOwner(), githubRepo.GetName())
if err != nil {
if errors.Is(err, ghrepo.ErrPrivateRepoForbidden) {
return nil, util.UserVisibleError(codes.InvalidArgument, err.Error())
Expand Down Expand Up @@ -425,12 +425,12 @@ func (s *Server) deleteRepository(
return status.Errorf(codes.Internal, "unexpected error fetching repo: %v", err)
}

provider, err := s.findProviderByName(ctx, repo.Provider, projectID)
provider, err := s.providerStore.GetByName(ctx, projectID, repo.Provider)
if err != nil {
return status.Errorf(codes.Internal, "cannot get provider: %v", err)
}

client, err := s.getClientForProvider(ctx, provider)
client, err := s.getClientForProvider(ctx, *provider)
if err != nil {
return status.Errorf(codes.Internal, "cannot get client for provider: %v", err)
}
Expand All @@ -445,13 +445,13 @@ func (s *Server) deleteRepository(
// TODO: move out of controlplane
// inferProviderByOwner returns the provider to use for a given repo owner
func (s *Server) inferProviderByOwner(ctx context.Context, owner string, projectID uuid.UUID, providerName string,
) (db.Provider, error) {
) (*db.Provider, error) {
if providerName != "" {
return s.findProviderByName(ctx, providerName, projectID)
return s.providerStore.GetByName(ctx, projectID, providerName)
}
opts, err := s.providerStore.GetByNameAndTrait(ctx, projectID, providerName, db.ProviderTypeGithub)
if err != nil {
return db.Provider{}, fmt.Errorf("error getting providers: %v", err)
return nil, fmt.Errorf("error getting providers: %v", err)
}

slices.SortFunc(opts, func(a, b db.Provider) int {
Expand All @@ -467,11 +467,11 @@ func (s *Server) inferProviderByOwner(ctx context.Context, owner string, project

for _, prov := range opts {
if github.CanHandleOwner(ctx, prov, owner) {
return prov, nil
return &prov, nil
}
}

return db.Provider{}, fmt.Errorf("no providers can handle repo owned by %s", owner)
return nil, fmt.Errorf("no providers can handle repo owned by %s", owner)
}

func (s *Server) getClientForProvider(
Expand All @@ -496,22 +496,6 @@ func (s *Server) getClientForProvider(
return client, nil
}

func (s *Server) findProviderByName(ctx context.Context, providerName string, projectID uuid.UUID) (db.Provider, error) {
ph, err := s.store.GetParentProjects(ctx, projectID)
if err != nil {
return db.Provider{}, fmt.Errorf("error getting project hierarchy: %v", err)
}

provider, err := s.store.GetProviderByName(ctx, db.GetProviderByNameParams{
Name: providerName,
Projects: ph,
})
if err != nil {
return db.Provider{}, fmt.Errorf("error getting provider: %v", err)
}
return provider, nil
}

func getProjectID(ctx context.Context) uuid.UUID {
entityCtx := engine.EntityFromContext(ctx)
return entityCtx.Project.ID
Expand Down
12 changes: 6 additions & 6 deletions internal/controlplane/handlers_repositories_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestServer_RegisterRepository(t *testing.T) {
RepoOwner: repoOwner,
RepoName: repoName,
ProviderFails: true,
ExpectedError: "error getting provider",
ExpectedError: "cannot retrieve provider",
},
{
Name: "Repo creation fails when repo name is missing",
Expand Down Expand Up @@ -150,7 +150,7 @@ func TestServer_DeleteRepository(t *testing.T) {
RepoName: repoOwnerAndName,
RepoServiceSetup: newRepoService(withSuccessfulGetRepoByName),
ProviderFails: true,
ExpectedError: "error getting provider",
ExpectedError: "cannot retrieve provider",
},
{
Name: "delete by name fails when name is malformed",
Expand Down Expand Up @@ -354,12 +354,12 @@ func createServer(ctrl *gomock.Controller, repoServiceSetup repoMockBuilder, pro

if providerFails {
store.EXPECT().
GetProviderByName(gomock.Any(), gomock.Any()).
Return(db.Provider{}, errDefault)
FindProviders(gomock.Any(), gomock.Any()).
Return([]db.Provider{}, errDefault)
} else {
store.EXPECT().
GetProviderByName(gomock.Any(), gomock.Any()).
Return(provider, nil).AnyTimes()
FindProviders(gomock.Any(), gomock.Any()).
Return([]db.Provider{provider}, nil).AnyTimes()
store.EXPECT().
GetAccessTokenByProjectID(gomock.Any(), gomock.Any()).
Return(db.ProviderAccessToken{
Expand Down
43 changes: 37 additions & 6 deletions internal/providers/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,22 @@ import (

// ProviderStore provides methods for retrieving Providers from the database
type ProviderStore interface {
// GetByName returns the provider instance in the database as identified
// by its project ID and name.
// GetByID returns the provider identified by its UUID primary key.
// It is assumed that the caller carries out some kind of validation to
// ensure that whoever made the request is authorized to access this
// provider.
GetByID(ctx context.Context, providerID uuid.UUID) (*db.Provider, error)
// GetByName returns the provider instance in the database as
// identified by its project ID and name. All parent projects of the
// specified project are included in the search.
GetByName(ctx context.Context, projectID uuid.UUID, name string) (*db.Provider, error)
// GetByNameInSpecificProject returns the provider instance in the database as
// identified by its project ID and name. Unlike `GetByName` it will only
// search in the specified project, and ignore the project hierarchy.
GetByNameInSpecificProject(ctx context.Context, projectID uuid.UUID, name string) (*db.Provider, error)
// GetByNameAndTrait returns the providers in the project which match the
// specified trait.
// specified trait. All parent projects of the specified project are
// included in the search.
// Note that if error is nil, there will always be at least one element
// in the list of providers which is returned.
GetByNameAndTrait(
Expand All @@ -54,8 +65,17 @@ func NewProviderStore(store db.Store) ProviderStore {
return &providerStore{store: store}
}

func (p *providerStore) GetByID(ctx context.Context, providerID uuid.UUID) (*db.Provider, error) {
provider, err := p.store.GetProviderByID(ctx, providerID)
if err != nil {
return nil, fmt.Errorf("failed to find provider by ID: %w", err)
}
return &provider, nil
}

func (p *providerStore) GetByName(ctx context.Context, projectID uuid.UUID, name string) (*db.Provider, error) {
nameFilter := getNameFilterParam(name)

providers, err := p.findProvider(ctx, nameFilter, db.NullProviderType{}, projectID)
if err != nil {
return nil, err
Expand All @@ -77,6 +97,17 @@ func (p *providerStore) GetByName(ctx context.Context, projectID uuid.UUID, name
return &providers[0], nil
}

func (p *providerStore) GetByNameInSpecificProject(ctx context.Context, projectID uuid.UUID, name string) (*db.Provider, error) {
provider, err := p.store.GetProviderByName(ctx, db.GetProviderByNameParams{
Name: name,
Projects: []uuid.UUID{projectID},
})
if err != nil {
return nil, fmt.Errorf("cannot retrieve provider: %w", err)
}
return &provider, nil
}

func (p *providerStore) GetByNameAndTrait(
ctx context.Context,
projectID uuid.UUID,
Expand Down Expand Up @@ -111,16 +142,16 @@ func (p *providerStore) findProvider(
ctx context.Context,
name sql.NullString,
trait db.NullProviderType,
projectId uuid.UUID,
projectID uuid.UUID,
) ([]db.Provider, error) {
// Allows us to take into account the hierarchy to find the provider
parents, err := p.store.GetParentProjects(ctx, projectId)
projectHierarchy, err := p.store.GetParentProjects(ctx, projectID)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "cannot retrieve parent projects: %s", err)
}

provs, err := p.store.FindProviders(ctx, db.FindProvidersParams{
Projects: parents,
Projects: projectHierarchy,
Name: name,
Trait: trait,
})
Expand Down
Loading