Skip to content

Commit

Permalink
Upsert provider access tokens instead of Create and Delete (#2486)
Browse files Browse the repository at this point in the history
For OAuth flow provider enrollments we used to delete and recreate the
provider access token in order to be able to list the new token based on
CreatedAt. For the case where we pass a PAT, we would just error out
in case the token was already created.

Let's just upsert in both cases and instead of CreatedAt let's look at
UpdatedAt.

This will be useful for tests.

Fixes: #2411
  • Loading branch information
jhrozek authored Mar 5, 2024
1 parent f7eb7b4 commit aa7c815
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 133 deletions.
59 changes: 15 additions & 44 deletions database/mock/store.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 15 additions & 10 deletions database/query/provider_access_tokens.sql
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
-- name: CreateAccessToken :one
INSERT INTO provider_access_tokens (project_id, provider, encrypted_token, expiration_time, owner_filter) VALUES ($1, $2, $3, $4, $5) RETURNING *;

-- name: GetAccessTokenByProjectID :one
SELECT * FROM provider_access_tokens WHERE provider = $1 AND project_id = $2;

-- name: UpdateAccessToken :one
UPDATE provider_access_tokens SET encrypted_token = $3, expiration_time = $4, owner_filter = $5, updated_at = NOW() WHERE provider = $1 AND project_id = $2 RETURNING *;

-- name: DeleteAccessToken :exec
DELETE FROM provider_access_tokens WHERE provider = $1 AND project_id = $2;

-- name: GetAccessTokenByProvider :many
SELECT * FROM provider_access_tokens WHERE provider = $1;

-- name: GetAccessTokenSinceDate :one
SELECT * FROM provider_access_tokens WHERE provider = $1 AND project_id = $2 AND created_at >= $3;
SELECT * FROM provider_access_tokens WHERE provider = $1 AND project_id = $2 AND updated_at >= $3;

-- name: UpsertAccessToken :one
INSERT INTO provider_access_tokens
(project_id, provider, encrypted_token, expiration_time, owner_filter)
VALUES
($1, $2, $3, $4, $5)
ON CONFLICT (project_id, provider)
DO UPDATE SET
encrypted_token = $3,
expiration_time = $4,
owner_filter = $5,
updated_at = NOW()
WHERE provider_access_tokens.project_id = $1 AND provider_access_tokens.provider = $2
RETURNING *;
27 changes: 9 additions & 18 deletions internal/controlplane/handlers_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ import (
"github.com/stacklok/minder/internal/db"
"github.com/stacklok/minder/internal/engine"
"github.com/stacklok/minder/internal/logger"
"github.com/stacklok/minder/internal/util"
pb "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1"
)

Expand Down Expand Up @@ -246,23 +245,14 @@ func (s *Server) generateOAuthToken(ctx context.Context, provider string, code s

encodedToken := base64.StdEncoding.EncodeToString(encryptedToken)

// delete token if it exists
err = s.store.DeleteAccessToken(ctx, db.DeleteAccessTokenParams{
Provider: provider,
ProjectID: stateData.ProjectID,
})
if err != nil {
return fmt.Errorf("error deleting access token: %w", err)
}

var owner sql.NullString
if stateData.OwnerFilter.Valid {
owner = sql.NullString{Valid: true, String: stateData.OwnerFilter.String}
} else {
owner = sql.NullString{Valid: false}
}

_, err = s.store.CreateAccessToken(ctx, db.CreateAccessTokenParams{
_, err = s.store.UpsertAccessToken(ctx, db.UpsertAccessTokenParams{
ProjectID: stateData.ProjectID,
Provider: provider,
EncryptedToken: encodedToken,
Expand Down Expand Up @@ -337,12 +327,13 @@ func (s *Server) StoreProviderToken(ctx context.Context,
owner = sql.NullString{String: *in.Owner, Valid: true}
}

_, err = s.store.CreateAccessToken(ctx, db.CreateAccessTokenParams{ProjectID: projectID, Provider: provider.Name,
EncryptedToken: encodedToken, OwnerFilter: owner})

if db.ErrIsUniqueViolation(err) {
return nil, util.UserVisibleError(codes.AlreadyExists, "token already exists")
} else if err != nil {
_, err = s.store.UpsertAccessToken(ctx, db.UpsertAccessTokenParams{
ProjectID: projectID,
Provider: provider.Name,
EncryptedToken: encodedToken,
OwnerFilter: owner,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "error storing access token: %v", err)
}

Expand All @@ -366,7 +357,7 @@ func (s *Server) VerifyProviderTokenFrom(ctx context.Context,

// check if a token has been created since timestamp
_, err = s.store.GetAccessTokenSinceDate(ctx,
db.GetAccessTokenSinceDateParams{Provider: provider.Name, ProjectID: projectID, CreatedAt: in.Timestamp.AsTime()})
db.GetAccessTokenSinceDateParams{Provider: provider.Name, ProjectID: projectID, UpdatedAt: in.Timestamp.AsTime()})
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return &pb.VerifyProviderTokenFromResponse{Status: "KO"}, nil
Expand Down
79 changes: 21 additions & 58 deletions internal/db/provider_access_tokens.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 65 additions & 0 deletions internal/db/provider_access_tokens_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
//
// Copyright 2023 Stacklok, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package db

import (
"context"
"database/sql"
"testing"

"github.com/stretchr/testify/require"
)

func TestUpsertProviderAccessToken(t *testing.T) {
t.Parallel()

org := createRandomOrganization(t)
project := createRandomProject(t, org.ID)
prov := createRandomProvider(t, project.ID)

tok, err := testQueries.UpsertAccessToken(context.Background(), UpsertAccessTokenParams{
ProjectID: project.ID,
Provider: prov.Name,
EncryptedToken: "abc",
OwnerFilter: sql.NullString{},
})

require.NoError(t, err)
require.NotEmpty(t, tok)
require.NotEmpty(t, tok.ID)
require.NotEmpty(t, tok.CreatedAt)
require.NotEmpty(t, tok.UpdatedAt)
require.Equal(t, project.ID, tok.ProjectID)
require.Equal(t, prov.Name, tok.Provider)
require.Equal(t, "abc", tok.EncryptedToken)
require.Equal(t, sql.NullString{}, tok.OwnerFilter)

tokUpdate, err := testQueries.UpsertAccessToken(context.Background(), UpsertAccessTokenParams{
ProjectID: project.ID,
Provider: prov.Name,
EncryptedToken: "def",
OwnerFilter: sql.NullString{},
})

require.NoError(t, err)
require.Equal(t, project.ID, tokUpdate.ProjectID)
require.Equal(t, prov.Name, tokUpdate.Provider)
require.Equal(t, "def", tokUpdate.EncryptedToken)
require.Equal(t, sql.NullString{}, tokUpdate.OwnerFilter)
require.Equal(t, tok.ID, tokUpdate.ID)
require.Equal(t, tok.CreatedAt, tokUpdate.CreatedAt)
require.NotEqual(t, tok.UpdatedAt, tokUpdate.UpdatedAt)
}
4 changes: 1 addition & 3 deletions internal/db/querier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit aa7c815

Please sign in to comment.