Skip to content

Commit

Permalink
Merge pull request #369 from elisasre/feat/auth-refactor
Browse files Browse the repository at this point in the history
Combine Add and Rotate keys
  • Loading branch information
heppu authored Aug 22, 2024
2 parents 2dd4b6b + 554fe9e commit 758360c
Show file tree
Hide file tree
Showing 10 changed files with 153 additions and 142 deletions.
1 change: 1 addition & 0 deletions v2/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

// JWTKey is struct for storing auth private keys.
type JWTKey struct {
CreatedAt time.Time `yaml:"created_at" json:"created_at"`
KID string `yaml:"kid" json:"kid"`
PrivateKey *rsa.PrivateKey `yaml:"-" json:"-"`
PublicKey *rsa.PublicKey `yaml:"-" json:"-"`
Expand Down
20 changes: 6 additions & 14 deletions v2/auth/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ import (

// Datastore represents required storage interface.
type Datastore interface {
AddJWTKey(context.Context, auth.JWTKey) error
ListJWTKeys(context.Context) ([]auth.JWTKey, error)
RotateJWTKeys(context.Context, string) error
RotateJWTKeys(context.Context, auth.JWTKey) error
}

// Cache provides in-memory owerlay for JWT key storage with key rotation functionality.
Expand Down Expand Up @@ -66,22 +65,15 @@ func (db *Cache) RotateKeys(ctx context.Context) error {
return fmt.Errorf("error GenerateNewKeyPair: %w", err)
}

if err := db.store.AddJWTKey(ctx, keys); err != nil {
return fmt.Errorf("error AddKeys: %w", err)
}

err = db.store.RotateJWTKeys(ctx, keys.KID)
if err != nil {
if err := db.store.RotateJWTKeys(ctx, keys); err != nil {
return err
}

newKeys, err := db.refreshKeys(ctx, false)
if err != nil {
if _, err := db.refreshKeys(ctx, true); err != nil {
return err
}
db.keys = newKeys
slog.Info("JWT RotateKeys called",
slog.Any("keys", getKIDs(db.keys)),

slog.Info("JWT RotateKeys finished",
slog.Duration("duration", time.Since(start)),
)
return nil
Expand All @@ -94,7 +86,7 @@ func (db *Cache) refreshKeys(ctx context.Context, reload bool) ([]auth.JWTKey, e
}
if reload {
db.keys = keys
slog.Info("JWT RefreshKeys called",
slog.Info("JWT RefreshKeys executed",
slog.Any("keys", getKIDs(db.keys)),
)
}
Expand Down
10 changes: 3 additions & 7 deletions v2/auth/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,15 @@ type DB struct {
keys []auth.JWTKey
}

func (store *DB) AddJWTKey(c context.Context, payload auth.JWTKey) error {
store.keys = append(store.keys, payload)
return nil
}

func (store *DB) ListJWTKeys(c context.Context) ([]auth.JWTKey, error) {
return store.keys, nil
}

func (store *DB) RotateJWTKeys(c context.Context, kid string) error {
func (store *DB) RotateJWTKeys(c context.Context, payload auth.JWTKey) error {
store.keys = append(store.keys, payload)
out := []auth.JWTKey{}
for _, key := range store.keys {
if key.KID != kid {
if key.KID != payload.KID {
key.PrivateKey = nil
}
out = append(out, key)
Expand Down
12 changes: 4 additions & 8 deletions v2/auth/store/memory/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,6 @@ func New() *Memory {
return &Memory{keys: make([]auth.JWTKey, 0, 3)}
}

// AddJWTKey adds jwt key to storage.
func (m *Memory) AddJWTKey(_ context.Context, key auth.JWTKey) error {
m.keys = append([]auth.JWTKey{key}, m.keys...)
return nil
}

// GetKeys fetch all keys from cache.
func (m *Memory) ListJWTKeys(context.Context) ([]auth.JWTKey, error) {
data := make([]auth.JWTKey, len(m.keys))
Expand All @@ -31,10 +25,12 @@ func (m *Memory) ListJWTKeys(context.Context) ([]auth.JWTKey, error) {
}

// RotateKeys rotates the jwt secrets.
func (m *Memory) RotateJWTKeys(_ context.Context, kid string) error {
func (m *Memory) RotateJWTKeys(_ context.Context, key auth.JWTKey) error {
m.keys = append([]auth.JWTKey{key}, m.keys...)

// private key is needed only in newest which are used to generate new tokens
for i := range m.keys {
if m.keys[i].KID != kid {
if m.keys[i].KID != key.KID {
m.keys[i].PrivateKey = nil
}
}
Expand Down
115 changes: 54 additions & 61 deletions v2/auth/store/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,53 +56,6 @@ type RawKey struct {
PublicKey []byte `db:"public_key_as_bytes"`
}

// AddJWTKey adds new keys to database.
func (db *DB) AddJWTKey(c context.Context, payload auth.JWTKey) error {
privKey, err := auth.Encrypt(auth.EncodePrivateKeyToPEM(payload.PrivateKey), keySecret(payload.KID, db.secret))
if err != nil {
return err
}

key := RawKey{
KID: payload.KID,
PrivateKey: privKey,
PublicKey: auth.EncodePublicKeyToPEM(payload.PublicKey),
}

if err := db.addKey(c, &key); err != nil {
return err
}

return nil
}

func (db *DB) addKey(c context.Context, key *RawKey) error {
span := sentryutil.MakeSpan(c, 1)
defer span.Finish()

const query = `
INSERT INTO jwt_keys (
k_id,
private_key_as_bytes,
public_key_as_bytes,
created_at,
updated_at
) VALUES (
:k_id,
:private_key_as_bytes,
:public_key_as_bytes,
:created_at,
:updated_at
)
RETURNING id`

if err := sqlxutil.CreateNamed(c, db.db, key, query); err != nil {
return fmt.Errorf("creating jwt_key failed: %w", err)
}

return nil
}

// ListJWTKeys lists the keys from database.
func (db *DB) ListJWTKeys(c context.Context) ([]auth.JWTKey, error) {
span := sentryutil.MakeSpan(c, 1)
Expand Down Expand Up @@ -131,31 +84,57 @@ func (db *DB) ListJWTKeys(c context.Context) ([]auth.JWTKey, error) {
}

// RotateJWTKeys rotates the JWT keys in database.
func (db *DB) RotateJWTKeys(c context.Context, kid string) error {
span := sentryutil.MakeSpan(c, 1)
func (db *DB) RotateJWTKeys(ctx context.Context, new auth.JWTKey) error {
span := sentryutil.MakeSpan(ctx, 1)
defer span.Finish()

const updateQuery = `
UPDATE jwt_keys
SET private_key_as_bytes=NULL
WHERE k_id != $1`
if _, err := db.db.ExecContext(c, updateQuery, kid); err != nil {
return fmt.Errorf("resetting old jwt keys failed: %w", err)
key, err := prepareRawKey(new, db.secret)
if err != nil {
return err
}

// keep 3 latest ones
const deleteQuery = `
return sqlxutil.WithTx(ctx, db.db, func(ctx context.Context, tx *sqlx.Tx) error {
const addQuery = `
INSERT INTO jwt_keys (
k_id,
private_key_as_bytes,
public_key_as_bytes,
created_at,
updated_at
) VALUES (
:k_id,
:private_key_as_bytes,
:public_key_as_bytes,
:created_at,
:updated_at
)
RETURNING id`
if err := sqlxutil.CreateNamed(ctx, tx, key, addQuery); err != nil {
return fmt.Errorf("adding jwt key to db failed: %w", err)
}

const updateQuery = `
UPDATE jwt_keys
SET private_key_as_bytes=NULL
WHERE k_id != $1`
if _, err := tx.ExecContext(ctx, updateQuery, new.KID); err != nil {
return fmt.Errorf("resetting old jwt keys failed: %w", err)
}

// keep 3 latest ones
const deleteQuery = `
DELETE FROM jwt_keys
WHERE id not in (
SELECT id
FROM jwt_keys
ORDER BY ID DESC
LIMIT 3
)`
if _, err := db.db.ExecContext(c, deleteQuery); err != nil {
return fmt.Errorf("deleting old jwt keys failed: %w", err)
}
return nil
if _, err := tx.ExecContext(ctx, deleteQuery); err != nil {
return fmt.Errorf("deleting old jwt keys failed: %w", err)
}
return nil
})
}

func DecryptRawKey(key RawKey, secret string) (auth.JWTKey, error) {
Expand All @@ -166,6 +145,7 @@ func DecryptRawKey(key RawKey, secret string) (auth.JWTKey, error) {
}

response := auth.JWTKey{
CreatedAt: key.CreatedAt,
KID: key.KID,
PublicKey: pub,
}
Expand All @@ -189,3 +169,16 @@ func DecryptRawKey(key RawKey, secret string) (auth.JWTKey, error) {
func keySecret(kid string, secret string) string {
return fmt.Sprintf("%s.%s", secret, kid)
}

func prepareRawKey(key auth.JWTKey, secret string) (*RawKey, error) {
privKey, err := auth.Encrypt(auth.EncodePrivateKeyToPEM(key.PrivateKey), keySecret(key.KID, secret))
if err != nil {
return nil, err
}

return &RawKey{
KID: key.KID,
PrivateKey: privKey,
PublicKey: auth.EncodePublicKeyToPEM(key.PublicKey),
}, nil
}
34 changes: 34 additions & 0 deletions v2/httputil/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package httputil

import (
"fmt"
"net/http"
"strings"
)

func IsHTTPS(r *http.Request) bool {
const protoHTTPS = "https"
switch {
case r.URL.Scheme == protoHTTPS:
return true
case r.TLS != nil:
return true
case strings.HasPrefix(strings.ToLower(r.Proto), protoHTTPS):
return true
case r.Header.Get("X-Forwarded-Proto") == protoHTTPS:
return true
default:
return false
}
}

func (e ErrorResponse) Error() string {
return fmt.Sprintf("%d: %s", e.Code, e.Message)
}

// ErrorResponse provides HTTP error response.
type ErrorResponse struct {
Code uint `json:"code,omitempty" example:"400"`
Message string `json:"message" example:"Bad request"`
ErrorType string `json:"error_type,omitempty" example:"invalid_scope"`
}
Loading

0 comments on commit 758360c

Please sign in to comment.