Skip to content

Commit

Permalink
authn: refactor, maintain
Browse files Browse the repository at this point in the history
* config: remove rlock; use pointers

Signed-off-by: Alex Aizman <[email protected]>
  • Loading branch information
alex-aizman committed Jul 16, 2024
1 parent e0a312c commit 73682b8
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 72 deletions.
61 changes: 37 additions & 24 deletions api/authn/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Package authn provides AuthN API over HTTP(S)
/*
* Copyright (c) 2018-2022, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2018-2024, NVIDIA CORPORATION. All rights reserved.
*/
package authn

Expand All @@ -18,11 +18,12 @@ import (

type (
Config struct {
sync.RWMutex `list:"omit"` // for cmn.IterFields
Log LogConf `json:"log"`
Net NetConf `json:"net"`
Server ServerConf `json:"auth"`
Timeout TimeoutConf `json:"timeout"`
Log LogConf `json:"log"`
Net NetConf `json:"net"`
Server ServerConf `json:"auth"`
Timeout TimeoutConf `json:"timeout"`
// private
mu sync.RWMutex
}
LogConf struct {
Dir string `json:"dir"`
Expand All @@ -32,14 +33,17 @@ type (
HTTP HTTPConf `json:"http"`
}
HTTPConf struct {
Port int `json:"port"`
UseHTTPS bool `json:"use_https"`
Certificate string `json:"server_crt"`
Key string `json:"server_key"`
Port int `json:"port"`
UseHTTPS bool `json:"use_https"`
}
ServerConf struct {
Secret string `json:"secret"`
ExpirePeriod cos.Duration `json:"expiration_time"`
Secret string `json:"secret"`
Expire cos.Duration `json:"expiration_time"`
// private
psecret *string
pexpire *cos.Duration
}
TimeoutConf struct {
Default cos.Duration `json:"default_timeout"`
Expand All @@ -48,8 +52,8 @@ type (
Server *ServerConfToSet `json:"auth"`
}
ServerConfToSet struct {
Secret *string `json:"secret"`
ExpirePeriod *string `json:"expiration_time"`
Secret *string `json:"secret,omitempty"`
Expire *string `json:"expiration_time,omitempty"`
}
// TokenList is a list of tokens pushed by authn
TokenList struct {
Expand All @@ -67,11 +71,12 @@ var (

func (*Config) JspOpts() jsp.Options { return authcfgJspOpts }

func (c *Config) Secret() (secret string) {
c.RLock()
secret = c.Server.Secret
c.RUnlock()
return
func (c *Config) Lock() { c.mu.Lock() }
func (c *Config) Unlock() { c.mu.Unlock() }

func (c *Config) Init() {
c.Server.psecret = &c.Server.Secret
c.Server.pexpire = &c.Server.Expire
}

func (c *Config) Verbose() bool {
Expand All @@ -80,24 +85,32 @@ func (c *Config) Verbose() bool {
return level > 3
}

func (c *Config) Secret() string { return *c.Server.psecret }
func (c *Config) Expire() time.Duration { return time.Duration(*c.Server.pexpire) }

func (c *Config) SetSecret(val *string) {
c.Server.Secret = *val
c.Server.psecret = val
}

func (c *Config) ApplyUpdate(cu *ConfigToUpdate) error {
if cu.Server == nil {
return errors.New("configuration is empty")
}
c.Lock()
defer c.Unlock()
if cu.Server.Secret != nil {
if *cu.Server.Secret == "" {
return errors.New("secret not defined")
}
c.Server.Secret = *cu.Server.Secret
c.SetSecret(cu.Server.Secret)
}
if cu.Server.ExpirePeriod != nil {
dur, err := time.ParseDuration(*cu.Server.ExpirePeriod)
if cu.Server.Expire != nil {
dur, err := time.ParseDuration(*cu.Server.Expire)
if err != nil {
return fmt.Errorf("invalid time format %s, err: %v", *cu.Server.ExpirePeriod, err)
return fmt.Errorf("invalid time format %s: %v", *cu.Server.Expire, err)
}
c.Server.ExpirePeriod = cos.Duration(dur)
v := cos.Duration(dur)
c.Server.Expire = v
c.Server.pexpire = &v
}
return nil
}
15 changes: 10 additions & 5 deletions cmd/authn/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Package authn is authentication server for AIStore.
/*
* Copyright (c) 2018-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2018-2024, NVIDIA CORPORATION. All rights reserved.
*/
package main

Expand Down Expand Up @@ -30,9 +30,9 @@ func httpConfigGet(w http.ResponseWriter, r *http.Request) {
if err := validateAdminPerms(w, r); err != nil {
return
}
Conf.RLock()
writeJSON(w, Conf, "config")
Conf.RUnlock()
Conf.Lock()
writeJSON(w, Conf, "get config")
Conf.Unlock()
}

func httpConfigPut(w http.ResponseWriter, r *http.Request) {
Expand All @@ -44,10 +44,15 @@ func httpConfigPut(w http.ResponseWriter, r *http.Request) {
cmn.WriteErrMsg(w, r, "Invalid request")
return
}
if err := Conf.ApplyUpdate(updateCfg); err != nil {

Conf.Lock()
err := Conf.ApplyUpdate(updateCfg)
Conf.Unlock()
if err != nil {
cmn.WriteErr(w, r, err)
return
}

if err := jsp.SaveMeta(configPath, Conf, nil); err != nil {
cmn.WriteErr(w, r, err)
}
Expand Down
38 changes: 17 additions & 21 deletions cmd/authn/hserv.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func (h *hserv) httpUserGet(w http.ResponseWriter, r *http.Request) {
return
}
uInfo.Password = ""
writeJSON(w, uInfo, "user info")
writeJSON(w, uInfo, "get user")
}

// Checks if the request header contains valid admin credentials.
Expand Down Expand Up @@ -283,7 +283,6 @@ func validateAdminPerms(w http.ResponseWriter, r *http.Request) error {
// If h token is already issued and it is not expired yet then the old
// token is returned
func (h *hserv) userLogin(w http.ResponseWriter, r *http.Request) {
var err error
apiItems, err := parseURL(w, r, 1, apc.URLPathUsers.L)
if err != nil {
return
Expand All @@ -293,39 +292,36 @@ func (h *hserv) userLogin(w http.ResponseWriter, r *http.Request) {
return
}
if msg.Password == "" {
cmn.WriteErrMsg(w, r, "Not authorized", http.StatusUnauthorized)
cmn.WriteErrMsg(w, r, "empty password", http.StatusUnauthorized)
return
}
userID := apiItems[0]
pass := msg.Password

tokenString, err := h.mgr.issueToken(userID, pass, msg)
if err != nil {
nlog.Errorf("Failed to generate token for user %q: %v\n", userID, err)
var (
token string
userID = apiItems[0]
)
if token, err = h.mgr.issueToken(userID, msg.Password, msg); err != nil {
nlog.Errorf("failed to generate token for user %q: %v\n", userID, err)
cmn.WriteErr(w, r, err, http.StatusUnauthorized)
return
}

repl := fmt.Sprintf(`{"token": %q}`, tokenString)
writeBytes(w, []byte(repl), "auth")
repl := fmt.Sprintf(`{"token": %q}`, token)
writeBytes(w, cos.UnsafeB(repl), "login")
}

func writeJSON(w http.ResponseWriter, val any, tag string) {
w.Header().Set(cos.HdrContentType, cos.ContentJSON)
var err error
if err = jsoniter.NewEncoder(w).Encode(val); err == nil {
return
if err := jsoniter.NewEncoder(w).Encode(val); err != nil {
nlog.Errorf("%s: failed to write response: %v", tag, err)
}
nlog.Errorf("%s: failed to write json, err: %v", tag, err)
}

func writeBytes(w http.ResponseWriter, jsbytes []byte, tag string) {
w.Header().Set(cos.HdrContentType, cos.ContentJSON)
var err error
if _, err = w.Write(jsbytes); err == nil {
return
if _, err := w.Write(jsbytes); err != nil {
nlog.Errorf("%s: failed to write response: %v", tag, err)
}
nlog.Errorf("%s: failed to write json, err: %v", tag, err)
}

func (h *hserv) httpSrvPost(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -412,7 +408,7 @@ func (h *hserv) httpSrvGet(w http.ResponseWriter, r *http.Request) {
}
cluList = &authn.RegisteredClusters{Clusters: clus}
}
writeJSON(w, cluList, "auth")
writeJSON(w, cluList, "get cluster")
}

func (h *hserv) roleHandler(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -446,7 +442,7 @@ func (h *hserv) httpRoleGet(w http.ResponseWriter, r *http.Request) {
cmn.WriteErr(w, r, err)
return
}
writeJSON(w, roles, "rolelist")
writeJSON(w, roles, "list roles")
return
}

Expand All @@ -465,7 +461,7 @@ func (h *hserv) httpRoleGet(w http.ResponseWriter, r *http.Request) {
clu.Alias = cInfo.Alias
}
}
writeJSON(w, role, "role")
writeJSON(w, role, "get role")
}

func (h *hserv) httpRoleDel(w http.ResponseWriter, r *http.Request) {
Expand Down
3 changes: 2 additions & 1 deletion cmd/authn/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ func main() {
if _, err := jsp.LoadMeta(configPath, Conf); err != nil {
cos.ExitLogf("Failed to load configuration from %q: %v", configPath, err)
}
Conf.Init()
if val := os.Getenv(secretKeyPodEnv); val != "" {
Conf.Server.Secret = val
Conf.SetSecret(&val)
}
if err := updateLogOptions(); err != nil {
cos.ExitLogf("Failed to set up logger: %v", err)
Expand Down
30 changes: 16 additions & 14 deletions cmd/authn/mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,53 +349,55 @@ func (m *mgr) delCluster(cluID string) error {
// already generated and is not expired yet the existing token is returned.
// Token includes user ID, permissions, and token expiration time.
// If a new token was generated then it sends the proxy a new valid token list
func (m *mgr) issueToken(userID, pwd string, msg *authn.LoginMsg) (string, error) {
func (m *mgr) issueToken(uid, pwd string, msg *authn.LoginMsg) (token string, err error) {
var (
err error
expires time.Time
token string
uInfo = &authn.User{}
cid string
cluACLs []*authn.CluACL
bckACLs []*authn.BckACL
)

err = m.db.Get(usersCollection, userID, uInfo)
err = m.db.Get(usersCollection, uid, uInfo)
if err != nil {
nlog.Errorln(err)
return "", errInvalidCredentials
}

debug.Assert(uid == uInfo.ID, uid, " vs ", uInfo.ID)

if !isSamePassword(pwd, uInfo.Password) {
return "", errInvalidCredentials
}

// update ACLs with roles's ones
// update ACLs with roles' ones
for _, role := range uInfo.Roles {
cluACLs = mergeClusterACLs(cluACLs, role.ClusterACLs, cid)
bckACLs = mergeBckACLs(bckACLs, role.BucketACLs, cid)
}

// generate token
Conf.RLock()
defer Conf.RUnlock()
issued := time.Now()
expDelta := time.Duration(Conf.Server.ExpirePeriod)
token, err = m._token(msg, uInfo, cluACLs, bckACLs)
return token, err
}

func (m *mgr) _token(msg *authn.LoginMsg, uInfo *authn.User, cluACLs []*authn.CluACL, bckACLs []*authn.BckACL) (token string, err error) {
expDelta := Conf.Expire()
if msg.ExpiresIn != nil {
expDelta = *msg.ExpiresIn
}
if expDelta == 0 {
expDelta = foreverTokenTime
}
expires = issued.Add(expDelta)

// put all useful info into token: who owns the token, when it was issued,
// when it expires and credentials to log in AWS, GCP etc.
// If a user is a super user, it is enough to pass only isAdmin marker
expires := time.Now().Add(expDelta)
uid := uInfo.ID
if uInfo.IsAdmin() {
token, err = tok.IssueAdminJWT(expires, userID, Conf.Server.Secret)
token, err = tok.AdminJWT(expires, uid, Conf.Secret())
} else {
m.fixClusterIDs(cluACLs)
token, err = tok.IssueJWT(expires, userID, bckACLs, cluACLs, Conf.Server.Secret)
token, err = tok.JWT(expires, uid, bckACLs, cluACLs, Conf.Secret())
}
return token, err
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/authn/tok/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ var (
ErrTokenRevoked = errors.New("token revoked")
)

func IssueAdminJWT(expires time.Time, userID, secret string) (string, error) {
func AdminJWT(expires time.Time, userID, secret string) (string, error) {
t := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"expires": expires,
"username": userID,
Expand All @@ -46,7 +46,7 @@ func IssueAdminJWT(expires time.Time, userID, secret string) (string, error) {
return t.SignedString([]byte(secret))
}

func IssueJWT(expires time.Time, userID string, bucketACLs []*authn.BckACL, clusterACLs []*authn.CluACL,
func JWT(expires time.Time, userID string, bucketACLs []*authn.BckACL, clusterACLs []*authn.CluACL,
secret string) (string, error) {
t := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"expires": expires,
Expand Down
14 changes: 9 additions & 5 deletions cmd/authn/unit_test.go → cmd/authn/unit_internal_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
// Package authn is authentication server for AIStore.
//go:build debug

// Package authn
/*
* Copyright (c) 2018-2024, NVIDIA CORPORATION. All rights reserved.
*/
package main

// NOTE go:build debug (above) =====================================

import (
"testing"
"time"
Expand All @@ -30,9 +34,9 @@ var (
)

func init() {
// Set default expiration time to 30 minutes
if Conf.Server.ExpirePeriod == 0 {
Conf.Server.ExpirePeriod = cos.Duration(time.Minute * 30)
Conf.Init()
if Conf.Server.Expire == 0 {
Conf.Server.Expire = cos.Duration(time.Minute * 30) // NOTE: default token expiration time
}
}

Expand Down Expand Up @@ -150,7 +154,7 @@ func TestToken(t *testing.T) {
var (
err error
token string
secret = Conf.Server.Secret
secret = Conf.Secret()
)

driver := mock.NewDBDriver()
Expand Down

0 comments on commit 73682b8

Please sign in to comment.