Skip to content

Commit

Permalink
aconcli: fix review comments for openID support
Browse files Browse the repository at this point in the history
Signed-off-by: xxu36 <[email protected]>
  • Loading branch information
xxu36 committed Jun 11, 2024
1 parent 30ffab3 commit e6b39d4
Show file tree
Hide file tree
Showing 13 changed files with 126 additions and 51 deletions.
3 changes: 0 additions & 3 deletions aconcli/cmd/invoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,4 @@ func init() {

invokeCmd.Flags().StringVarP(&inputfile, "input", "i", "",
"optional file serving as stdin to the command")

invokeCmd.Flags().BoolVar(&nologin, "nologin", false,
"if set, login as an anonymous user")
}
3 changes: 0 additions & 3 deletions aconcli/cmd/kill.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,4 @@ func init() {
killCmd.Flags().Uint32VarP(&cid, "container", "e", 0,
"the ACON container to which the signal will be sent")
killCmd.MarkFlagRequired("container")

killCmd.Flags().BoolVar(&nologin, "nologin", false,
"if set, login as an anonymous user")
}
2 changes: 1 addition & 1 deletion aconcli/cmd/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func login() error {
if err != nil {
return fmt.Errorf("Login: cannot get the current user: %v", err)
}
if err := c.Login(user.Uid, vmConnTarget); err != nil {
if err := c.Login(user.Uid); err != nil {
return fmt.Errorf("Login: cannot call 'login' service: %v", err)
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion aconcli/cmd/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func logout() error {
if err != nil {
return fmt.Errorf("Logout: cannot get the current user: %v", err)
}
if err := c.Logout(user.Uid, vmConnTarget); err != nil {
if err := c.Logout(user.Uid); err != nil {
return fmt.Errorf("Logout: cannot call 'logout' service: %v", err)
}
return nil
Expand Down
2 changes: 0 additions & 2 deletions aconcli/cmd/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,4 @@ func init() {
"getting quote instead of getting report")
reportCmd.Flags().StringVarP(&file, "file", "f", "",
"file path to dump the report or quote raw data")
reportCmd.Flags().BoolVar(&nologin, "nologin", false,
"if set, login as an anonymous user")
}
2 changes: 0 additions & 2 deletions aconcli/cmd/restart.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,4 @@ func init() {

restartCmd.Flags().Uint64VarP(&timeout, "timeout", "t", 30,
"optional timeout in seconds to wait before restarting the container")
restartCmd.Flags().BoolVar(&nologin, "nologin", false,
"if set, login as an anonymous user")
}
1 change: 1 addition & 0 deletions aconcli/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ func init() {
&cobra.Group{"image", "ACON Image and Image Repo Commands:"},
&cobra.Group{"runtime", "ACON TD/VM and Container Commands:"})
rootCmd.PersistentFlags().StringVarP(&targetDir, "directory", "C", "", "change working directory before performing any operations")
rootCmd.PersistentFlags().BoolVar(&nologin, "nologin", false, "if set, login as an anonymous user")
}
2 changes: 1 addition & 1 deletion aconcli/cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func run(args []string) error {
if err != nil {
return fmt.Errorf("Run: cannot get the current user: %v", err)
}
if err := c.Login(user.Uid, vmConnTarget); err != nil {
if err := c.Login(user.Uid); err != nil {
return fmt.Errorf("Run: cannot login as user %s: %v", user.Uid, err)
} else {
log.Println("Successfully login")
Expand Down
2 changes: 0 additions & 2 deletions aconcli/cmd/shutdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,4 @@ func init() {
rootCmd.AddCommand(shutDownCmd)
shutDownCmd.Flags().BoolVarP(&force, "force", "f", false,
"force terminating the virtual machines, i.e. no matter whether Shutdown/Kill command works")
shutDownCmd.Flags().BoolVar(&nologin, "nologin", false,
"if set, login as an anonymous user")
}
2 changes: 0 additions & 2 deletions aconcli/cmd/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,4 @@ func showStatus() error {

func init() {
rootCmd.AddCommand(statusCmd)
statusCmd.Flags().BoolVar(&nologin, "nologin", false,
"if set, login as an anonymous user")
}
2 changes: 0 additions & 2 deletions aconcli/cmd/stop.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,4 @@ func stopAcons(args []string) error {

func init() {
rootCmd.AddCommand(stopCmd)
stopCmd.Flags().BoolVar(&nologin, "nologin", false,
"if set, login as an anonymous user")
}
40 changes: 27 additions & 13 deletions aconcli/service/aconclient_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/json"
"fmt"
"io"
Expand All @@ -19,6 +20,8 @@ import (
"strconv"
"strings"
"time"

"aconcli/cryptoutil"
)

const (
Expand Down Expand Up @@ -106,11 +109,12 @@ type OpenidConfig struct {
}

type AconClientHttp struct {
client *http.Client
host string
scheme string
noAuth bool
sessionkey string
client *http.Client
host string
scheme string
noAuth bool
sessionkey string
fingerPrint string
}

type Opt func(*AconClientHttp) error
Expand All @@ -135,7 +139,15 @@ func OptDialTLSContextInsecure() Opt {
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := tls.Dial(network, addr, &tls.Config{
InsecureSkipVerify: true,
VerifyConnection: func(tls.ConnectionState) error {
VerifyConnection: func(tcs tls.ConnectionState) error {
digest, err := cryptoutil.BytesDigest(tcs.PeerCertificates[0].RawSubjectPublicKeyInfo, "sha384")
if err != nil {
return fmt.Errorf("fail to digest server's public key info: %v", err)
}
c.fingerPrint = hex.EncodeToString(digest)
fmt.Println("******* TLS *******")
fmt.Println(c.fingerPrint)
fmt.Println("******* TLS *******")
return nil
},
})
Expand Down Expand Up @@ -175,7 +187,7 @@ func OptDialTLSContext(caCertFilePath string) Opt {
}
func NewAconHttpConnWithOpts(host string, opts ...Opt) (*AconClientHttp, error) {
log.Println("Service: Connecting", host)
c := &AconClientHttp{&http.Client{Timeout: DefaultServiceTimeout}, host, "http", false, ""}
c := &AconClientHttp{&http.Client{Timeout: DefaultServiceTimeout}, host, "http", false, "", ""}
for _, opt := range opts {
if err := opt(c); err != nil {
return nil, err
Expand Down Expand Up @@ -264,8 +276,8 @@ func (c *AconClientHttp) setRequestAuthHeader(req *http.Request) error {
return nil
}

func (c *AconClientHttp) Logout(uid string, vmid string) error {
sessionkey, loggedIn := IsLoggedIn(uid, vmid)
func (c *AconClientHttp) Logout(uid string) error {
sessionkey, loggedIn := IsLoggedIn(uid, c.fingerPrint)
if !loggedIn {
return nil
}
Expand All @@ -284,13 +296,13 @@ func (c *AconClientHttp) Logout(uid string, vmid string) error {
return fmt.Errorf(clientProcRespErrFmt, "Logout", err)
}

if err := RemoveAuthToken(uid, vmid); err != nil {
if err := RemoveAuthToken(uid, c.fingerPrint); err != nil {
return fmt.Errorf("fail to log out: %v", err)
}
return nil
}

func (c *AconClientHttp) Login(uid string, vmid string) error {
func (c *AconClientHttp) Login(uid string) error {
clientId := os.Getenv("ATD_CLIENT_ID")
if clientId == "" {
return fmt.Errorf("failed to get env variable ATD_CLIENT_ID for authentication")
Expand Down Expand Up @@ -354,22 +366,24 @@ func (c *AconClientHttp) Login(uid string, vmid string) error {
if err := json.Unmarshal(keydata, &key); err != nil {
return fmt.Errorf("failed to parse access token from response: %v", err)
}
if err := UpdateAuthToken(uid, map[string]string{vmid: key}); err != nil {
if err := UpdateAuthToken(uid, map[string]string{c.fingerPrint: key}); err != nil {
return fmt.Errorf("failed to update access token: %v", err)
}
c.sessionkey = key
return nil
}

func (c *AconClientHttp) fetchSessionKey() error {
fmt.Println("*******FETCH SESSION KEY *******")
fmt.Println(c.fingerPrint)
if len(c.sessionkey) > 0 {
return nil
}
user, err := user.Current()
if err != nil {
return fmt.Errorf("failed to get current user: %v", err)
}
key, err := GetAuthToken(user.Uid, c.host)
key, err := GetAuthToken(user.Uid, c.fingerPrint)
if err != nil {
return err
}
Expand Down
114 changes: 95 additions & 19 deletions aconcli/service/auth.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package service

import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"sync"
"syscall"
"time"
)

const (
Expand All @@ -18,25 +21,36 @@ const (
// map vmid to associated access token
type AuthTable map[string]string

var authTableMutex sync.Mutex

func getUserAuthTable(uid string) (AuthTable, error) {
records, err := os.ReadFile(filepath.Join(UserRuntimeDir, uid, AuthTableFile))
func getUserAuthTable(f *os.File) (AuthTable, error) {
finfo, err := f.Stat()
if err != nil {
return nil, err
}
records := make([]byte, finfo.Size())
n, err := f.Read(records)
if err != nil {
return nil, err
}
var authTable AuthTable
if err := json.Unmarshal(records, &authTable); err != nil {
if err := json.Unmarshal(records[:n], &authTable); err != nil {
return nil, err
}
return authTable, nil
}

func GetAuthToken(uid string, vmid string) (string, error) {
authTableMutex.Lock()
defer authTableMutex.Unlock()
f, err := os.Open(filepath.Join(UserRuntimeDir, uid, AuthTableFile))
if err != nil {
return "", fmt.Errorf("failed to open auth file: %v", err)
}
defer f.Close()

if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX); err != nil {
return "", err
}
defer syscall.Flock(int(f.Fd()), syscall.LOCK_UN)

authTable, err := getUserAuthTable(uid)
authTable, err := getUserAuthTable(f)
if err != nil {
return "", fmt.Errorf("failed to get auth token: %v", err)
}
Expand All @@ -48,15 +62,36 @@ func GetAuthToken(uid string, vmid string) (string, error) {
}

func UpdateAuthToken(uid string, t AuthTable) error {
authTableMutex.Lock()
defer authTableMutex.Unlock()
filename := filepath.Join(UserRuntimeDir, uid, AuthTableFile)
f, err := os.OpenFile(filename, os.O_RDWR, 0600)
if err != nil {
return fmt.Errorf("failed to open auth file: %v", err)
}
defer f.Close()

if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX); err != nil {
return err
}
defer syscall.Flock(int(f.Fd()), syscall.LOCK_UN)

wholeTable, err := getUserAuthTable(uid)
wholeTable, err := getUserAuthTable(f)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("failed to retrive whole auth data: %v", err)
}
wholeTable = AuthTable{}
} else {
current := time.Now().UTC().Unix()
for vmid, sk := range wholeTable {
expired, err := isExpired(sk, current)
if err != nil {
return fmt.Errorf("failed to determine expiration: %v", err)
}
if expired {
delete(wholeTable, vmid)
}
}

}
for k, v := range t {
wholeTable[k] = v
Expand All @@ -65,28 +100,51 @@ func UpdateAuthToken(uid string, t AuthTable) error {
if err != nil {
return fmt.Errorf("failed to marshal auth data: %v\n", err)
}
if err := os.WriteFile(filepath.Join(UserRuntimeDir, uid, AuthTableFile),
authData, 0600); err != nil {

f.Truncate(0)
if _, err := f.Write(authData); err != nil {
return fmt.Errorf("failed to write back auth data: %v\n", err)
}
return nil
}

func RemoveAuthToken(uid string, vmid string) error {
authTableMutex.Lock()
defer authTableMutex.Unlock()
filename := filepath.Join(UserRuntimeDir, uid, AuthTableFile)
f, err := os.OpenFile(filename, os.O_RDWR, 0600)
if err != nil {
return fmt.Errorf("failed to open auth file: %v", err)
}
defer f.Close()

if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX); err != nil {
return err
}
defer syscall.Flock(int(f.Fd()), syscall.LOCK_UN)

wholeTable, err := getUserAuthTable(uid)
wholeTable, err := getUserAuthTable(f)
if err != nil {
return fmt.Errorf("failed to retrive whole auth data: %v", err)
}
delete(wholeTable, vmid)

current := time.Now().UTC().Unix()
for vmid, sk := range wholeTable {
expired, err := isExpired(sk, current)
if err != nil {
return fmt.Errorf("failed to determine expiration: %v", err)
}
if expired {
delete(wholeTable, vmid)
}
}

authData, err := json.Marshal(wholeTable)
if err != nil {
return fmt.Errorf("failed to marshal auth data: %v\n", err)
}
if err := os.WriteFile(filepath.Join(UserRuntimeDir, uid, AuthTableFile),
authData, 0600); err != nil {

f.Truncate(0)
if _, err := f.Write(authData); err != nil {
return fmt.Errorf("failed to write back auth data: %v\n", err)
}
return nil
Expand All @@ -99,3 +157,21 @@ func IsLoggedIn(uid string, vmid string) (string, bool) {
}
return token, true
}

func getExpirationFromSessionKey(sk string) (int64, error) {
b := []byte(sk)
var duration int64
buf := bytes.NewReader(b[:8])
if err := binary.Read(buf, binary.LittleEndian, &duration); err != nil {
return 0, err
}
return duration, nil
}

func isExpired(sk string, current int64) (bool, error) {
expiration, err := getExpirationFromSessionKey(sk)
if err != nil {
return true, err
}
return current >= expiration, nil
}

0 comments on commit e6b39d4

Please sign in to comment.