Skip to content

Commit

Permalink
Refactor systemd unit installation
Browse files Browse the repository at this point in the history
  • Loading branch information
jtagcat committed Oct 14, 2023
1 parent e1ecab3 commit 3d70fda
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 115 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ install: $(BINS)
@install -dm755 $(DESTDIR)$(LIBDIR)/systemd/system
@install -dm755 $(DESTDIR)$(LIBDIR)/systemd/user
@DESTDIR=$(DESTDIR) PREFIX=$(PREFIX) bin/ssh-tpm-hostkeys --install-system-units
@TEMPLATE_BINARY=1 DESTDIR=$(DESTDIR) PREFIX=$(PREFIX) bin/ssh-tpm-agent --install-user-units --install-system
@TEMPLATE_BINARY=/usr/bin/ssh-tpm-agent DESTDIR=$(DESTDIR) PREFIX=$(PREFIX) bin/ssh-tpm-agent --install-user-units --install-system

.PHONY: lint
lint:
Expand Down
14 changes: 10 additions & 4 deletions cmd/ssh-tpm-agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,13 @@ func main() {
slog.SetDefault(logger)

if installUserUnits {
utils.InstallUserUnits(system)
if err := utils.InstallUserUnits(system); err != nil {
log.Fatal(err)
fmt.Println(err.Error())
os.Exit(1)
}

fmt.Println("Enable with: systemctl --user enable --now ssh-tpm-agent.socket")
os.Exit(0)
}

Expand All @@ -152,7 +158,7 @@ func main() {
}

if keyDir == "" {
keyDir = utils.GetSSHDir()
keyDir = utils.SSHDir()
}

fi, err := os.Lstat(keyDir)
Expand Down Expand Up @@ -205,7 +211,7 @@ func main() {
slog.Info("Socket activated agent.")
} else {
os.Remove(socketPath)
if err := os.MkdirAll(filepath.Dir(socketPath), 0777); err != nil {
if err := os.MkdirAll(filepath.Dir(socketPath), 0o777); err != nil {
slog.Error("Failed to create UNIX socket folder:", err)
os.Exit(1)
}
Expand All @@ -222,7 +228,7 @@ func main() {
// TPM Callback
func() (tpm transport.TPMCloser) {
// the agent will close the TPM after this is called
tpm, err := utils.GetTPM(swtpmFlag)
tpm, err := utils.TPM(swtpmFlag)
if err != nil {
log.Fatal(err)
}
Expand Down
4 changes: 3 additions & 1 deletion cmd/ssh-tpm-hostkeys/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ func main() {
flag.Parse()

if installSystemUnits {
if err := utils.InstallSystemUnits(); err != nil {
if err := utils.InstallHostkeyUnits(); err != nil {
log.Fatal(err)
}

fmt.Println("Enable with: systemctl enable --now ssh-tpm-agent.socket")
os.Exit(0)
}
if installSshdConfig {
Expand Down
12 changes: 6 additions & 6 deletions cmd/ssh-tpm-keygen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func main() {

flag.Parse()

tpm, err := utils.GetTPM(swtpmFlag)
tpm, err := utils.TPM(swtpmFlag)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -161,11 +161,11 @@ func main() {
log.Fatal(err)
}

if err := os.WriteFile(pubkeyFilename, k.AuthorizedKey(), 0600); err != nil {
if err := os.WriteFile(pubkeyFilename, k.AuthorizedKey(), 0o600); err != nil {
log.Fatal(err)
}

if err := os.WriteFile(privatekeyFilename, k.Encode(), 0600); err != nil {
if err := os.WriteFile(privatekeyFilename, k.Encode(), 0o600); err != nil {
log.Fatal(err)
}

Expand Down Expand Up @@ -256,7 +256,7 @@ func main() {
} else {
fmt.Printf("Generating a sealed public/private %s key pair.\n", keyType)

filename = path.Join(utils.GetSSHDir(), filename)
filename = path.Join(utils.SSHDir(), filename)
filenameInput, err := getStdin("Enter file in which to save the key (%s): ", filename)
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -318,12 +318,12 @@ func main() {
}

if importKey == "" {
if err := os.WriteFile(pubkeyFilename, k.AuthorizedKey(), 0600); err != nil {
if err := os.WriteFile(pubkeyFilename, k.AuthorizedKey(), 0o600); err != nil {
log.Fatal(err)
}
}

if err := os.WriteFile(privatekeyFilename, k.Encode(), 0600); err != nil {
if err := os.WriteFile(privatekeyFilename, k.Encode(), 0o600); err != nil {
log.Fatal(err)
}

Expand Down
9 changes: 3 additions & 6 deletions contrib/contrib.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,14 @@ func readPath(f embed.FS, s string) map[string][]byte {
return ret
}

// Get user services
func GetUserServices() map[string][]byte {
func EmbeddedUserServices() map[string][]byte {
return readPath(services, "services/user")
}

// Get system services
func GetSystemServices() map[string][]byte {
func EmbeddedSystemServices() map[string][]byte {
return readPath(services, "services/system")
}

// Get sshd config
func GetSshdConfig() map[string][]byte {
func EmbeddedSshdConfig() map[string][]byte {
return readPath(sshd, "sshd")
}
6 changes: 3 additions & 3 deletions contrib/contrib_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,21 @@ import (
)

func TestUserServices(t *testing.T) {
m := GetUserServices()
m := EmbeddedUserServices()
if len(m) != 2 {
t.Fatalf("invalid number of entries")
}
}

func TestSystemServices(t *testing.T) {
m := GetSystemServices()
m := EmbeddedSystemServices()
if len(m) != 3 {
t.Fatalf("invalid number of entries")
}
}

func TestSshdConfig(t *testing.T) {
m := GetSshdConfig()
m := EmbeddedSshdConfig()
if len(m) != 1 {
t.Fatalf("invalid number of entries")
}
Expand Down
2 changes: 1 addition & 1 deletion contrib/services/user/ssh-tpm-agent.service
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[Unit]
ConditionEnvironment=!SSH_AGENT_PID
Description=ssh-tpm-agent socket
Description=ssh-tpm-agent service
Documentation=man:ssh-agent(1) man:ssh-add(1) man:ssh(1)
Requires=ssh-tpm-agent.socket

Expand Down
6 changes: 2 additions & 4 deletions utils/tpm.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ func FlushHandle(tpm transport.TPM, h handle) {
flushSrk.Execute(tpm)
}

var (
swtpmPath = "/var/tmp/ssh-tpm-agent"
)
var swtpmPath = "/var/tmp/ssh-tpm-agent"

// Smaller wrapper for getting the correct TPM instance
func GetTPM(f bool) (transport.TPMCloser, error) {
func TPM(f bool) (transport.TPMCloser, error) {
var tpm transport.TPMCloser
var err error
if f || os.Getenv("SSH_TPM_AGENT_SWTPM") != "" {
Expand Down
146 changes: 57 additions & 89 deletions utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,49 +11,37 @@ import (
"github.com/foxboron/ssh-tpm-agent/contrib"
)

func GetSSHDir() string {
func SSHDir() string {
dirname, err := os.UserHomeDir()
if err != nil {
panic("$HOME is not defined")
}
return path.Join(dirname, ".ssh")
}

func GetSystemdUserDir() string {
dirname, err := os.UserHomeDir()
if err != nil {
panic("$HOME is not defined")
}
return path.Join(dirname, ".config/systemd/user")
}

func DirExists(s string) bool {
info, err := os.Stat(s)
if errors.Is(err, fs.ErrNotExist) {
return false
}
return info.IsDir()
return path.Join(dirname, ".ssh")
}

func FileExists(s string) bool {
info, err := os.Stat(s)
_, err := os.Stat(s)
if errors.Is(err, fs.ErrNotExist) {
return false
}
return !info.IsDir()

return true
}

// This is the sort of things I swore I'd never write.
// but here we are.
func fmtSystemdInstallPath() string {
DESTDIR := ""
PREFIX := "/usr/"
if s, ok := os.LookupEnv("DESTDIR"); ok {
DESTDIR = s
if val, ok := os.LookupEnv("DESTDIR"); ok {
DESTDIR = val
}
if s, ok := os.LookupEnv("PREFIX"); ok {
PREFIX = s

PREFIX := "/usr/"
if val, ok := os.LookupEnv("PREFIX"); ok {
PREFIX = val
}

return path.Join(DESTDIR, PREFIX, "lib/systemd")
}

Expand All @@ -64,89 +52,69 @@ func fmtSystemdInstallPath() string {
// Passing the env TEMPLATE_BINARY will use /usr/bin/ssh-tpm-agent for the
// binary in the service
func InstallUserUnits(global bool) error {
var exPath string
var serviceInstallPath string
var err error

// If ran as root, install global system units
if uid := os.Getuid(); uid == 0 {
global = true
if global || os.Getuid() == 0 { // If ran as root, install global system units
return installUnits(path.Join(fmtSystemdInstallPath(), "/user/"), contrib.EmbeddedUserServices())
}

if global {
serviceInstallPath = path.Join(fmtSystemdInstallPath(), "/user/")
} else {
serviceInstallPath = GetSystemdUserDir()
dirname, err := os.UserHomeDir()
if err != nil {
return err
}

// TODO: Use in a Makefile
if s := os.Getenv("TEMPLATE_BINARY"); s != "" {
exPath = "/usr/bin/ssh-tpm-agent"
} else {
exPath, err = os.Executable()
return installUnits(path.Join(dirname, ".config/systemd/user"), contrib.EmbeddedUserServices())
}

func InstallHostkeyUnits() error {
return installUnits(path.Join(fmtSystemdInstallPath(), "/system/"), contrib.EmbeddedSystemServices())
}

func installUnits(installPath string, files map[string][]byte) (err error) {
execPath := os.Getenv("TEMPLATE_BINARY")
if execPath == "" {
execPath, err = os.Executable()
if err != nil {
return err
}
}

if DirExists(serviceInstallPath) {
files := contrib.GetUserServices()
for name := range files {
ff := path.Join(serviceInstallPath, name)
if FileExists(ff) {
fmt.Printf("%s exists. Not installing user units.\n", ff)
return nil
}
}
for name, data := range files {
ff := path.Join(serviceInstallPath, name)
serviceFile, err := os.OpenFile(ff, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return err
}
t := template.Must(template.New("service").Parse(string(data)))
if err = t.Execute(serviceFile, &struct {
GoBinary string
}{
GoBinary: exPath,
}); err != nil {
return err
}

fmt.Printf("Installed %s\n", ff)
}
fmt.Println("Enable with: systemctl --user enable --now ssh-tpm-agent.socket")
return nil
systemdBootedDir := "/run/systemd/system" // https://www.freedesktop.org/software/systemd/man/sd_booted.html
if !FileExists(systemdBootedDir) {
return fmt.Errorf("systemd not booted (%q does not exist)", systemdBootedDir)
}
fmt.Printf("Couldn't find %s, probably not running systemd?\n", serviceInstallPath)
return nil
}

func InstallSystemUnits() error {
serviceInstallPath := path.Join(fmtSystemdInstallPath(), "/system/")

if !DirExists(serviceInstallPath) {
fmt.Printf("Couldn't find %s, probably not running systemd?\n", serviceInstallPath)
return nil
if !FileExists(installPath) {
if err := os.MkdirAll(installPath, 0o750); err != nil {
return fmt.Errorf("creating service installation directory: %w", err)
}
}

files := contrib.GetSystemServices()
for name := range files {
ff := path.Join(serviceInstallPath, name)
if FileExists(ff) {
fmt.Printf("%s exists. Not installing user units.\n", ff)
servicePath := path.Join(installPath, name)
if FileExists(servicePath) {
fmt.Printf("%s exists. Not installing units.\n", servicePath)
return nil
}
}

for name, data := range files {
ff := path.Join(serviceInstallPath, name)
if err := os.WriteFile(ff, data, 0644); err != nil {
return fmt.Errorf("failed writing service file: %v", err)
servicePath := path.Join(installPath, name)

f, err := os.OpenFile(servicePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o640)
if err != nil {
return err
}
defer f.Close()

fmt.Printf("Installed %s\n", ff)
t := template.Must(template.New("service").Parse(string(data)))
if err = t.Execute(f, &map[string]string{
"GoBinary": execPath,
}); err != nil {
return err
}

fmt.Printf("Installed %s\n", servicePath)
}
fmt.Println("Enable with: systemctl enable --now ssh-tpm-agent.socket")

return nil
}

Expand All @@ -158,11 +126,11 @@ func InstallSshdConf() error {

sshdConfInstallPath := "/etc/ssh/sshd_config.d/"

if !DirExists(sshdConfInstallPath) {
if !FileExists(sshdConfInstallPath) {
return nil
}

files := contrib.GetSshdConfig()
files := contrib.EmbeddedSshdConfig()
for name := range files {
ff := path.Join(sshdConfInstallPath, name)
if FileExists(ff) {
Expand All @@ -172,7 +140,7 @@ func InstallSshdConf() error {
}
for name, data := range files {
ff := path.Join(sshdConfInstallPath, name)
if err := os.WriteFile(ff, data, 0644); err != nil {
if err := os.WriteFile(ff, data, 0o644); err != nil {
return fmt.Errorf("failed writing sshd conf: %v", err)
}
fmt.Printf("Installed %s\n", ff)
Expand Down

0 comments on commit 3d70fda

Please sign in to comment.