Skip to content

Commit

Permalink
add token functions
Browse files Browse the repository at this point in the history
  • Loading branch information
zetaab committed May 10, 2023
1 parent 0496e26 commit ff8c139
Show file tree
Hide file tree
Showing 6 changed files with 333 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ golint: .git/hooks/pre-commit
pre-commit run --all-files

test:
go test -race -covermode atomic -coverprofile=gotest-coverage.out ./... $(GOTEST_REPORT_FORMAT) > gotest-report.out && cat gotest-report.out || (cat gotest-report.out; exit 1)
go test -race -v -covermode atomic -coverprofile=gotest-coverage.out ./... $(GOTEST_REPORT_FORMAT) > gotest-report.out && cat gotest-report.out || (cat gotest-report.out; exit 1)
git diff --exit-code go.mod go.sum


Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
github.com/gin-gonic/gin v1.9.0
github.com/go-redis/redis/v8 v8.11.5
github.com/go-redis/redis_rate/v9 v9.1.2
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/pkg/errors v0.9.1
github.com/rs/zerolog v1.29.1
github.com/spf13/viper v1.15.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ github.com/go-redis/redis_rate/v9 v9.1.2/go.mod h1:oam2de2apSgRG8aJzwJddXbNu91Iy
github.com/goccy/go-json v0.10.0 h1:mXKd9Qw4NuzShiRlOXKews24ufknHO7gx30lsDyokKA=
github.com/goccy/go-json v0.10.0/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
Expand Down
147 changes: 147 additions & 0 deletions token/token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package token

import (
"fmt"
"strings"
"time"

"github.com/elisasre/go-common"
"github.com/golang-jwt/jwt"
)

// Token struct.
type Token struct {
User *common.User
}

const (
OpenID = "openid"
Profile = "profile"
Email = "email"
Groups = "groups"
Internal = "internal"
)

var AllScopes = []string{OpenID, Profile, Email, Groups, Internal}

// SignClaims contains claims that are passed to SignExpires func.
type SignClaims struct {
Aud string
Exp int64
Iat int64
Issuer string
Nonce string
Scopes []string
}

// SignAlgo const.
const SignAlgo = "RS256"

// New constructs new token which is passed for application.
func New(user *common.User) *Token {
return &Token{User: user}
}

// UserJWTClaims contains struct for making and parsing jwt tokens.
type UserJWTClaims struct {
*common.User
jwt.StandardClaims
Nonce string `json:"nonce,omitempty"`
}

// SignExpires makes new jwt token using expiration time and secret.
func (t *Token) SignExpires(key common.JWTKey, claim SignClaims) (string, error) {
t.User.Email = common.String(strings.ToLower(common.StringValue(t.User.Email)))
sub := t.User.MakeSub()
if claim.Iat == 0 {
claim.Iat = time.Now().Unix()
}

if !common.Contains(claim.Scopes, OpenID) {
return "", fmt.Errorf("token must contain '%s' scope", OpenID)
}

if !common.Contains(claim.Scopes, Internal) {
t.User.Internal = nil
}

if !common.Contains(claim.Scopes, Email) {
t.User.Email = nil
t.User.EmailVerified = nil
}

if !common.Contains(claim.Scopes, Groups) {
t.User.Groups = nil
}

if !common.Contains(claim.Scopes, Profile) {
t.User.Name = nil
}

claims := UserJWTClaims{
t.User,
jwt.StandardClaims{
Subject: sub,
Audience: claim.Aud,
ExpiresAt: claim.Exp,
Issuer: claim.Issuer,
IssuedAt: claim.Iat,
},
claim.Nonce,
}
method := jwt.SigningMethodRS256
token := jwt.Token{
Header: map[string]interface{}{
"typ": "JWT",
"alg": method.Alg(),
"kid": key.KID,
},
Claims: claims,
Method: method,
}
if key.PrivateKey == nil {
return "", fmt.Errorf("privatekey is nil for key %d", key.ID)
}
return token.SignedString(key.PrivateKey)
}

func findKidFromArray(keys []common.JWTKey, kid interface{}) (common.JWTKey, error) {
kidAsString, ok := kid.(string)
if !ok {
return common.JWTKey{}, fmt.Errorf("not str")
}
for _, s := range keys {
if s.KID == kidAsString {
return s, nil
}
}
return common.JWTKey{}, fmt.Errorf("could not find kid '%s'", kidAsString)
}

// Parse will validate jwt token and return token.
func Parse(raw string, keys []common.JWTKey) (*UserJWTClaims, error) {
parsed, err := jwt.ParseWithClaims(raw, &UserJWTClaims{}, func(t *jwt.Token) (interface{}, error) {
if t.Method.Alg() != SignAlgo {
return nil, jwt.ErrSignatureInvalid
}
if val, ok := t.Header["kid"]; ok {
key, err := findKidFromArray(keys, val)
if err != nil {
return nil, err
}
return key.PublicKey, nil
}
return nil, fmt.Errorf("could not find kid from headers")
})
if err != nil {
return nil, err
} else if !parsed.Valid {
return nil, jwt.ValidationError{}
}

claims, ok := parsed.Claims.(*UserJWTClaims)
if !ok {
return nil, fmt.Errorf("could not parse struct")
}
return claims, nil
}
147 changes: 147 additions & 0 deletions token/token_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package token

import (
"fmt"
"strings"
"testing"
"time"

"github.com/elisasre/go-common"
"github.com/stretchr/testify/require"
)

func TestToken(t *testing.T) {
key, err := common.GenerateNewKeyPair()
require.NoError(t, err)

fullUser := common.User{
Name: common.String("Test User"),
Email: common.String("[email protected]"),
Groups: []string{"group1", "group2"},
EmailVerified: common.Bool(true),
Internal: &common.Internal{
EmployeeID: "123456",
MFA: common.Bool(true),
},
}

type testCase struct {
name string
scopes []string
err error
user *common.User
}

testCases := []testCase{
{
name: "all scopes",
scopes: AllScopes,
user: &common.User{
Name: common.String("Test User"),
Email: common.String("[email protected]"),
Groups: []string{"group1", "group2"},
EmailVerified: common.Bool(true),
Internal: &common.Internal{
EmployeeID: "123456",
MFA: common.Bool(true),
},
},
},
{
name: "openid",
scopes: []string{OpenID},
user: nil,
},
{
name: "openid profile",
scopes: []string{OpenID, Profile},
user: &common.User{Name: common.String("Test User")},
},
{
name: "openid email",
scopes: []string{OpenID, Email},
user: &common.User{
Email: common.String("[email protected]"),
EmailVerified: common.Bool(true),
},
},
{
name: "openid groups",
scopes: []string{OpenID, Groups},
user: &common.User{
Groups: []string{"group1", "group2"},
},
},
{
name: "openid internal",
scopes: []string{OpenID, Internal},
user: &common.User{
Internal: &common.Internal{
EmployeeID: "123456",
MFA: common.Bool(true),
},
},
},
{
name: "no openid scope",
scopes: []string{},
err: fmt.Errorf("token must contain '%s' scope", OpenID),
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
user := tc.user
newUser := fullUser
testUser := New(&newUser)
token, err := testUser.SignExpires(*key, SignClaims{
Aud: "internal",
Exp: time.Now().Add(time.Hour).Unix(),
Issuer: "http://localhost",
Scopes: tc.scopes,
})
if tc.err != nil {
require.Equal(t, tc.err, err)
} else {
require.NoError(t, err)
userClaims, err := Parse(token, []common.JWTKey{*key})
require.NoError(t, err)
require.Equal(t, user, userClaims.User)
}
})
}
}

func TestInvalidKid(t *testing.T) {
key, err := common.GenerateNewKeyPair()
require.NoError(t, err)
key2, err := common.GenerateNewKeyPair()
require.NoError(t, err)

testUser := New(&common.User{})
token, err := testUser.SignExpires(*key, SignClaims{
Aud: "internal",
Exp: time.Now().Add(time.Hour).Unix(),
Issuer: "http://localhost",
Scopes: AllScopes,
})
require.NoError(t, err)
_, err = Parse(token, []common.JWTKey{*key2})
require.Equal(t, fmt.Sprintf("could not find kid '%s'", key.KID), err.Error())
}

func TestExpired(t *testing.T) {
key, err := common.GenerateNewKeyPair()
require.NoError(t, err)

testUser := New(&common.User{})
token, err := testUser.SignExpires(*key, SignClaims{
Aud: "internal",
Exp: 1,
Issuer: "http://localhost",
Scopes: AllScopes,
})
require.NoError(t, err)
_, err = Parse(token, []common.JWTKey{*key})
require.True(t, strings.HasPrefix(err.Error(), "token is expired by"))
}
35 changes: 35 additions & 0 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package common
import (
"context"
"crypto/rsa"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"time"
Expand Down Expand Up @@ -130,3 +132,36 @@ type Datastore interface {
ListJWTKeys(context.Context) ([]JWTKey, error)
RotateJWTKeys(context.Context, uint) error
}

// Internal contains struct for internal non standard variables.
type Internal struct {
Cluster *string `json:"cluster,omitempty"`
ChangeLimit *int `json:"limit,omitempty"`
MFA *bool `json:"mfa"`
EmployeeID string `json:"employeeid,omitempty"`
}

// User contains struct for single user.
type User struct {
Groups []string `json:"groups,omitempty"`
Eid string `json:"custom:employeeid,omitempty"`
ImportGroups []string `json:"cognito:groups,omitempty"`
Email *string `json:"email,omitempty"`
EmailVerified *bool `json:"email_verified,omitempty"`
Name *string `json:"name,omitempty"`
Internal *Internal `json:"internal,omitempty"`
}

// MakeSub returns sub value for user.
func (u *User) MakeSub() string {
if u == nil {
return ""
}
sub := StringValue(u.Email)
if u.Internal != nil && u.Internal.EmployeeID != "" {
sub = u.Internal.EmployeeID
}
sub = strings.ToLower(sub)
b := sha256.Sum256([]byte(sub))
return hex.EncodeToString(b[:])
}

0 comments on commit ff8c139

Please sign in to comment.