diff --git a/Makefile b/Makefile index 83af3ae..76fd17f 100644 --- a/Makefile +++ b/Makefile @@ -19,7 +19,7 @@ golint: .git/hooks/pre-commit pre-commit run --all-files test: - go test -race -covermode atomic -coverprofile=gotest-coverage.out ./... $(GOTEST_REPORT_FORMAT) > gotest-report.out && cat gotest-report.out || (cat gotest-report.out; exit 1) + go test -race -v -covermode atomic -coverprofile=gotest-coverage.out ./... $(GOTEST_REPORT_FORMAT) > gotest-report.out && cat gotest-report.out || (cat gotest-report.out; exit 1) git diff --exit-code go.mod go.sum diff --git a/go.mod b/go.mod index 8b32aaa..a197ff1 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/gin-gonic/gin v1.9.0 github.com/go-redis/redis/v8 v8.11.5 github.com/go-redis/redis_rate/v9 v9.1.2 + github.com/golang-jwt/jwt v3.2.2+incompatible github.com/pkg/errors v0.9.1 github.com/rs/zerolog v1.29.1 github.com/spf13/viper v1.15.0 diff --git a/go.sum b/go.sum index a3e1176..6edc431 100644 --- a/go.sum +++ b/go.sum @@ -94,6 +94,8 @@ github.com/go-redis/redis_rate/v9 v9.1.2/go.mod h1:oam2de2apSgRG8aJzwJddXbNu91Iy github.com/goccy/go-json v0.10.0 h1:mXKd9Qw4NuzShiRlOXKews24ufknHO7gx30lsDyokKA= github.com/goccy/go-json v0.10.0/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= diff --git a/token/token.go b/token/token.go new file mode 100644 index 0000000..20a83db --- /dev/null +++ b/token/token.go @@ -0,0 +1,147 @@ +package token + +import ( + "fmt" + "strings" + "time" + + "github.com/elisasre/go-common" + "github.com/golang-jwt/jwt" +) + +// Token struct. +type Token struct { + User *common.User +} + +const ( + OpenID = "openid" + Profile = "profile" + Email = "email" + Groups = "groups" + Internal = "internal" +) + +var AllScopes = []string{OpenID, Profile, Email, Groups, Internal} + +// SignClaims contains claims that are passed to SignExpires func. +type SignClaims struct { + Aud string + Exp int64 + Iat int64 + Issuer string + Nonce string + Scopes []string +} + +// SignAlgo const. +const SignAlgo = "RS256" + +// New constructs new token which is passed for application. +func New(user *common.User) *Token { + return &Token{User: user} +} + +// UserJWTClaims contains struct for making and parsing jwt tokens. +type UserJWTClaims struct { + *common.User + jwt.StandardClaims + Nonce string `json:"nonce,omitempty"` +} + +// SignExpires makes new jwt token using expiration time and secret. +func (t *Token) SignExpires(key common.JWTKey, claim SignClaims) (string, error) { + t.User.Email = common.String(strings.ToLower(common.StringValue(t.User.Email))) + sub := t.User.MakeSub() + if claim.Iat == 0 { + claim.Iat = time.Now().Unix() + } + + if !common.Contains(claim.Scopes, OpenID) { + return "", fmt.Errorf("token must contain '%s' scope", OpenID) + } + + if !common.Contains(claim.Scopes, Internal) { + t.User.Internal = nil + } + + if !common.Contains(claim.Scopes, Email) { + t.User.Email = nil + t.User.EmailVerified = nil + } + + if !common.Contains(claim.Scopes, Groups) { + t.User.Groups = nil + } + + if !common.Contains(claim.Scopes, Profile) { + t.User.Name = nil + } + + claims := UserJWTClaims{ + t.User, + jwt.StandardClaims{ + Subject: sub, + Audience: claim.Aud, + ExpiresAt: claim.Exp, + Issuer: claim.Issuer, + IssuedAt: claim.Iat, + }, + claim.Nonce, + } + method := jwt.SigningMethodRS256 + token := jwt.Token{ + Header: map[string]interface{}{ + "typ": "JWT", + "alg": method.Alg(), + "kid": key.KID, + }, + Claims: claims, + Method: method, + } + if key.PrivateKey == nil { + return "", fmt.Errorf("privatekey is nil for key %d", key.ID) + } + return token.SignedString(key.PrivateKey) +} + +func findKidFromArray(keys []common.JWTKey, kid interface{}) (common.JWTKey, error) { + kidAsString, ok := kid.(string) + if !ok { + return common.JWTKey{}, fmt.Errorf("not str") + } + for _, s := range keys { + if s.KID == kidAsString { + return s, nil + } + } + return common.JWTKey{}, fmt.Errorf("could not find kid '%s'", kidAsString) +} + +// Parse will validate jwt token and return token. +func Parse(raw string, keys []common.JWTKey) (*UserJWTClaims, error) { + parsed, err := jwt.ParseWithClaims(raw, &UserJWTClaims{}, func(t *jwt.Token) (interface{}, error) { + if t.Method.Alg() != SignAlgo { + return nil, jwt.ErrSignatureInvalid + } + if val, ok := t.Header["kid"]; ok { + key, err := findKidFromArray(keys, val) + if err != nil { + return nil, err + } + return key.PublicKey, nil + } + return nil, fmt.Errorf("could not find kid from headers") + }) + if err != nil { + return nil, err + } else if !parsed.Valid { + return nil, jwt.ValidationError{} + } + + claims, ok := parsed.Claims.(*UserJWTClaims) + if !ok { + return nil, fmt.Errorf("could not parse struct") + } + return claims, nil +} diff --git a/token/token_test.go b/token/token_test.go new file mode 100644 index 0000000..8683f48 --- /dev/null +++ b/token/token_test.go @@ -0,0 +1,147 @@ +package token + +import ( + "fmt" + "strings" + "testing" + "time" + + "github.com/elisasre/go-common" + "github.com/stretchr/testify/require" +) + +func TestToken(t *testing.T) { + key, err := common.GenerateNewKeyPair() + require.NoError(t, err) + + fullUser := common.User{ + Name: common.String("Test User"), + Email: common.String("email@company.com"), + Groups: []string{"group1", "group2"}, + EmailVerified: common.Bool(true), + Internal: &common.Internal{ + EmployeeID: "123456", + MFA: common.Bool(true), + }, + } + + type testCase struct { + name string + scopes []string + err error + user *common.User + } + + testCases := []testCase{ + { + name: "all scopes", + scopes: AllScopes, + user: &common.User{ + Name: common.String("Test User"), + Email: common.String("email@company.com"), + Groups: []string{"group1", "group2"}, + EmailVerified: common.Bool(true), + Internal: &common.Internal{ + EmployeeID: "123456", + MFA: common.Bool(true), + }, + }, + }, + { + name: "openid", + scopes: []string{OpenID}, + user: nil, + }, + { + name: "openid profile", + scopes: []string{OpenID, Profile}, + user: &common.User{Name: common.String("Test User")}, + }, + { + name: "openid email", + scopes: []string{OpenID, Email}, + user: &common.User{ + Email: common.String("email@company.com"), + EmailVerified: common.Bool(true), + }, + }, + { + name: "openid groups", + scopes: []string{OpenID, Groups}, + user: &common.User{ + Groups: []string{"group1", "group2"}, + }, + }, + { + name: "openid internal", + scopes: []string{OpenID, Internal}, + user: &common.User{ + Internal: &common.Internal{ + EmployeeID: "123456", + MFA: common.Bool(true), + }, + }, + }, + { + name: "no openid scope", + scopes: []string{}, + err: fmt.Errorf("token must contain '%s' scope", OpenID), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + user := tc.user + newUser := fullUser + testUser := New(&newUser) + token, err := testUser.SignExpires(*key, SignClaims{ + Aud: "internal", + Exp: time.Now().Add(time.Hour).Unix(), + Issuer: "http://localhost", + Scopes: tc.scopes, + }) + if tc.err != nil { + require.Equal(t, tc.err, err) + } else { + require.NoError(t, err) + userClaims, err := Parse(token, []common.JWTKey{*key}) + require.NoError(t, err) + require.Equal(t, user, userClaims.User) + } + }) + } +} + +func TestInvalidKid(t *testing.T) { + key, err := common.GenerateNewKeyPair() + require.NoError(t, err) + key2, err := common.GenerateNewKeyPair() + require.NoError(t, err) + + testUser := New(&common.User{}) + token, err := testUser.SignExpires(*key, SignClaims{ + Aud: "internal", + Exp: time.Now().Add(time.Hour).Unix(), + Issuer: "http://localhost", + Scopes: AllScopes, + }) + require.NoError(t, err) + _, err = Parse(token, []common.JWTKey{*key2}) + require.Equal(t, fmt.Sprintf("could not find kid '%s'", key.KID), err.Error()) +} + +func TestExpired(t *testing.T) { + key, err := common.GenerateNewKeyPair() + require.NoError(t, err) + + testUser := New(&common.User{}) + token, err := testUser.SignExpires(*key, SignClaims{ + Aud: "internal", + Exp: 1, + Issuer: "http://localhost", + Scopes: AllScopes, + }) + require.NoError(t, err) + _, err = Parse(token, []common.JWTKey{*key}) + require.True(t, strings.HasPrefix(err.Error(), "token is expired by")) +} diff --git a/types.go b/types.go index f8cfd03..408bea5 100644 --- a/types.go +++ b/types.go @@ -3,6 +3,8 @@ package common import ( "context" "crypto/rsa" + "crypto/sha256" + "encoding/hex" "fmt" "strings" "time" @@ -130,3 +132,36 @@ type Datastore interface { ListJWTKeys(context.Context) ([]JWTKey, error) RotateJWTKeys(context.Context, uint) error } + +// Internal contains struct for internal non standard variables. +type Internal struct { + Cluster *string `json:"cluster,omitempty"` + ChangeLimit *int `json:"limit,omitempty"` + MFA *bool `json:"mfa"` + EmployeeID string `json:"employeeid,omitempty"` +} + +// User contains struct for single user. +type User struct { + Groups []string `json:"groups,omitempty"` + Eid string `json:"custom:employeeid,omitempty"` + ImportGroups []string `json:"cognito:groups,omitempty"` + Email *string `json:"email,omitempty"` + EmailVerified *bool `json:"email_verified,omitempty"` + Name *string `json:"name,omitempty"` + Internal *Internal `json:"internal,omitempty"` +} + +// MakeSub returns sub value for user. +func (u *User) MakeSub() string { + if u == nil { + return "" + } + sub := StringValue(u.Email) + if u.Internal != nil && u.Internal.EmployeeID != "" { + sub = u.Internal.EmployeeID + } + sub = strings.ToLower(sub) + b := sha256.Sum256([]byte(sub)) + return hex.EncodeToString(b[:]) +}