Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Actually check for auth flows in provider enrollment #2601

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 55 additions & 12 deletions cmd/cli/app/provider/provider_enroll.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ actions such as adding repositories.`,
// EnrollProviderCommand is the command for enrolling a provider
func EnrollProviderCommand(ctx context.Context, cmd *cobra.Command, conn *grpc.ClientConn) error {
client := minderv1.NewOAuthServiceClient(conn)
provcli := minderv1.NewProvidersServiceClient(conn)

provider := viper.GetString("provider")
project := viper.GetString("project")
Expand Down Expand Up @@ -87,24 +88,66 @@ func EnrollProviderCommand(ctx context.Context, cmd *cobra.Command, conn *grpc.C
}
}

oAuthCallbackCtx, oAuthCancel := context.WithTimeout(context.Background(), MAX_WAIT+5*time.Second)
defer oAuthCancel()
prov, err := provcli.GetProvider(ctx, &minderv1.GetProviderRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
Name: provider,
})
if err != nil {
return cli.MessageAndError("Error getting provider", err)
}

if token != "" {
// use pat for enrollment
_, err := client.StoreProviderToken(context.Background(), &minderv1.StoreProviderTokenRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
AccessToken: token,
Owner: &owner,
})
if err != nil {
return cli.MessageAndError("Error storing token", err)
if !prov.Provider.SupportsAuthFlow(minderv1.AuthorizationFlow_AUTHORIZATION_FLOW_USER_INPUT) {
return fmt.Errorf("provider %s does not support token enrollment", provider)
}

cmd.Println("Provider enrolled successfully")
return nil
return enrollUsingToken(ctx, cmd, client, provider, project, token, owner)
}

if !prov.Provider.SupportsAuthFlow(
minderv1.AuthorizationFlow_AUTHORIZATION_FLOW_OAUTH2_AUTHORIZATION_CODE_FLOW) {
return fmt.Errorf("provider %s does not support OAuth2 enrollment", provider)
}

// This will have a different timeout
enrollemntCtx := cmd.Context()

return enrollUsingOAuth2Flow(enrollemntCtx, cmd, client, provider, project, owner)
}

func enrollUsingToken(
ctx context.Context,
cmd *cobra.Command,
client minderv1.OAuthServiceClient,
provider string,
project string,
token string,
owner string,
) error {
_, err := client.StoreProviderToken(ctx, &minderv1.StoreProviderTokenRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
AccessToken: token,
Owner: &owner,
})
if err != nil {
return cli.MessageAndError("Error storing token", err)
}

cmd.Println("Provider enrolled successfully")
return nil
}

func enrollUsingOAuth2Flow(
ctx context.Context,
cmd *cobra.Command,
client minderv1.OAuthServiceClient,
provider string,
project string,
owner string,
) error {
oAuthCallbackCtx, oAuthCancel := context.WithTimeout(ctx, MAX_WAIT+5*time.Second)
defer oAuthCancel()

// Get random port
port, err := rand.GetRandomPort()
if err != nil {
Expand Down
12 changes: 12 additions & 0 deletions internal/controlplane/handlers_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"fmt"
"net/http"
"net/url"
"slices"

"github.com/google/uuid"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
Expand All @@ -39,6 +40,7 @@ import (
"github.com/stacklok/minder/internal/db"
"github.com/stacklok/minder/internal/engine"
"github.com/stacklok/minder/internal/logger"
"github.com/stacklok/minder/internal/util"
pb "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1"
)

Expand All @@ -56,6 +58,11 @@ func (s *Server) GetAuthorizationURL(ctx context.Context,
return nil, providerError(err)
}

if !slices.Contains(provider.AuthFlows, db.AuthorizationFlowOauth2AuthorizationCodeFlow) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use provider.SupportsAuthFlow(...) here? I think this is the same comment jakub had down below.

Ideally, provider returned from getProviderFromRequestOrDefault would be either an actual Provider instance, or at least a ProviderBuilder instance. But that's a larger refactor.

return nil, util.UserVisibleError(codes.InvalidArgument,
"provider does not support authorization code flow")
}

// Configure tracing
// trace call to AuthCodeURL
span := trace.SpanFromContext(ctx)
Expand Down Expand Up @@ -288,6 +295,11 @@ func (s *Server) StoreProviderToken(ctx context.Context,
return nil, providerError(err)
}

if !slices.Contains(provider.AuthFlows, db.AuthorizationFlowUserInput) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not call p.SupportsAuthFlow() here? is it because one is using a protobuf constant and the other a db constant? Since we're using the constant verbatim and not from a variable, could we use the protobuf constant here as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think getProviderFromRequestOrDefault actually returns a db.Provider, not a proto object mentioned in SupportsAuthFlow.

With that said, I'd prefer to have it return something (maybe a ProviderBuilder?) that encapsulates this check more than slices.Contains on a returned database row.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a separate PR with that refactor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with a separate PR if it comes soon and we don't propagate code patters like this over time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on it.

return nil, util.UserVisibleError(codes.InvalidArgument,
"provider does not support token enrollment")
}

// validate token
err = auth.ValidateProviderToken(ctx, provider.Name, in.AccessToken)
if err != nil {
Expand Down
37 changes: 37 additions & 0 deletions internal/controlplane/handlers_oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"golang.org/x/oauth2/github"
"golang.org/x/oauth2/google"
Expand Down Expand Up @@ -114,6 +115,9 @@ func TestGetAuthorizationURL(t *testing.T) {
Return([]db.Provider{{
ID: providerID,
Name: "github",
AuthFlows: []db.AuthorizationFlow{
db.AuthorizationFlowOauth2AuthorizationCodeFlow,
},
}}, nil)
store.EXPECT().
CreateSessionState(gomock.Any(), gomock.Any()).
Expand All @@ -137,6 +141,39 @@ func TestGetAuthorizationURL(t *testing.T) {

expectedStatusCode: codes.OK,
},
{
name: "Unsupported auth flow",
req: &pb.GetAuthorizationURLRequest{
Context: &pb.Context{
Provider: &providerName,
Project: &projectIdStr,
},
Port: 8080,
Cli: true,
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
GetParentProjects(gomock.Any(), projectID).
Return([]uuid.UUID{projectID}, nil)
store.EXPECT().
ListProvidersByProjectID(gomock.Any(), []uuid.UUID{projectID}).
Return([]db.Provider{{
ID: providerID,
Name: "github",
AuthFlows: []db.AuthorizationFlow{
db.AuthorizationFlowNone,
},
}}, nil)
},

checkResponse: func(t *testing.T, _ *pb.GetAuthorizationURLResponse, err error) {
t.Helper()

assert.Error(t, err, "Expected error in GetAuthorizationURL")
},

expectedStatusCode: codes.InvalidArgument,
},
}

rpcOptions := &pb.RpcOptions{
Expand Down
7 changes: 7 additions & 0 deletions pkg/api/protobuf/go/minder/v1/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package v1

import "slices"

// ToString returns the string representation of the ProviderType
func (provt ProviderType) ToString() string {
return enumToStringViaDescriptor(provt.Descriptor(), provt.Number())
Expand All @@ -23,3 +25,8 @@ func (provt ProviderType) ToString() string {
func (a AuthorizationFlow) ToString() string {
return enumToStringViaDescriptor(a.Descriptor(), a.Number())
}

// SupportsAuthFlow returns true if the provider supports the given auth flow
func (p *Provider) SupportsAuthFlow(flow AuthorizationFlow) bool {
return slices.Contains(p.GetAuthFlows(), flow)
}
Comment on lines +28 to +32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see where this is used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is in the CLI

Loading