From 3d70fdac8ae23ef8a09ee0d79265f5e9a8c4b274 Mon Sep 17 00:00:00 2001 From: jtagcat Date: Sat, 14 Oct 2023 15:14:21 +0300 Subject: [PATCH] Refactor systemd unit installation --- Makefile | 2 +- cmd/ssh-tpm-agent/main.go | 14 +- cmd/ssh-tpm-hostkeys/main.go | 4 +- cmd/ssh-tpm-keygen/main.go | 12 +- contrib/contrib.go | 9 +- contrib/contrib_test.go | 6 +- contrib/services/user/ssh-tpm-agent.service | 2 +- utils/tpm.go | 6 +- utils/utils.go | 146 ++++++++------------ 9 files changed, 86 insertions(+), 115 deletions(-) diff --git a/Makefile b/Makefile index 4e19f93..81e6e6e 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ install: $(BINS) @install -dm755 $(DESTDIR)$(LIBDIR)/systemd/system @install -dm755 $(DESTDIR)$(LIBDIR)/systemd/user @DESTDIR=$(DESTDIR) PREFIX=$(PREFIX) bin/ssh-tpm-hostkeys --install-system-units - @TEMPLATE_BINARY=1 DESTDIR=$(DESTDIR) PREFIX=$(PREFIX) bin/ssh-tpm-agent --install-user-units --install-system + @TEMPLATE_BINARY=/usr/bin/ssh-tpm-agent DESTDIR=$(DESTDIR) PREFIX=$(PREFIX) bin/ssh-tpm-agent --install-user-units --install-system .PHONY: lint lint: diff --git a/cmd/ssh-tpm-agent/main.go b/cmd/ssh-tpm-agent/main.go index f0d31f9..a66fa2f 100644 --- a/cmd/ssh-tpm-agent/main.go +++ b/cmd/ssh-tpm-agent/main.go @@ -137,7 +137,13 @@ func main() { slog.SetDefault(logger) if installUserUnits { - utils.InstallUserUnits(system) + if err := utils.InstallUserUnits(system); err != nil { + log.Fatal(err) + fmt.Println(err.Error()) + os.Exit(1) + } + + fmt.Println("Enable with: systemctl --user enable --now ssh-tpm-agent.socket") os.Exit(0) } @@ -152,7 +158,7 @@ func main() { } if keyDir == "" { - keyDir = utils.GetSSHDir() + keyDir = utils.SSHDir() } fi, err := os.Lstat(keyDir) @@ -205,7 +211,7 @@ func main() { slog.Info("Socket activated agent.") } else { os.Remove(socketPath) - if err := os.MkdirAll(filepath.Dir(socketPath), 0777); err != nil { + if err := os.MkdirAll(filepath.Dir(socketPath), 0o777); err != nil { slog.Error("Failed to create UNIX socket folder:", err) os.Exit(1) } @@ -222,7 +228,7 @@ func main() { // TPM Callback func() (tpm transport.TPMCloser) { // the agent will close the TPM after this is called - tpm, err := utils.GetTPM(swtpmFlag) + tpm, err := utils.TPM(swtpmFlag) if err != nil { log.Fatal(err) } diff --git a/cmd/ssh-tpm-hostkeys/main.go b/cmd/ssh-tpm-hostkeys/main.go index 025d3d5..58b8007 100644 --- a/cmd/ssh-tpm-hostkeys/main.go +++ b/cmd/ssh-tpm-hostkeys/main.go @@ -39,9 +39,11 @@ func main() { flag.Parse() if installSystemUnits { - if err := utils.InstallSystemUnits(); err != nil { + if err := utils.InstallHostkeyUnits(); err != nil { log.Fatal(err) } + + fmt.Println("Enable with: systemctl enable --now ssh-tpm-agent.socket") os.Exit(0) } if installSshdConfig { diff --git a/cmd/ssh-tpm-keygen/main.go b/cmd/ssh-tpm-keygen/main.go index 7f5b73a..ecacc44 100644 --- a/cmd/ssh-tpm-keygen/main.go +++ b/cmd/ssh-tpm-keygen/main.go @@ -127,7 +127,7 @@ func main() { flag.Parse() - tpm, err := utils.GetTPM(swtpmFlag) + tpm, err := utils.TPM(swtpmFlag) if err != nil { log.Fatal(err) } @@ -161,11 +161,11 @@ func main() { log.Fatal(err) } - if err := os.WriteFile(pubkeyFilename, k.AuthorizedKey(), 0600); err != nil { + if err := os.WriteFile(pubkeyFilename, k.AuthorizedKey(), 0o600); err != nil { log.Fatal(err) } - if err := os.WriteFile(privatekeyFilename, k.Encode(), 0600); err != nil { + if err := os.WriteFile(privatekeyFilename, k.Encode(), 0o600); err != nil { log.Fatal(err) } @@ -256,7 +256,7 @@ func main() { } else { fmt.Printf("Generating a sealed public/private %s key pair.\n", keyType) - filename = path.Join(utils.GetSSHDir(), filename) + filename = path.Join(utils.SSHDir(), filename) filenameInput, err := getStdin("Enter file in which to save the key (%s): ", filename) if err != nil { log.Fatal(err) @@ -318,12 +318,12 @@ func main() { } if importKey == "" { - if err := os.WriteFile(pubkeyFilename, k.AuthorizedKey(), 0600); err != nil { + if err := os.WriteFile(pubkeyFilename, k.AuthorizedKey(), 0o600); err != nil { log.Fatal(err) } } - if err := os.WriteFile(privatekeyFilename, k.Encode(), 0600); err != nil { + if err := os.WriteFile(privatekeyFilename, k.Encode(), 0o600); err != nil { log.Fatal(err) } diff --git a/contrib/contrib.go b/contrib/contrib.go index 6ca89fc..686bdae 100644 --- a/contrib/contrib.go +++ b/contrib/contrib.go @@ -21,17 +21,14 @@ func readPath(f embed.FS, s string) map[string][]byte { return ret } -// Get user services -func GetUserServices() map[string][]byte { +func EmbeddedUserServices() map[string][]byte { return readPath(services, "services/user") } -// Get system services -func GetSystemServices() map[string][]byte { +func EmbeddedSystemServices() map[string][]byte { return readPath(services, "services/system") } -// Get sshd config -func GetSshdConfig() map[string][]byte { +func EmbeddedSshdConfig() map[string][]byte { return readPath(sshd, "sshd") } diff --git a/contrib/contrib_test.go b/contrib/contrib_test.go index f6d03ac..42f7630 100644 --- a/contrib/contrib_test.go +++ b/contrib/contrib_test.go @@ -5,21 +5,21 @@ import ( ) func TestUserServices(t *testing.T) { - m := GetUserServices() + m := EmbeddedUserServices() if len(m) != 2 { t.Fatalf("invalid number of entries") } } func TestSystemServices(t *testing.T) { - m := GetSystemServices() + m := EmbeddedSystemServices() if len(m) != 3 { t.Fatalf("invalid number of entries") } } func TestSshdConfig(t *testing.T) { - m := GetSshdConfig() + m := EmbeddedSshdConfig() if len(m) != 1 { t.Fatalf("invalid number of entries") } diff --git a/contrib/services/user/ssh-tpm-agent.service b/contrib/services/user/ssh-tpm-agent.service index 423e823..caa5bff 100644 --- a/contrib/services/user/ssh-tpm-agent.service +++ b/contrib/services/user/ssh-tpm-agent.service @@ -1,6 +1,6 @@ [Unit] ConditionEnvironment=!SSH_AGENT_PID -Description=ssh-tpm-agent socket +Description=ssh-tpm-agent service Documentation=man:ssh-agent(1) man:ssh-add(1) man:ssh(1) Requires=ssh-tpm-agent.socket diff --git a/utils/tpm.go b/utils/tpm.go index d7180fd..96c7e07 100644 --- a/utils/tpm.go +++ b/utils/tpm.go @@ -22,12 +22,10 @@ func FlushHandle(tpm transport.TPM, h handle) { flushSrk.Execute(tpm) } -var ( - swtpmPath = "/var/tmp/ssh-tpm-agent" -) +var swtpmPath = "/var/tmp/ssh-tpm-agent" // Smaller wrapper for getting the correct TPM instance -func GetTPM(f bool) (transport.TPMCloser, error) { +func TPM(f bool) (transport.TPMCloser, error) { var tpm transport.TPMCloser var err error if f || os.Getenv("SSH_TPM_AGENT_SWTPM") != "" { diff --git a/utils/utils.go b/utils/utils.go index 56e17cf..27dfb9c 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -11,49 +11,37 @@ import ( "github.com/foxboron/ssh-tpm-agent/contrib" ) -func GetSSHDir() string { +func SSHDir() string { dirname, err := os.UserHomeDir() if err != nil { panic("$HOME is not defined") } - return path.Join(dirname, ".ssh") -} -func GetSystemdUserDir() string { - dirname, err := os.UserHomeDir() - if err != nil { - panic("$HOME is not defined") - } - return path.Join(dirname, ".config/systemd/user") -} - -func DirExists(s string) bool { - info, err := os.Stat(s) - if errors.Is(err, fs.ErrNotExist) { - return false - } - return info.IsDir() + return path.Join(dirname, ".ssh") } func FileExists(s string) bool { - info, err := os.Stat(s) + _, err := os.Stat(s) if errors.Is(err, fs.ErrNotExist) { return false } - return !info.IsDir() + + return true } // This is the sort of things I swore I'd never write. // but here we are. func fmtSystemdInstallPath() string { DESTDIR := "" - PREFIX := "/usr/" - if s, ok := os.LookupEnv("DESTDIR"); ok { - DESTDIR = s + if val, ok := os.LookupEnv("DESTDIR"); ok { + DESTDIR = val } - if s, ok := os.LookupEnv("PREFIX"); ok { - PREFIX = s + + PREFIX := "/usr/" + if val, ok := os.LookupEnv("PREFIX"); ok { + PREFIX = val } + return path.Join(DESTDIR, PREFIX, "lib/systemd") } @@ -64,89 +52,69 @@ func fmtSystemdInstallPath() string { // Passing the env TEMPLATE_BINARY will use /usr/bin/ssh-tpm-agent for the // binary in the service func InstallUserUnits(global bool) error { - var exPath string - var serviceInstallPath string - var err error - - // If ran as root, install global system units - if uid := os.Getuid(); uid == 0 { - global = true + if global || os.Getuid() == 0 { // If ran as root, install global system units + return installUnits(path.Join(fmtSystemdInstallPath(), "/user/"), contrib.EmbeddedUserServices()) } - if global { - serviceInstallPath = path.Join(fmtSystemdInstallPath(), "/user/") - } else { - serviceInstallPath = GetSystemdUserDir() + dirname, err := os.UserHomeDir() + if err != nil { + return err } - // TODO: Use in a Makefile - if s := os.Getenv("TEMPLATE_BINARY"); s != "" { - exPath = "/usr/bin/ssh-tpm-agent" - } else { - exPath, err = os.Executable() + return installUnits(path.Join(dirname, ".config/systemd/user"), contrib.EmbeddedUserServices()) +} + +func InstallHostkeyUnits() error { + return installUnits(path.Join(fmtSystemdInstallPath(), "/system/"), contrib.EmbeddedSystemServices()) +} + +func installUnits(installPath string, files map[string][]byte) (err error) { + execPath := os.Getenv("TEMPLATE_BINARY") + if execPath == "" { + execPath, err = os.Executable() if err != nil { return err } } - if DirExists(serviceInstallPath) { - files := contrib.GetUserServices() - for name := range files { - ff := path.Join(serviceInstallPath, name) - if FileExists(ff) { - fmt.Printf("%s exists. Not installing user units.\n", ff) - return nil - } - } - for name, data := range files { - ff := path.Join(serviceInstallPath, name) - serviceFile, err := os.OpenFile(ff, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) - if err != nil { - return err - } - t := template.Must(template.New("service").Parse(string(data))) - if err = t.Execute(serviceFile, &struct { - GoBinary string - }{ - GoBinary: exPath, - }); err != nil { - return err - } - - fmt.Printf("Installed %s\n", ff) - } - fmt.Println("Enable with: systemctl --user enable --now ssh-tpm-agent.socket") - return nil + systemdBootedDir := "/run/systemd/system" // https://www.freedesktop.org/software/systemd/man/sd_booted.html + if !FileExists(systemdBootedDir) { + return fmt.Errorf("systemd not booted (%q does not exist)", systemdBootedDir) } - fmt.Printf("Couldn't find %s, probably not running systemd?\n", serviceInstallPath) - return nil -} - -func InstallSystemUnits() error { - serviceInstallPath := path.Join(fmtSystemdInstallPath(), "/system/") - if !DirExists(serviceInstallPath) { - fmt.Printf("Couldn't find %s, probably not running systemd?\n", serviceInstallPath) - return nil + if !FileExists(installPath) { + if err := os.MkdirAll(installPath, 0o750); err != nil { + return fmt.Errorf("creating service installation directory: %w", err) + } } - files := contrib.GetSystemServices() for name := range files { - ff := path.Join(serviceInstallPath, name) - if FileExists(ff) { - fmt.Printf("%s exists. Not installing user units.\n", ff) + servicePath := path.Join(installPath, name) + if FileExists(servicePath) { + fmt.Printf("%s exists. Not installing units.\n", servicePath) return nil } } + for name, data := range files { - ff := path.Join(serviceInstallPath, name) - if err := os.WriteFile(ff, data, 0644); err != nil { - return fmt.Errorf("failed writing service file: %v", err) + servicePath := path.Join(installPath, name) + + f, err := os.OpenFile(servicePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o640) + if err != nil { + return err } + defer f.Close() - fmt.Printf("Installed %s\n", ff) + t := template.Must(template.New("service").Parse(string(data))) + if err = t.Execute(f, &map[string]string{ + "GoBinary": execPath, + }); err != nil { + return err + } + + fmt.Printf("Installed %s\n", servicePath) } - fmt.Println("Enable with: systemctl enable --now ssh-tpm-agent.socket") + return nil } @@ -158,11 +126,11 @@ func InstallSshdConf() error { sshdConfInstallPath := "/etc/ssh/sshd_config.d/" - if !DirExists(sshdConfInstallPath) { + if !FileExists(sshdConfInstallPath) { return nil } - files := contrib.GetSshdConfig() + files := contrib.EmbeddedSshdConfig() for name := range files { ff := path.Join(sshdConfInstallPath, name) if FileExists(ff) { @@ -172,7 +140,7 @@ func InstallSshdConf() error { } for name, data := range files { ff := path.Join(sshdConfInstallPath, name) - if err := os.WriteFile(ff, data, 0644); err != nil { + if err := os.WriteFile(ff, data, 0o644); err != nil { return fmt.Errorf("failed writing sshd conf: %v", err) } fmt.Printf("Installed %s\n", ff)