Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VAULT-12798 Correct removal behaviour when JWT is symlink #18863

Merged
merged 9 commits into from
Mar 14, 2023
3 changes: 3 additions & 0 deletions changelog/18863.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
agent: JWT auto-auth has a new config option, `remove_jwt_follows_symlinks` (default: false), that, if set to true will now remove the JWT, instead of the symlink to the JWT, if a symlink to a JWT has been provided in the `path` option, and the `remove_jwt_after_reading` config option is set to true (default).
```
87 changes: 65 additions & 22 deletions command/agent/auth/jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io/fs"
"net/http"
"os"
"path/filepath"
"sync"
"sync/atomic"
"time"
Expand All @@ -18,19 +19,20 @@ import (
)

type jwtMethod struct {
logger hclog.Logger
path string
mountPath string
role string
removeJWTAfterReading bool
credsFound chan struct{}
watchCh chan string
stopCh chan struct{}
doneCh chan struct{}
credSuccessGate chan struct{}
ticker *time.Ticker
once *sync.Once
latestToken *atomic.Value
logger hclog.Logger
path string
mountPath string
role string
removeJWTAfterReading bool
removeJWTFollowsSymlinks bool
credsFound chan struct{}
watchCh chan string
stopCh chan struct{}
doneCh chan struct{}
credSuccessGate chan struct{}
ticker *time.Ticker
once *sync.Once
latestToken *atomic.Value
}

// NewJWTAuthMethod returns an implementation of Agent's auth.AuthMethod
Expand Down Expand Up @@ -83,20 +85,39 @@ func NewJWTAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
j.removeJWTAfterReading = removeJWTAfterReading
}

if removeJWTFollowsSymlinksRaw, ok := conf.Config["remove_jwt_follows_symlinks"]; ok {
removeJWTFollowsSymlinks, err := parseutil.ParseBool(removeJWTFollowsSymlinksRaw)
if err != nil {
return nil, fmt.Errorf("error parsing 'remove_jwt_follows_symlinks' value: %w", err)
}
j.removeJWTFollowsSymlinks = removeJWTFollowsSymlinks
}

switch {
case j.path == "":
return nil, errors.New("'path' value is empty")
case j.role == "":
return nil, errors.New("'role' value is empty")
}

// If we don't delete the JWT after reading, use a slower reload period,
// otherwise we would re-read the whole file every 500ms, instead of just
// doing a stat on the file every 500ms.
// Default readPeriod
readPeriod := 1 * time.Minute
if j.removeJWTAfterReading {
readPeriod = 500 * time.Millisecond

if jwtReadPeriodRaw, ok := conf.Config["jwt_read_period"]; ok {
jwtReadPeriod, err := parseutil.ParseDurationSecond(jwtReadPeriodRaw)
if err != nil {
return nil, fmt.Errorf("error parsing 'jwt_read_period' value: %w", err)
}
readPeriod = jwtReadPeriod
} else {
// If we don't delete the JWT after reading, use a slower reload period,
// otherwise we would re-read the whole file every 500ms, instead of just
// doing a stat on the file every 500ms.
if j.removeJWTAfterReading {
readPeriod = 500 * time.Millisecond
}
}

j.ticker = time.NewTicker(readPeriod)

go j.runWatcher()
Expand Down Expand Up @@ -147,8 +168,8 @@ func (j *jwtMethod) runWatcher() {

case <-j.credSuccessGate:
// We only start the next loop once we're initially successful,
// since at startup Authenticate will be called and we don't want
// to end up immediately reauthenticating by having found a new
// since at startup Authenticate will be called, and we don't want
// to end up immediately re-authenticating by having found a new
// value
}

Expand Down Expand Up @@ -182,11 +203,27 @@ func (j *jwtMethod) ingressToken() {
// Check that the path refers to a file.
// If it's a symlink, it could still be a symlink to a directory,
// but os.ReadFile below will return a descriptive error.
evalSymlinkPath := j.path
switch mode := fi.Mode(); {
case mode.IsRegular():
// regular file
case mode&fs.ModeSymlink != 0:
// symlink
// If our file path is a symlink, we should also return early (like above) without error
// if the file that is linked to is not present, otherwise we will error when trying
// to read that file by following the link in the os.ReadFile call.
evalSymlinkPath, err = filepath.EvalSymlinks(j.path)
if err != nil {
j.logger.Error("error encountered evaluating symlinks", "error", err)
return
}
_, err := os.Stat(evalSymlinkPath)
if err != nil {
if os.IsNotExist(err) {
return
}
j.logger.Error("error encountered stat'ing jwt file after evaluating symlinks", "error", err)
return
}
default:
j.logger.Error("jwt file is not a regular file or symlink")
return
Expand All @@ -207,7 +244,13 @@ func (j *jwtMethod) ingressToken() {
}

if j.removeJWTAfterReading {
if err := os.Remove(j.path); err != nil {
pathToRemove := j.path
if j.removeJWTFollowsSymlinks {
// If removeJWTFollowsSymlinks is set, we follow the symlink and delete the jwt,
// not just the symlink that links to the jwt
pathToRemove = evalSymlinkPath
}
if err := os.Remove(pathToRemove); err != nil {
j.logger.Error("error removing jwt file", "error", err)
}
}
Expand Down
92 changes: 92 additions & 0 deletions command/agent/auth/jwt/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,95 @@ func TestDeleteAfterReading(t *testing.T) {
}
}
}

func TestDeleteAfterReadingSymlink(t *testing.T) {
for _, tc := range map[string]struct {
configValue string
shouldDelete bool
removeJWTFollowsSymlinks bool
}{
"default": {
"",
true,
false,
},
"explicit true": {
"true",
true,
false,
},
"false": {
"false",
false,
false,
},
"default + removeJWTFollowsSymlinks": {
"",
true,
true,
},
"explicit true + removeJWTFollowsSymlinks": {
"true",
true,
true,
},
"false + removeJWTFollowsSymlinks": {
"false",
false,
true,
},
} {
rootDir, err := os.MkdirTemp("", "vault-agent-jwt-auth-test")
if err != nil {
t.Fatalf("failed to create temp dir: %s", err)
}
defer os.RemoveAll(rootDir)
tokenPath := path.Join(rootDir, "token")
err = os.WriteFile(tokenPath, []byte("test"), 0o644)
if err != nil {
t.Fatal(err)
}

symlink, err := os.CreateTemp("", "auth.jwt.symlink.test.")
if err != nil {
t.Fatal(err)
}
symlinkName := symlink.Name()
symlink.Close()
os.Remove(symlinkName)
os.Symlink(tokenPath, symlinkName)

config := &auth.AuthConfig{
Config: map[string]interface{}{
"path": symlinkName,
"role": "unusedrole",
},
Logger: hclog.Default(),
}
if tc.configValue != "" {
config.Config["remove_jwt_after_reading"] = tc.configValue
}
config.Config["remove_jwt_follows_symlinks"] = tc.removeJWTFollowsSymlinks

jwtAuth, err := NewJWTAuthMethod(config)
if err != nil {
t.Fatal(err)
}

jwtAuth.(*jwtMethod).ingressToken()

pathToCheck := symlinkName
if tc.removeJWTFollowsSymlinks {
pathToCheck = tokenPath
}
if _, err := os.Lstat(pathToCheck); tc.shouldDelete {
if err == nil || !os.IsNotExist(err) {
t.Fatal(err)
}
} else {
if err != nil {
t.Fatal(err)
}
}
}
}
Loading