diff --git a/Makefile b/Makefile index 4e19f93..81e6e6e 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ install: $(BINS) @install -dm755 $(DESTDIR)$(LIBDIR)/systemd/system @install -dm755 $(DESTDIR)$(LIBDIR)/systemd/user @DESTDIR=$(DESTDIR) PREFIX=$(PREFIX) bin/ssh-tpm-hostkeys --install-system-units - @TEMPLATE_BINARY=1 DESTDIR=$(DESTDIR) PREFIX=$(PREFIX) bin/ssh-tpm-agent --install-user-units --install-system + @TEMPLATE_BINARY=/usr/bin/ssh-tpm-agent DESTDIR=$(DESTDIR) PREFIX=$(PREFIX) bin/ssh-tpm-agent --install-user-units --install-system .PHONY: lint lint: diff --git a/agent/agent.go b/agent/agent.go index 88712ea..3e91c4e 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -24,9 +24,7 @@ import ( var ErrOperationUnsupported = errors.New("operation unsupported") -var ( - SSH_TPM_AGENT_ADD = "tpm-add-key" -) +var SSH_TPM_AGENT_ADD = "tpm-add-key" type Agent struct { mu sync.Mutex @@ -45,7 +43,7 @@ func (a *Agent) Extension(extensionType string, contents []byte) ([]byte, error) slog.Debug("called extensions") switch extensionType { case SSH_TPM_AGENT_ADD: - slog.Debug("runnning %s", extensionType) + slog.Debug("runnning extension", slog.String("type", extensionType)) return a.AddTPMKey(contents) } return nil, agent.ErrExtensionUnsupported @@ -78,7 +76,7 @@ func (a *Agent) signers() ([]ssh.Signer, error) { for _, agent := range a.agents { l, err := agent.Signers() if err != nil { - slog.Info("failed getting Signers from agent: %f", err) + slog.Info("failed getting Signers from agent", slog.String("error", err.Error())) continue } signers = append(signers, l...) @@ -111,7 +109,7 @@ func (a *Agent) List() ([]*agent.Key, error) { for _, agent := range a.agents { l, err := agent.List() if err != nil { - slog.Info("failed getting list from agent: %v", err) + slog.Info("failed getting list from agent", slog.String("error", err.Error())) continue } agentKeys = append(agentKeys, l...) @@ -160,7 +158,7 @@ func (a *Agent) SignWithFlags(key ssh.PublicKey, data []byte, flags agent.Signat for _, agent := range a.agents { signers, err := agent.Signers() if err != nil { - slog.Info("failed getting signers from agent: %v", err) + slog.Info("failed getting signers from agent", slog.String("error", err.Error())) continue } for _, s := range signers { @@ -181,7 +179,7 @@ func (a *Agent) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { func (a *Agent) serveConn(c net.Conn) { if err := agent.ServeAgent(a, c); err != io.EOF { - slog.Info("Agent client connection ended with error:", err) + slog.Info("Agent client connection ended unsuccessfully", slog.String("error", err.Error())) } } @@ -202,9 +200,10 @@ func (a *Agent) serve() { if err != nil { type temporary interface { Temporary() bool + Error() string } if err, ok := err.(temporary); ok && err.Temporary() { - slog.Info("Temporary Accept error, sleeping 1s:", err) + slog.Info("Temporary Accept failure, sleeping 1s", slog.String("error", err.Error())) time.Sleep(1 * time.Second) continue } @@ -212,7 +211,7 @@ func (a *Agent) serve() { case <-a.quit: return default: - slog.Error("Failed to accept connections:", err) + slog.Error("Failed to accept connections", slog.String("error", err.Error())) } } a.wg.Add(1) @@ -252,14 +251,17 @@ func (a *Agent) Remove(key ssh.PublicKey) error { slog.Debug("called remove") return ErrOperationUnsupported } + func (a *Agent) RemoveAll() error { slog.Debug("called removeall") return a.Close() } + func (a *Agent) Lock(passphrase []byte) error { slog.Debug("called lock") return ErrOperationUnsupported } + func (a *Agent) Unlock(passphrase []byte) error { slog.Debug("called unlock") return ErrOperationUnsupported @@ -284,7 +286,7 @@ func LoadKeys(keyDir string) (map[string]*key.Key, error) { } k, err := key.DecodeKey(f) if err != nil { - slog.Debug("%s not a TPM sealed key: %v\n", path, err) + slog.Debug("not a TPM-sealed key", slog.String("key_path", path), slog.String("error", err.Error())) return nil } keys[k.Fingerprint()] = k diff --git a/cmd/ssh-tpm-agent/main.go b/cmd/ssh-tpm-agent/main.go index f0d31f9..0f4d6d2 100644 --- a/cmd/ssh-tpm-agent/main.go +++ b/cmd/ssh-tpm-agent/main.go @@ -137,7 +137,13 @@ func main() { slog.SetDefault(logger) if installUserUnits { - utils.InstallUserUnits(system) + if err := utils.InstallUserUnits(system); err != nil { + log.Fatal(err) + fmt.Println(err.Error()) + os.Exit(1) + } + + fmt.Println("Enable with: systemctl --user enable --now ssh-tpm-agent.socket") os.Exit(0) } @@ -152,7 +158,7 @@ func main() { } if keyDir == "" { - keyDir = utils.GetSSHDir() + keyDir = utils.SSHDir() } fi, err := os.Lstat(keyDir) @@ -161,7 +167,7 @@ func main() { os.Exit(1) } if fi.Mode()&os.ModeSymlink == os.ModeSymlink { - slog.Info("Warning: %s is a symbolic link; will not follow it", keyDir) + slog.Info("Not following symbolic link", slog.String("key_directory", keyDir)) } if term.IsTerminal(int(os.Stdin.Fd())) { @@ -205,7 +211,7 @@ func main() { slog.Info("Socket activated agent.") } else { os.Remove(socketPath) - if err := os.MkdirAll(filepath.Dir(socketPath), 0777); err != nil { + if err := os.MkdirAll(filepath.Dir(socketPath), 0o777); err != nil { slog.Error("Failed to create UNIX socket folder:", err) os.Exit(1) } @@ -214,7 +220,7 @@ func main() { slog.Error("Failed to listen on UNIX socket:", err) os.Exit(1) } - slog.Info(fmt.Sprintf("Listening on %v", socketPath)) + slog.Info("Listening on socket", slog.String("path", socketPath)) } a := agent.NewAgent(listener, @@ -222,7 +228,7 @@ func main() { // TPM Callback func() (tpm transport.TPMCloser) { // the agent will close the TPM after this is called - tpm, err := utils.GetTPM(swtpmFlag) + tpm, err := utils.TPM(swtpmFlag) if err != nil { log.Fatal(err) } diff --git a/cmd/ssh-tpm-hostkeys/main.go b/cmd/ssh-tpm-hostkeys/main.go index 025d3d5..58b8007 100644 --- a/cmd/ssh-tpm-hostkeys/main.go +++ b/cmd/ssh-tpm-hostkeys/main.go @@ -39,9 +39,11 @@ func main() { flag.Parse() if installSystemUnits { - if err := utils.InstallSystemUnits(); err != nil { + if err := utils.InstallHostkeyUnits(); err != nil { log.Fatal(err) } + + fmt.Println("Enable with: systemctl enable --now ssh-tpm-agent.socket") os.Exit(0) } if installSshdConfig { diff --git a/cmd/ssh-tpm-keygen/main.go b/cmd/ssh-tpm-keygen/main.go index 7f5b73a..273797e 100644 --- a/cmd/ssh-tpm-keygen/main.go +++ b/cmd/ssh-tpm-keygen/main.go @@ -127,7 +127,7 @@ func main() { flag.Parse() - tpm, err := utils.GetTPM(swtpmFlag) + tpm, err := utils.TPM(swtpmFlag) if err != nil { log.Fatal(err) } @@ -154,22 +154,22 @@ func main() { continue } - slog.Info(fmt.Sprintf("Generating new %s host key\n", strings.ToUpper(n))) + slog.Info("Generating new host key", slog.String("algorithm", strings.ToUpper(n))) k, err := key.CreateKey(tpm, t, []byte(""), []byte(defaultComment)) if err != nil { log.Fatal(err) } - if err := os.WriteFile(pubkeyFilename, k.AuthorizedKey(), 0600); err != nil { + if err := os.WriteFile(pubkeyFilename, k.AuthorizedKey(), 0o600); err != nil { log.Fatal(err) } - if err := os.WriteFile(privatekeyFilename, k.Encode(), 0600); err != nil { + if err := os.WriteFile(privatekeyFilename, k.Encode(), 0o600); err != nil { log.Fatal(err) } - slog.Info(fmt.Sprintf("Wrote %s\n", privatekeyFilename)) + slog.Info("Wrote private key", slog.String("filename", privatekeyFilename)) } os.Exit(0) } @@ -256,7 +256,7 @@ func main() { } else { fmt.Printf("Generating a sealed public/private %s key pair.\n", keyType) - filename = path.Join(utils.GetSSHDir(), filename) + filename = path.Join(utils.SSHDir(), filename) filenameInput, err := getStdin("Enter file in which to save the key (%s): ", filename) if err != nil { log.Fatal(err) @@ -318,12 +318,12 @@ func main() { } if importKey == "" { - if err := os.WriteFile(pubkeyFilename, k.AuthorizedKey(), 0600); err != nil { + if err := os.WriteFile(pubkeyFilename, k.AuthorizedKey(), 0o600); err != nil { log.Fatal(err) } } - if err := os.WriteFile(privatekeyFilename, k.Encode(), 0600); err != nil { + if err := os.WriteFile(privatekeyFilename, k.Encode(), 0o600); err != nil { log.Fatal(err) } diff --git a/contrib/contrib.go b/contrib/contrib.go index 6ca89fc..686bdae 100644 --- a/contrib/contrib.go +++ b/contrib/contrib.go @@ -21,17 +21,14 @@ func readPath(f embed.FS, s string) map[string][]byte { return ret } -// Get user services -func GetUserServices() map[string][]byte { +func EmbeddedUserServices() map[string][]byte { return readPath(services, "services/user") } -// Get system services -func GetSystemServices() map[string][]byte { +func EmbeddedSystemServices() map[string][]byte { return readPath(services, "services/system") } -// Get sshd config -func GetSshdConfig() map[string][]byte { +func EmbeddedSshdConfig() map[string][]byte { return readPath(sshd, "sshd") } diff --git a/contrib/contrib_test.go b/contrib/contrib_test.go index f6d03ac..42f7630 100644 --- a/contrib/contrib_test.go +++ b/contrib/contrib_test.go @@ -5,21 +5,21 @@ import ( ) func TestUserServices(t *testing.T) { - m := GetUserServices() + m := EmbeddedUserServices() if len(m) != 2 { t.Fatalf("invalid number of entries") } } func TestSystemServices(t *testing.T) { - m := GetSystemServices() + m := EmbeddedSystemServices() if len(m) != 3 { t.Fatalf("invalid number of entries") } } func TestSshdConfig(t *testing.T) { - m := GetSshdConfig() + m := EmbeddedSshdConfig() if len(m) != 1 { t.Fatalf("invalid number of entries") } diff --git a/contrib/services/user/ssh-tpm-agent.service b/contrib/services/user/ssh-tpm-agent.service index 423e823..caa5bff 100644 --- a/contrib/services/user/ssh-tpm-agent.service +++ b/contrib/services/user/ssh-tpm-agent.service @@ -1,6 +1,6 @@ [Unit] ConditionEnvironment=!SSH_AGENT_PID -Description=ssh-tpm-agent socket +Description=ssh-tpm-agent service Documentation=man:ssh-agent(1) man:ssh-add(1) man:ssh(1) Requires=ssh-tpm-agent.socket diff --git a/utils/tpm.go b/utils/tpm.go index d7180fd..96c7e07 100644 --- a/utils/tpm.go +++ b/utils/tpm.go @@ -22,12 +22,10 @@ func FlushHandle(tpm transport.TPM, h handle) { flushSrk.Execute(tpm) } -var ( - swtpmPath = "/var/tmp/ssh-tpm-agent" -) +var swtpmPath = "/var/tmp/ssh-tpm-agent" // Smaller wrapper for getting the correct TPM instance -func GetTPM(f bool) (transport.TPMCloser, error) { +func TPM(f bool) (transport.TPMCloser, error) { var tpm transport.TPMCloser var err error if f || os.Getenv("SSH_TPM_AGENT_SWTPM") != "" { diff --git a/utils/utils.go b/utils/utils.go index 56e17cf..e1c8b0a 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -11,49 +11,34 @@ import ( "github.com/foxboron/ssh-tpm-agent/contrib" ) -func GetSSHDir() string { +func SSHDir() string { dirname, err := os.UserHomeDir() if err != nil { panic("$HOME is not defined") } - return path.Join(dirname, ".ssh") -} -func GetSystemdUserDir() string { - dirname, err := os.UserHomeDir() - if err != nil { - panic("$HOME is not defined") - } - return path.Join(dirname, ".config/systemd/user") -} - -func DirExists(s string) bool { - info, err := os.Stat(s) - if errors.Is(err, fs.ErrNotExist) { - return false - } - return info.IsDir() + return path.Join(dirname, ".ssh") } func FileExists(s string) bool { - info, err := os.Stat(s) - if errors.Is(err, fs.ErrNotExist) { - return false - } - return !info.IsDir() + _, err := os.Stat(s) + + return !errors.Is(err, fs.ErrNotExist) } // This is the sort of things I swore I'd never write. // but here we are. func fmtSystemdInstallPath() string { DESTDIR := "" - PREFIX := "/usr/" - if s, ok := os.LookupEnv("DESTDIR"); ok { - DESTDIR = s + if val, ok := os.LookupEnv("DESTDIR"); ok { + DESTDIR = val } - if s, ok := os.LookupEnv("PREFIX"); ok { - PREFIX = s + + PREFIX := "/usr/" + if val, ok := os.LookupEnv("PREFIX"); ok { + PREFIX = val } + return path.Join(DESTDIR, PREFIX, "lib/systemd") } @@ -64,89 +49,69 @@ func fmtSystemdInstallPath() string { // Passing the env TEMPLATE_BINARY will use /usr/bin/ssh-tpm-agent for the // binary in the service func InstallUserUnits(global bool) error { - var exPath string - var serviceInstallPath string - var err error - - // If ran as root, install global system units - if uid := os.Getuid(); uid == 0 { - global = true + if global || os.Getuid() == 0 { // If ran as root, install global system units + return installUnits(path.Join(fmtSystemdInstallPath(), "/user/"), contrib.EmbeddedUserServices()) } - if global { - serviceInstallPath = path.Join(fmtSystemdInstallPath(), "/user/") - } else { - serviceInstallPath = GetSystemdUserDir() + dirname, err := os.UserHomeDir() + if err != nil { + return err } - // TODO: Use in a Makefile - if s := os.Getenv("TEMPLATE_BINARY"); s != "" { - exPath = "/usr/bin/ssh-tpm-agent" - } else { - exPath, err = os.Executable() + return installUnits(path.Join(dirname, ".config/systemd/user"), contrib.EmbeddedUserServices()) +} + +func InstallHostkeyUnits() error { + return installUnits(path.Join(fmtSystemdInstallPath(), "/system/"), contrib.EmbeddedSystemServices()) +} + +func installUnits(installPath string, files map[string][]byte) (err error) { + execPath := os.Getenv("TEMPLATE_BINARY") + if execPath == "" { + execPath, err = os.Executable() if err != nil { return err } } - if DirExists(serviceInstallPath) { - files := contrib.GetUserServices() - for name := range files { - ff := path.Join(serviceInstallPath, name) - if FileExists(ff) { - fmt.Printf("%s exists. Not installing user units.\n", ff) - return nil - } - } - for name, data := range files { - ff := path.Join(serviceInstallPath, name) - serviceFile, err := os.OpenFile(ff, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) - if err != nil { - return err - } - t := template.Must(template.New("service").Parse(string(data))) - if err = t.Execute(serviceFile, &struct { - GoBinary string - }{ - GoBinary: exPath, - }); err != nil { - return err - } - - fmt.Printf("Installed %s\n", ff) - } - fmt.Println("Enable with: systemctl --user enable --now ssh-tpm-agent.socket") - return nil + systemdBootedDir := "/run/systemd/system" // https://www.freedesktop.org/software/systemd/man/sd_booted.html + if !FileExists(systemdBootedDir) { + return fmt.Errorf("systemd not booted (%q does not exist)", systemdBootedDir) } - fmt.Printf("Couldn't find %s, probably not running systemd?\n", serviceInstallPath) - return nil -} - -func InstallSystemUnits() error { - serviceInstallPath := path.Join(fmtSystemdInstallPath(), "/system/") - if !DirExists(serviceInstallPath) { - fmt.Printf("Couldn't find %s, probably not running systemd?\n", serviceInstallPath) - return nil + if !FileExists(installPath) { + if err := os.MkdirAll(installPath, 0o750); err != nil { + return fmt.Errorf("creating service installation directory: %w", err) + } } - files := contrib.GetSystemServices() for name := range files { - ff := path.Join(serviceInstallPath, name) - if FileExists(ff) { - fmt.Printf("%s exists. Not installing user units.\n", ff) + servicePath := path.Join(installPath, name) + if FileExists(servicePath) { + fmt.Printf("%s exists. Not installing units.\n", servicePath) return nil } } + for name, data := range files { - ff := path.Join(serviceInstallPath, name) - if err := os.WriteFile(ff, data, 0644); err != nil { - return fmt.Errorf("failed writing service file: %v", err) + servicePath := path.Join(installPath, name) + + f, err := os.OpenFile(servicePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o640) + if err != nil { + return err } + defer f.Close() - fmt.Printf("Installed %s\n", ff) + t := template.Must(template.New("service").Parse(string(data))) + if err = t.Execute(f, &map[string]string{ + "GoBinary": execPath, + }); err != nil { + return err + } + + fmt.Printf("Installed %s\n", servicePath) } - fmt.Println("Enable with: systemctl enable --now ssh-tpm-agent.socket") + return nil } @@ -158,11 +123,11 @@ func InstallSshdConf() error { sshdConfInstallPath := "/etc/ssh/sshd_config.d/" - if !DirExists(sshdConfInstallPath) { + if !FileExists(sshdConfInstallPath) { return nil } - files := contrib.GetSshdConfig() + files := contrib.EmbeddedSshdConfig() for name := range files { ff := path.Join(sshdConfInstallPath, name) if FileExists(ff) { @@ -172,7 +137,7 @@ func InstallSshdConf() error { } for name, data := range files { ff := path.Join(sshdConfInstallPath, name) - if err := os.WriteFile(ff, data, 0644); err != nil { + if err := os.WriteFile(ff, data, 0o644); err != nil { return fmt.Errorf("failed writing sshd conf: %v", err) } fmt.Printf("Installed %s\n", ff)