From 005844251cce865d4a444f7cc986c16ce16b4c3b Mon Sep 17 00:00:00 2001 From: Henri Koski Date: Thu, 22 Aug 2024 19:34:16 +0300 Subject: [PATCH 1/5] Combine Add and Rotate keys --- v2/auth/auth.go | 1 + v2/auth/cache/cache.go | 20 ++--- v2/auth/cache/cache_test.go | 10 +-- v2/auth/store/memory/memory.go | 12 +-- v2/auth/store/postgres/postgres.go | 115 ++++++++++++++--------------- 5 files changed, 68 insertions(+), 90 deletions(-) diff --git a/v2/auth/auth.go b/v2/auth/auth.go index de54b85..2cfbafb 100644 --- a/v2/auth/auth.go +++ b/v2/auth/auth.go @@ -28,6 +28,7 @@ import ( // JWTKey is struct for storing auth private keys. type JWTKey struct { + CreatedAt time.Time `yaml:"created_at" json:"created_at"` KID string `yaml:"kid" json:"kid"` PrivateKey *rsa.PrivateKey `yaml:"-" json:"-"` PublicKey *rsa.PublicKey `yaml:"-" json:"-"` diff --git a/v2/auth/cache/cache.go b/v2/auth/cache/cache.go index ef23abf..3eee5a3 100644 --- a/v2/auth/cache/cache.go +++ b/v2/auth/cache/cache.go @@ -12,9 +12,8 @@ import ( // Datastore represents required storage interface. type Datastore interface { - AddJWTKey(context.Context, auth.JWTKey) error ListJWTKeys(context.Context) ([]auth.JWTKey, error) - RotateJWTKeys(context.Context, string) error + RotateJWTKeys(context.Context, auth.JWTKey) error } // Cache provides in-memory owerlay for JWT key storage with key rotation functionality. @@ -66,22 +65,15 @@ func (db *Cache) RotateKeys(ctx context.Context) error { return fmt.Errorf("error GenerateNewKeyPair: %w", err) } - if err := db.store.AddJWTKey(ctx, keys); err != nil { - return fmt.Errorf("error AddKeys: %w", err) - } - - err = db.store.RotateJWTKeys(ctx, keys.KID) - if err != nil { + if err := db.store.RotateJWTKeys(ctx, keys); err != nil { return err } - newKeys, err := db.refreshKeys(ctx, false) - if err != nil { + if _, err := db.refreshKeys(ctx, true); err != nil { return err } - db.keys = newKeys - slog.Info("JWT RotateKeys called", - slog.Any("keys", getKIDs(db.keys)), + + slog.Info("JWT RotateKeys finished", slog.Duration("duration", time.Since(start)), ) return nil @@ -94,7 +86,7 @@ func (db *Cache) refreshKeys(ctx context.Context, reload bool) ([]auth.JWTKey, e } if reload { db.keys = keys - slog.Info("JWT RefreshKeys called", + slog.Info("JWT RefreshKeys executed", slog.Any("keys", getKIDs(db.keys)), ) } diff --git a/v2/auth/cache/cache_test.go b/v2/auth/cache/cache_test.go index 460eb09..9d2fddd 100644 --- a/v2/auth/cache/cache_test.go +++ b/v2/auth/cache/cache_test.go @@ -13,19 +13,15 @@ type DB struct { keys []auth.JWTKey } -func (store *DB) AddJWTKey(c context.Context, payload auth.JWTKey) error { - store.keys = append(store.keys, payload) - return nil -} - func (store *DB) ListJWTKeys(c context.Context) ([]auth.JWTKey, error) { return store.keys, nil } -func (store *DB) RotateJWTKeys(c context.Context, kid string) error { +func (store *DB) RotateJWTKeys(c context.Context, payload auth.JWTKey) error { + store.keys = append(store.keys, payload) out := []auth.JWTKey{} for _, key := range store.keys { - if key.KID != kid { + if key.KID != payload.KID { key.PrivateKey = nil } out = append(out, key) diff --git a/v2/auth/store/memory/memory.go b/v2/auth/store/memory/memory.go index 94e546d..86bb550 100644 --- a/v2/auth/store/memory/memory.go +++ b/v2/auth/store/memory/memory.go @@ -17,12 +17,6 @@ func New() *Memory { return &Memory{keys: make([]auth.JWTKey, 0, 3)} } -// AddJWTKey adds jwt key to storage. -func (m *Memory) AddJWTKey(_ context.Context, key auth.JWTKey) error { - m.keys = append([]auth.JWTKey{key}, m.keys...) - return nil -} - // GetKeys fetch all keys from cache. func (m *Memory) ListJWTKeys(context.Context) ([]auth.JWTKey, error) { data := make([]auth.JWTKey, len(m.keys)) @@ -31,10 +25,12 @@ func (m *Memory) ListJWTKeys(context.Context) ([]auth.JWTKey, error) { } // RotateKeys rotates the jwt secrets. -func (m *Memory) RotateJWTKeys(_ context.Context, kid string) error { +func (m *Memory) RotateJWTKeys(_ context.Context, key auth.JWTKey) error { + m.keys = append([]auth.JWTKey{key}, m.keys...) + // private key is needed only in newest which are used to generate new tokens for i := range m.keys { - if m.keys[i].KID != kid { + if m.keys[i].KID != key.KID { m.keys[i].PrivateKey = nil } } diff --git a/v2/auth/store/postgres/postgres.go b/v2/auth/store/postgres/postgres.go index 9d8fd04..7db5a32 100644 --- a/v2/auth/store/postgres/postgres.go +++ b/v2/auth/store/postgres/postgres.go @@ -56,53 +56,6 @@ type RawKey struct { PublicKey []byte `db:"public_key_as_bytes"` } -// AddJWTKey adds new keys to database. -func (db *DB) AddJWTKey(c context.Context, payload auth.JWTKey) error { - privKey, err := auth.Encrypt(auth.EncodePrivateKeyToPEM(payload.PrivateKey), keySecret(payload.KID, db.secret)) - if err != nil { - return err - } - - key := RawKey{ - KID: payload.KID, - PrivateKey: privKey, - PublicKey: auth.EncodePublicKeyToPEM(payload.PublicKey), - } - - if err := db.addKey(c, &key); err != nil { - return err - } - - return nil -} - -func (db *DB) addKey(c context.Context, key *RawKey) error { - span := sentryutil.MakeSpan(c, 1) - defer span.Finish() - - const query = ` - INSERT INTO jwt_keys ( - k_id, - private_key_as_bytes, - public_key_as_bytes, - created_at, - updated_at - ) VALUES ( - :k_id, - :private_key_as_bytes, - :public_key_as_bytes, - :created_at, - :updated_at - ) - RETURNING id` - - if err := sqlxutil.CreateNamed(c, db.db, key, query); err != nil { - return fmt.Errorf("creating jwt_key failed: %w", err) - } - - return nil -} - // ListJWTKeys lists the keys from database. func (db *DB) ListJWTKeys(c context.Context) ([]auth.JWTKey, error) { span := sentryutil.MakeSpan(c, 1) @@ -131,20 +84,45 @@ func (db *DB) ListJWTKeys(c context.Context) ([]auth.JWTKey, error) { } // RotateJWTKeys rotates the JWT keys in database. -func (db *DB) RotateJWTKeys(c context.Context, kid string) error { - span := sentryutil.MakeSpan(c, 1) +func (db *DB) RotateJWTKeys(ctx context.Context, new auth.JWTKey) error { + span := sentryutil.MakeSpan(ctx, 1) defer span.Finish() - const updateQuery = ` - UPDATE jwt_keys - SET private_key_as_bytes=NULL - WHERE k_id != $1` - if _, err := db.db.ExecContext(c, updateQuery, kid); err != nil { - return fmt.Errorf("resetting old jwt keys failed: %w", err) + key, err := prepareRawKey(new, db.secret) + if err != nil { + return err } - // keep 3 latest ones - const deleteQuery = ` + return sqlxutil.WithTx(ctx, db.db, func(ctx context.Context, tx *sqlx.Tx) error { + const addQuery = ` + INSERT INTO jwt_keys ( + k_id, + private_key_as_bytes, + public_key_as_bytes, + created_at, + updated_at + ) VALUES ( + :k_id, + :private_key_as_bytes, + :public_key_as_bytes, + :created_at, + :updated_at + ) + RETURNING id` + if err := sqlxutil.CreateNamed(ctx, tx, key, addQuery); err != nil { + return fmt.Errorf("adding jwt key to db failed: %w", err) + } + + const updateQuery = ` + UPDATE jwt_keys + SET private_key_as_bytes=NULL + WHERE k_id != $1` + if _, err := tx.ExecContext(ctx, updateQuery, new.KID); err != nil { + return fmt.Errorf("resetting old jwt keys failed: %w", err) + } + + // keep 3 latest ones + const deleteQuery = ` DELETE FROM jwt_keys WHERE id not in ( SELECT id @@ -152,10 +130,11 @@ func (db *DB) RotateJWTKeys(c context.Context, kid string) error { ORDER BY ID DESC LIMIT 3 )` - if _, err := db.db.ExecContext(c, deleteQuery); err != nil { - return fmt.Errorf("deleting old jwt keys failed: %w", err) - } - return nil + if _, err := tx.ExecContext(ctx, deleteQuery); err != nil { + return fmt.Errorf("deleting old jwt keys failed: %w", err) + } + return nil + }) } func DecryptRawKey(key RawKey, secret string) (auth.JWTKey, error) { @@ -166,6 +145,7 @@ func DecryptRawKey(key RawKey, secret string) (auth.JWTKey, error) { } response := auth.JWTKey{ + CreatedAt: key.CreatedAt, KID: key.KID, PublicKey: pub, } @@ -189,3 +169,16 @@ func DecryptRawKey(key RawKey, secret string) (auth.JWTKey, error) { func keySecret(kid string, secret string) string { return fmt.Sprintf("%s.%s", secret, kid) } + +func prepareRawKey(key auth.JWTKey, secret string) (*RawKey, error) { + privKey, err := auth.Encrypt(auth.EncodePrivateKeyToPEM(key.PrivateKey), keySecret(key.KID, secret)) + if err != nil { + return nil, err + } + + return &RawKey{ + KID: key.KID, + PrivateKey: privKey, + PublicKey: auth.EncodePublicKeyToPEM(key.PublicKey), + }, nil +} From c8f9a774413f47609b92da0fc5aae80639ba6e68 Mon Sep 17 00:00:00 2001 From: Henri Koski Date: Thu, 22 Aug 2024 19:51:45 +0300 Subject: [PATCH 2/5] Expose IsHTTPS --- v2/httputil/util.go | 22 ++++++++++++++++++++++ v2/middleware/csrf/csrf.go | 21 +++------------------ 2 files changed, 25 insertions(+), 18 deletions(-) create mode 100644 v2/httputil/util.go diff --git a/v2/httputil/util.go b/v2/httputil/util.go new file mode 100644 index 0000000..a6ff0a9 --- /dev/null +++ b/v2/httputil/util.go @@ -0,0 +1,22 @@ +package httputil + +import ( + "net/http" + "strings" +) + +func IsHTTPS(r *http.Request) bool { + const protoHTTPS = "https" + switch { + case r.URL.Scheme == protoHTTPS: + return true + case r.TLS != nil: + return true + case strings.HasPrefix(strings.ToLower(r.Proto), protoHTTPS): + return true + case r.Header.Get("X-Forwarded-Proto") == protoHTTPS: + return true + default: + return false + } +} diff --git a/v2/middleware/csrf/csrf.go b/v2/middleware/csrf/csrf.go index 66cd199..a24fd00 100644 --- a/v2/middleware/csrf/csrf.go +++ b/v2/middleware/csrf/csrf.go @@ -8,8 +8,8 @@ import ( "net/http" "net/url" "slices" - "strings" + "github.com/elisasre/go-common/v2/httputil" "github.com/gin-gonic/gin" ) @@ -71,7 +71,7 @@ func New(excludePaths []string) gin.HandlerFunc { Path: "/", Domain: c.Request.URL.Host, HttpOnly: false, - Secure: isHTTPS(c.Request), + Secure: httputil.IsHTTPS(c.Request), MaxAge: 12 * 60 * 60, SameSite: http.SameSiteLaxMode, }) @@ -82,7 +82,7 @@ func New(excludePaths []string) gin.HandlerFunc { return } - if isHTTPS(c.Request) { + if httputil.IsHTTPS(c.Request) { referer := c.Request.Header.Get("Referer") if referer == "" { c.JSON(403, ErrorResponse{Code: 403, Message: noReferer}) @@ -158,21 +158,6 @@ func randomString(n int) (string, error) { return string(b), nil } -func isHTTPS(r *http.Request) bool { - switch { - case r.URL.Scheme == protoHTTPS: - return true - case r.TLS != nil: - return true - case strings.HasPrefix(strings.ToLower(r.Proto), protoHTTPS): - return true - case r.Header.Get("X-Forwarded-Proto") == protoHTTPS: - return true - default: - return false - } -} - func getHeader(c *gin.Context) string { return c.Request.Header.Get(TokenHeaderKey) } From 5f5f3a9393e88da04ea362b4c8ef6d72ee5a4e09 Mon Sep 17 00:00:00 2001 From: Henri Koski Date: Thu, 22 Aug 2024 20:04:47 +0300 Subject: [PATCH 3/5] Extract ErrorResponse to httputil --- v2/httputil/util.go | 12 +++++++++++ v2/middleware/csrf/csrf.go | 25 +++++++---------------- v2/middleware/ratelimit/ratelimit.go | 17 +++------------ v2/middleware/ratelimit/ratelimit_test.go | 5 +++-- 4 files changed, 25 insertions(+), 34 deletions(-) diff --git a/v2/httputil/util.go b/v2/httputil/util.go index a6ff0a9..1bc3a0b 100644 --- a/v2/httputil/util.go +++ b/v2/httputil/util.go @@ -1,6 +1,7 @@ package httputil import ( + "fmt" "net/http" "strings" ) @@ -20,3 +21,14 @@ func IsHTTPS(r *http.Request) bool { return false } } + +func (e ErrorResponse) Error() string { + return fmt.Sprintf("%d: %s", e.Code, e.Message) +} + +// ErrorResponse provides HTTP error response. +type ErrorResponse struct { + Code uint `json:"code,omitempty" example:"400"` + Message string `json:"message" example:"Bad request"` + ErrorType string `json:"error_type,omitempty" example:"invalid_scope"` +} diff --git a/v2/middleware/csrf/csrf.go b/v2/middleware/csrf/csrf.go index a24fd00..fab8725 100644 --- a/v2/middleware/csrf/csrf.go +++ b/v2/middleware/csrf/csrf.go @@ -33,17 +33,6 @@ const ( var ignoreMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"} -func (e ErrorResponse) Error() string { - return fmt.Sprintf("%d: %s", e.Code, e.Message) -} - -// ErrorResponse provides HTTP error response. -type ErrorResponse struct { - Code uint `json:"code,omitempty" example:"400"` - Message string `json:"message" example:"Bad request"` - ErrorType string `json:"error_type,omitempty" example:"invalid_scope"` -} - // New creates new CSRF middleware for gin. func New(excludePaths []string) gin.HandlerFunc { return func(c *gin.Context) { @@ -61,7 +50,7 @@ func New(excludePaths []string) gin.HandlerFunc { if csrfToken == "" { val, err := RandomToken() if err != nil { - c.JSON(403, ErrorResponse{Code: 403, Message: malformedReferer}) + c.JSON(403, httputil.ErrorResponse{Code: 403, Message: malformedReferer}) c.Abort() return } @@ -85,27 +74,27 @@ func New(excludePaths []string) gin.HandlerFunc { if httputil.IsHTTPS(c.Request) { referer := c.Request.Header.Get("Referer") if referer == "" { - c.JSON(403, ErrorResponse{Code: 403, Message: noReferer}) + c.JSON(403, httputil.ErrorResponse{Code: 403, Message: noReferer}) c.Abort() return } parsedURL, err := url.Parse(referer) if err != nil { - c.JSON(403, ErrorResponse{Code: 403, Message: malformedReferer}) + c.JSON(403, httputil.ErrorResponse{Code: 403, Message: malformedReferer}) c.Abort() return } if parsedURL.Scheme != protoHTTPS { - c.JSON(403, ErrorResponse{Code: 403, Message: insecureReferer}) + c.JSON(403, httputil.ErrorResponse{Code: 403, Message: insecureReferer}) c.Abort() return } if parsedURL.Host != c.Request.Host { msg := fmt.Sprintf("Referer checking failed - %s does not match any trusted origins.", parsedURL.Host) - c.JSON(403, ErrorResponse{Code: 403, Message: msg}) + c.JSON(403, httputil.ErrorResponse{Code: 403, Message: msg}) c.Abort() return } @@ -113,13 +102,13 @@ func New(excludePaths []string) gin.HandlerFunc { requestCSRFToken := getHeader(c) if csrfToken == "" { - c.JSON(403, ErrorResponse{Code: 403, Message: tokenMissing}) + c.JSON(403, httputil.ErrorResponse{Code: 403, Message: tokenMissing}) c.Abort() return } if requestCSRFToken != csrfToken { - c.JSON(403, ErrorResponse{Code: 403, Message: badTooken}) + c.JSON(403, httputil.ErrorResponse{Code: 403, Message: badTooken}) c.Abort() return } diff --git a/v2/middleware/ratelimit/ratelimit.go b/v2/middleware/ratelimit/ratelimit.go index 4def0fb..74f408e 100644 --- a/v2/middleware/ratelimit/ratelimit.go +++ b/v2/middleware/ratelimit/ratelimit.go @@ -1,11 +1,11 @@ package ratelimit import ( - "fmt" "net/http" "strconv" "time" + "github.com/elisasre/go-common/v2/httputil" "github.com/gin-gonic/gin" "github.com/go-redis/redis_rate/v10" "github.com/redis/go-redis/v9" @@ -22,17 +22,6 @@ const ( HeaderRemaining = "X-Ratelimit-Remaining" ) -func (e ErrorResponse) Error() string { - return fmt.Sprintf("%d: %s", e.Code, e.Message) -} - -// ErrorResponse provides HTTP error response. -type ErrorResponse struct { - Code uint `json:"code,omitempty" example:"400"` - Message string `json:"message" example:"Bad request"` - ErrorType string `json:"error_type,omitempty" example:"invalid_scope"` -} - // New creates a distributed rate limiter middleware using redis for state management. func New(rdb *redis.Client, key KeyFunc, errFunc ErrFunc) gin.HandlerFunc { limiter := redis_rate.NewLimiter(rdb) @@ -40,7 +29,7 @@ func New(rdb *redis.Client, key KeyFunc, errFunc ErrFunc) gin.HandlerFunc { ctx := c.Request.Context() key, limit, err := key(c) if err != nil { - c.JSON(400, ErrorResponse{Code: 400, Message: err.Error()}) + c.JSON(400, httputil.ErrorResponse{Code: 400, Message: err.Error()}) c.Abort() return } @@ -61,7 +50,7 @@ func New(rdb *redis.Client, key KeyFunc, errFunc ErrFunc) gin.HandlerFunc { c.Header(HeaderLimit, strconv.Itoa(*limit)) c.Header(HeaderRemaining, strconv.Itoa(res.Remaining)) if res.Allowed <= 0 { - c.JSON(http.StatusTooManyRequests, ErrorResponse{Code: http.StatusTooManyRequests, Message: "rate limit exceeded"}) + c.JSON(http.StatusTooManyRequests, httputil.ErrorResponse{Code: http.StatusTooManyRequests, Message: "rate limit exceeded"}) c.Abort() return } diff --git a/v2/middleware/ratelimit/ratelimit_test.go b/v2/middleware/ratelimit/ratelimit_test.go index 6548249..26745d6 100644 --- a/v2/middleware/ratelimit/ratelimit_test.go +++ b/v2/middleware/ratelimit/ratelimit_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/alicebob/miniredis/v2" + "github.com/elisasre/go-common/v2/httputil" "github.com/elisasre/go-common/v2/middleware/ratelimit" "github.com/gin-gonic/gin" "github.com/redis/go-redis/v9" @@ -117,7 +118,7 @@ func TestRedisRateLimiterForce(t *testing.T) { t.Log(err) } c.JSON(http.StatusBadRequest, - ratelimit.ErrorResponse{Code: http.StatusBadRequest, Message: err.Error()}, + httputil.ErrorResponse{Code: http.StatusBadRequest, Message: err.Error()}, ) c.Abort() return true @@ -154,7 +155,7 @@ func TestRedisRateLimiterNil(t *testing.T) { t.Log(err) } c.JSON(http.StatusBadRequest, - ratelimit.ErrorResponse{Code: http.StatusBadRequest, Message: err.Error()}, + httputil.ErrorResponse{Code: http.StatusBadRequest, Message: err.Error()}, ) c.Abort() return true From b7089a222bf751bf3152ecb822528afdf56ff3d5 Mon Sep 17 00:00:00 2001 From: Henri Koski Date: Thu, 22 Aug 2024 21:09:22 +0300 Subject: [PATCH 4/5] Add RandomString util --- v2/utils.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/v2/utils.go b/v2/utils.go index a509815..44d221e 100644 --- a/v2/utils.go +++ b/v2/utils.go @@ -1,6 +1,8 @@ package common import ( + "crypto/rand" + "math/big" "strings" ) @@ -38,3 +40,18 @@ func StringToBool(v string) bool { v = strings.ToLower(v) return v == "true" || v == "t" || v == "yes" || v == "y" || v == "on" } + +func RandomString(n int) (string, error) { + characterRunes := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + + b := make([]byte, n) + for i := range b { + num, err := rand.Int(rand.Reader, big.NewInt(int64(len(characterRunes)))) + if err != nil { + return "", err + } + b[i] = characterRunes[num.Int64()] + } + + return string(b), nil +} From 554fe9e1fd9734a2dab1f46c653adaab7ba2781d Mon Sep 17 00:00:00 2001 From: Henri Koski Date: Thu, 22 Aug 2024 21:20:39 +0300 Subject: [PATCH 5/5] Add min and max --- v2/utils.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/v2/utils.go b/v2/utils.go index 44d221e..c5f6a35 100644 --- a/v2/utils.go +++ b/v2/utils.go @@ -55,3 +55,21 @@ func RandomString(n int) (string, error) { return string(b), nil } + +type Ordered interface { + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~float32 | ~float64 | ~string +} + +func Min[T Ordered](a, b T) T { + if a < b { + return a + } + return b +} + +func Max[T Ordered](a, b T) T { + if a > b { + return a + } + return b +}